]> birchwood-abbey.net Git - mit-scheme.git/commitdiff
Fix accuracy and nonsensical logic for edge cases in logsumexp.
authorTaylor R Campbell <campbell+mit-scheme@mumble.net>
Thu, 11 Feb 2021 04:48:04 +0000 (04:48 +0000)
committerTaylor R Campbell <campbell+mit-scheme@mumble.net>
Thu, 11 Feb 2021 05:04:28 +0000 (05:04 +0000)
(What was I thinking?!?)

(cherry picked from commit fc747b66302b64ab18e3cfe640ee5a4c232f4151)

doc/ref-manual/numbers.texi
src/runtime/arith.scm
tests/runtime/test-arith.scm

index 01c50e042de436f0902c35a9427b7cca83226446..cc6ac1d33c3b804e94faf2b41418590a865e10de 100644 (file)
@@ -1290,34 +1290,9 @@ log(exp(@var{x1}) + exp(@var{x2}) + @dots{} + exp(@var{xn})).
 @end example
 
 @end ifnottex
-The approximation avoids intermediate overflow and underflow.
-To minimize error, the caller should arrange for the numbers to be
-sorted from least to greatest.
-
-Edge cases:
-
-@itemize @bullet
-@item
-If @var{list} is empty, the result is @code{-inf}, as if the
-intermediate sum were zero.
-
-@item
-If @var{list} contains only finite numbers and @code{-inf}, the
-@code{-inf} elements are ignored, since the exponential of @code{-inf}
-is zero.
-
-@item
-If @var{list} contains only finite numbers and @code{+inf}, the result
-is @code{+inf} as if the sum had overflowed.
-(Otherwise, overflow is not possible.)
-
-@item
-If @var{list} contains both @code{-inf} and @code{+inf}, or if
-@var{list} contains any NaNs, the result is a NaN.
-@end itemize
-
-@code{Logsumexp} never raises any of the standard @acronym{IEEE 754-2008}
-floating-point exceptions other than invalid-operation.
+The computation avoids intermediate overflow; @code{logsumexp}
+returns @code{+inf.0} if and only if one of the inputs is
+@code{+inf.0}.
 @end deffn
 
 @deffn procedure sqrt z
index 7320035133f6a6cb5bd7086fc827b20091488e00..dee4d70d21588c34e9e8c396aea7986392253156 100644 (file)
@@ -2525,40 +2525,52 @@ USA.
   ;; Cases:
   ;;
   ;; 1. No inputs.  Empty sum is zero, so yield log(0) = -inf.
-  ;;
   ;; 2. One input.  Computation is exact.  Preserve it.
-  ;;
-  ;; 3. NaN among the inputs: invalid operation; result is NaN.
-  ;;
-  ;; 2. Maximum is +inf, and
-  ;;    (a) the minimum is -inf: inf - inf = NaN.
-  ;;    (b) the minimum is finite: sum overflows, so +inf.
-  ;;
-  ;; 3. Maximum is -inf: all inputs are -inf, so -inf.
+  ;; 3. Maximum is infinite:
+  ;;    - if +inf, sum overflows, so +inf.
+  ;;    - if -inf, all inputs are -inf, so -inf.
+  ;; 4. NaN among the inputs: invalid operation; result is NaN.
   ;;
   ;; Most likely all the inputs are finite, so prioritize that case by
   ;; checking for an infinity first -- if there is a NaN, the usual
   ;; computation will propagate it.
   ;;
   ;; Overflow is not possible because everything is normalized to be
-  ;; below zero.  Underflow can be safely ignored because it can't
-  ;; change the outcome: even if you had 2^64 copies of the largest
-  ;; subnormal in the sum, 2^64 * largest subnormal < 2^-900 <<<
-  ;; epsilon = 2^-53, and at least one addend in the sum is 1 since we
-  ;; compute e^{m - m} = e^0 = 1.
+  ;; below zero.
   ;;
-  (let ((m (reduce max #f l)))
-    (cond ((not (pair? l)) (flo:-inf.0))
-         ((not (pair? (cdr l))) (car l))
-         ((and (infinite? m)
-               (not (= (- m) (reduce min #f l)))
-               (not (any nan? l)))
-          m)
-         (else
-          (flo:with-exceptions-untrapped (flo:exception:underflow)
-            (lambda ()
-              (+ m
-                 (log (reduce + 0 (map (lambda (x) (exp (- x m))) l))))))))))
+  (cond ((not (pair? l))
+        (flo:-inf.0))
+       ((not (pair? (cdr l)))
+        (guarantee-real (car l) 'logsumexp)
+        (car l))
+       (else
+        (let* ((v (list->vector l))
+               (n (vector-length v)))
+          ;; Find the position of the maximum.
+          (let loop ((i 1) (imax 0) (xmax (vector-ref v 0)))
+            (if (< i n)
+                (let ((x (vector-ref v i)))
+                  (if (<= x xmax)
+                      (loop (+ i 1) imax xmax)
+                      (loop (+ i 1) i x)))
+                ;; Avoid arithmetic on +inf.0 and just pass it
+                ;; through.
+                (if (= xmax +inf.0)
+                    (or (find nan? l) +inf.0)
+                    ;; Compute xmax + log(1 + sum_i e^{xi - xmax}), with
+                    ;; imax omitted from the sum.
+                    (let loop ((i 0) (sumexp 0))
+                      (if (>= i n)
+                          (+ xmax (log1p sumexp))
+                          (let ((x (vector-ref v i)))
+                            (cond ((= x -inf.0)
+                                   (loop (+ i 1) (exact->inexact sumexp)))
+                                  ((= i imax)
+                                   (loop (+ i 1) sumexp))
+                                  (else
+                                   (loop (+ i 1)
+                                         (+ sumexp
+                                            (exp (- x xmax))))))))))))))))
 \f
 ;;; Logistic function: 1/(1 + e^{-x}) = e^x/(1 + e^x). Maps a
 ;;; log-odds-space probability in [-\infty, +\infty] into a
index a133614eb72f339c3a7f2ab34e8ede9307b06d53..c3d12c49d8a99226e88a9ca95e408fa7a6183b36 100644 (file)
@@ -445,8 +445,7 @@ USA.
    (list '(0 0) (log 2))
    ;; log(2^-30), log(1 + 2^-29) -> log(1 + 2^-29 + 2^-30)
    (list (list -20.79441541679836 1.8626451474962336e-9)
-        2.7939677199433077e-9
-        expect-failure))
+        2.7939677199433077e-9))
   (lambda (l s #!optional xfail)
     (with-expected-failure xfail
       (lambda ()
@@ -467,12 +466,12 @@ USA.
    (list (list (flo:+inf.0)) (flo:+inf.0))
    (list (list (flo:+inf.0) 1) (flo:+inf.0))
    (list (list 1 (flo:+inf.0)) (flo:+inf.0))
-   (list (list 1 (flo:-inf.0) (flo:+inf.0)) (flo:+inf.0) expect-failure)
-   (list (list (flo:-inf.0) (flo:+inf.0) 1) (flo:+inf.0) expect-failure)
+   (list (list 1 (flo:-inf.0) (flo:+inf.0)) (flo:+inf.0))
+   (list (list (flo:-inf.0) (flo:+inf.0) 1) (flo:+inf.0))
    (list (list (flo:-inf.0) (flo:-inf.0)) (flo:-inf.0))
-   (list (list (flo:-inf.0) (flo:+inf.0)) (flo:+inf.0) expect-failure)
+   (list (list (flo:-inf.0) (flo:+inf.0)) (flo:+inf.0))
    (list (list (flo:+inf.0) (flo:+inf.0)) (flo:+inf.0))
-   (list (list (flo:+inf.0) (flo:-inf.0)) (flo:+inf.0) expect-failure))
+   (list (list (flo:+inf.0) (flo:-inf.0)) (flo:+inf.0)))
   (lambda (l s #!optional xfail)
     (with-expected-failure xfail
       (lambda ()