Add (logsumexp (list 1 2 3)) = (log (+ (exp 1) (exp 2) (exp 3))).
authorTaylor R Campbell <campbell@mumble.net>
Tue, 30 Oct 2018 15:41:22 +0000 (15:41 +0000)
committerTaylor R Campbell <campbell@mumble.net>
Tue, 30 Oct 2018 15:41:22 +0000 (15:41 +0000)
src/runtime/arith.scm
src/runtime/runtime.pkg
tests/runtime/test-arith.scm

index 51042d68e26f02cbecadf5281a9a700a67b7961c..a129257fd0cb8c08df4fadbe267c24ca2a4a8f0e 100644 (file)
@@ -2036,6 +2036,38 @@ USA.
        ((<= x 33.3) (+ x (exp (- x))))
        (else (exact->inexact x))))
 
+;;; log(e^x + e^y + ...)
+;;;
+;;; Caller can minimize error by passing descending inputs below 0, or
+;;; ascending inputs above 1.
+
+(define (logsumexp l)
+  (if (pair? l)
+      (if (pair? (cdr l))
+         (let ((m (reduce max #f l)))
+           ;; Cases:
+           ;;
+           ;; 1. There is a NaN among the inputs: invalid operation;
+           ;;    result is NaN.
+           ;;
+           ;; 2. The maximum is +inf, and
+           ;;    (a) the minimum is -inf: NaN.
+           ;;    (b) the minimum is finite: +inf.
+           ;;
+           ;; 3. The maximum is -inf: all inputs are -inf, so -inf.
+           ;;
+           ;; Most likely all the inputs are finite, so prioritize
+           ;; that case by checking for an infinity first -- if there
+           ;; is a NaN, the usual computation will propagate it.
+           ;;
+           (if (and (infinite? m)
+                    (not (= (- m) (reduce min #f l)))
+                    (not (any nan? l)))
+               m
+               (+ m (log (reduce + 0 (map (lambda (x) (exp (- x m))) l))))))
+         (car l))
+      (flo:-inf.0)))
+
 ;;; Some lemmas for the bounds below.
 ;;;
 ;;; Lemma 1.  If |d| < 1/2, then 1/(1 + d) <= 2.
index 75032dd44abcc7a2d73b822050e6619c9bb53fb4..7b3c89686405eaad06da1fcb36fc09e593b7993c 100644 (file)
@@ -3388,6 +3388,7 @@ USA.
          logit
          logit-exp
          logit1/2+
+         logsumexp
          max
          min
          modulo
index b959afd71f2055e2a38fa3e2a8c6843999866d9e..ebce10c033909ab0f19cb47aa9203f5b949eddb5 100644 (file)
@@ -190,6 +190,50 @@ USA.
   (lambda (v)
     (assert-<= (relerr (cdr v) (log1pexp (car v))) 1e-15)))
 
+(define-enumerated-test 'logsumexp-values
+  (vector
+   (vector '(999 1000) 1000.3132616875182)
+   (vector '(-1000 -1000) (+ -1000 (log 2)))
+   (vector '(0 0) (log 2)))
+  (lambda (v)
+    (let ((l (vector-ref v 0))
+         (s (vector-ref v 1)))
+      (assert-<= (relerr s (logsumexp l)) 1e-15))))
+
+(define-enumerated-test 'logsumexp-edges
+  (vector
+   (vector '() (flo:-inf.0))
+   (vector '(-1000) -1000)
+   (vector '(-1000.) -1000.)
+   (vector (list (flo:-inf.0)) (flo:-inf.0))
+   (vector (list (flo:-inf.0) 1) 1.)
+   (vector (list 1 (flo:-inf.0)) 1.)
+   (vector (list (flo:+inf.0)) (flo:+inf.0))
+   (vector (list (flo:+inf.0) 1) (flo:+inf.0))
+   (vector (list 1 (flo:+inf.0)) (flo:+inf.0))
+   (vector (list (flo:-inf.0) (flo:-inf.0)) (flo:-inf.0))
+   (vector (list (flo:+inf.0) (flo:+inf.0)) (flo:+inf.0)))
+  (lambda (v)
+    (let ((l (vector-ref v 0))
+         (s (vector-ref v 1)))
+      (assert-eqv (logsumexp l) s))))
+
+(define-enumerated-test 'logsumexp-nan
+  (vector
+   (list (flo:-inf.0) (flo:+inf.0))
+   (list (flo:+inf.0) (flo:-inf.0))
+   (list 1 (flo:-inf.0) (flo:+inf.0))
+   (list (flo:-inf.0) (flo:+inf.0) 1)
+   (list (flo:nan.0))
+   (list (flo:+inf.0) (flo:nan.0))
+   (list (flo:-inf.0) (flo:nan.0))
+   (list 1 (flo:nan.0))
+   (list (flo:nan.0) (flo:+inf.0))
+   (list (flo:nan.0) (flo:-inf.0))
+   (list (flo:nan.0) 1))
+  (lambda (l)
+    (assert-nan (flo:with-trapped-exceptions 0 (lambda () (logsumexp l))))))
+
 (define-enumerated-test 'logit-logistic
   (vector
    (vector -36.7368005696771