From ab79270e1da3a7e0eddc0e5cac9888e27aad6118 Mon Sep 17 00:00:00 2001 From: Taylor R Campbell Date: Tue, 30 Oct 2018 15:41:22 +0000 Subject: [PATCH] Add (logsumexp (list 1 2 3)) = (log (+ (exp 1) (exp 2) (exp 3))). --- src/runtime/arith.scm | 32 ++++++++++++++++++++++++++ src/runtime/runtime.pkg | 1 + tests/runtime/test-arith.scm | 44 ++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/src/runtime/arith.scm b/src/runtime/arith.scm index 51042d68e..a129257fd 100644 --- a/src/runtime/arith.scm +++ b/src/runtime/arith.scm @@ -2036,6 +2036,38 @@ USA. ((<= x 33.3) (+ x (exp (- x)))) (else (exact->inexact x)))) +;;; log(e^x + e^y + ...) +;;; +;;; Caller can minimize error by passing descending inputs below 0, or +;;; ascending inputs above 1. + +(define (logsumexp l) + (if (pair? l) + (if (pair? (cdr l)) + (let ((m (reduce max #f l))) + ;; Cases: + ;; + ;; 1. There is a NaN among the inputs: invalid operation; + ;; result is NaN. + ;; + ;; 2. The maximum is +inf, and + ;; (a) the minimum is -inf: NaN. + ;; (b) the minimum is finite: +inf. + ;; + ;; 3. The maximum is -inf: all inputs are -inf, so -inf. + ;; + ;; Most likely all the inputs are finite, so prioritize + ;; that case by checking for an infinity first -- if there + ;; is a NaN, the usual computation will propagate it. + ;; + (if (and (infinite? m) + (not (= (- m) (reduce min #f l))) + (not (any nan? l))) + m + (+ m (log (reduce + 0 (map (lambda (x) (exp (- x m))) l)))))) + (car l)) + (flo:-inf.0))) + ;;; Some lemmas for the bounds below. ;;; ;;; Lemma 1. If |d| < 1/2, then 1/(1 + d) <= 2. diff --git a/src/runtime/runtime.pkg b/src/runtime/runtime.pkg index 75032dd44..7b3c89686 100644 --- a/src/runtime/runtime.pkg +++ b/src/runtime/runtime.pkg @@ -3388,6 +3388,7 @@ USA. logit logit-exp logit1/2+ + logsumexp max min modulo diff --git a/tests/runtime/test-arith.scm b/tests/runtime/test-arith.scm index b959afd71..ebce10c03 100644 --- a/tests/runtime/test-arith.scm +++ b/tests/runtime/test-arith.scm @@ -190,6 +190,50 @@ USA. (lambda (v) (assert-<= (relerr (cdr v) (log1pexp (car v))) 1e-15))) +(define-enumerated-test 'logsumexp-values + (vector + (vector '(999 1000) 1000.3132616875182) + (vector '(-1000 -1000) (+ -1000 (log 2))) + (vector '(0 0) (log 2))) + (lambda (v) + (let ((l (vector-ref v 0)) + (s (vector-ref v 1))) + (assert-<= (relerr s (logsumexp l)) 1e-15)))) + +(define-enumerated-test 'logsumexp-edges + (vector + (vector '() (flo:-inf.0)) + (vector '(-1000) -1000) + (vector '(-1000.) -1000.) + (vector (list (flo:-inf.0)) (flo:-inf.0)) + (vector (list (flo:-inf.0) 1) 1.) + (vector (list 1 (flo:-inf.0)) 1.) + (vector (list (flo:+inf.0)) (flo:+inf.0)) + (vector (list (flo:+inf.0) 1) (flo:+inf.0)) + (vector (list 1 (flo:+inf.0)) (flo:+inf.0)) + (vector (list (flo:-inf.0) (flo:-inf.0)) (flo:-inf.0)) + (vector (list (flo:+inf.0) (flo:+inf.0)) (flo:+inf.0))) + (lambda (v) + (let ((l (vector-ref v 0)) + (s (vector-ref v 1))) + (assert-eqv (logsumexp l) s)))) + +(define-enumerated-test 'logsumexp-nan + (vector + (list (flo:-inf.0) (flo:+inf.0)) + (list (flo:+inf.0) (flo:-inf.0)) + (list 1 (flo:-inf.0) (flo:+inf.0)) + (list (flo:-inf.0) (flo:+inf.0) 1) + (list (flo:nan.0)) + (list (flo:+inf.0) (flo:nan.0)) + (list (flo:-inf.0) (flo:nan.0)) + (list 1 (flo:nan.0)) + (list (flo:nan.0) (flo:+inf.0)) + (list (flo:nan.0) (flo:-inf.0)) + (list (flo:nan.0) 1)) + (lambda (l) + (assert-nan (flo:with-trapped-exceptions 0 (lambda () (logsumexp l)))))) + (define-enumerated-test 'logit-logistic (vector (vector -36.7368005696771 -- 2.25.1