((<= x 33.3) (+ x (exp (- x))))
(else (exact->inexact x))))
+;;; log(e^x + e^y + ...)
+;;;
+;;; Caller can minimize error by passing descending inputs below 0, or
+;;; ascending inputs above 1.
+
+(define (logsumexp l)
+ (if (pair? l)
+ (if (pair? (cdr l))
+ (let ((m (reduce max #f l)))
+ ;; Cases:
+ ;;
+ ;; 1. There is a NaN among the inputs: invalid operation;
+ ;; result is NaN.
+ ;;
+ ;; 2. The maximum is +inf, and
+ ;; (a) the minimum is -inf: NaN.
+ ;; (b) the minimum is finite: +inf.
+ ;;
+ ;; 3. The maximum is -inf: all inputs are -inf, so -inf.
+ ;;
+ ;; Most likely all the inputs are finite, so prioritize
+ ;; that case by checking for an infinity first -- if there
+ ;; is a NaN, the usual computation will propagate it.
+ ;;
+ (if (and (infinite? m)
+ (not (= (- m) (reduce min #f l)))
+ (not (any nan? l)))
+ m
+ (+ m (log (reduce + 0 (map (lambda (x) (exp (- x m))) l))))))
+ (car l))
+ (flo:-inf.0)))
+
;;; Some lemmas for the bounds below.
;;;
;;; Lemma 1. If |d| < 1/2, then 1/(1 + d) <= 2.
(lambda (v)
(assert-<= (relerr (cdr v) (log1pexp (car v))) 1e-15)))
+(define-enumerated-test 'logsumexp-values
+ (vector
+ (vector '(999 1000) 1000.3132616875182)
+ (vector '(-1000 -1000) (+ -1000 (log 2)))
+ (vector '(0 0) (log 2)))
+ (lambda (v)
+ (let ((l (vector-ref v 0))
+ (s (vector-ref v 1)))
+ (assert-<= (relerr s (logsumexp l)) 1e-15))))
+
+(define-enumerated-test 'logsumexp-edges
+ (vector
+ (vector '() (flo:-inf.0))
+ (vector '(-1000) -1000)
+ (vector '(-1000.) -1000.)
+ (vector (list (flo:-inf.0)) (flo:-inf.0))
+ (vector (list (flo:-inf.0) 1) 1.)
+ (vector (list 1 (flo:-inf.0)) 1.)
+ (vector (list (flo:+inf.0)) (flo:+inf.0))
+ (vector (list (flo:+inf.0) 1) (flo:+inf.0))
+ (vector (list 1 (flo:+inf.0)) (flo:+inf.0))
+ (vector (list (flo:-inf.0) (flo:-inf.0)) (flo:-inf.0))
+ (vector (list (flo:+inf.0) (flo:+inf.0)) (flo:+inf.0)))
+ (lambda (v)
+ (let ((l (vector-ref v 0))
+ (s (vector-ref v 1)))
+ (assert-eqv (logsumexp l) s))))
+
+(define-enumerated-test 'logsumexp-nan
+ (vector
+ (list (flo:-inf.0) (flo:+inf.0))
+ (list (flo:+inf.0) (flo:-inf.0))
+ (list 1 (flo:-inf.0) (flo:+inf.0))
+ (list (flo:-inf.0) (flo:+inf.0) 1)
+ (list (flo:nan.0))
+ (list (flo:+inf.0) (flo:nan.0))
+ (list (flo:-inf.0) (flo:nan.0))
+ (list 1 (flo:nan.0))
+ (list (flo:nan.0) (flo:+inf.0))
+ (list (flo:nan.0) (flo:-inf.0))
+ (list (flo:nan.0) 1))
+ (lambda (l)
+ (assert-nan (flo:with-trapped-exceptions 0 (lambda () (logsumexp l))))))
+
(define-enumerated-test 'logit-logistic
(vector
(vector -36.7368005696771