Tidy up error analysis of logit and logistic.
authorTaylor R Campbell <campbell@mumble.net>
Sun, 28 Oct 2018 04:13:26 +0000 (04:13 +0000)
committerTaylor R Campbell <campbell@mumble.net>
Sun, 28 Oct 2018 04:16:22 +0000 (04:16 +0000)
Add (logit1/2+ p) = (logit (+ 1/2 p)) and (logistic-1/2 x) = (-
(logistic x) 1/2), for reasons like log1p and expm1.  Add some
trivial tests of the edge cases they cover where logit and logistic
are ill-conditioned.

src/runtime/arith.scm
src/runtime/runtime.pkg
tests/runtime/test-arith.scm

index 087b593c7bdd09e16e6b9497eb95ab0c92cde47c..019ad736d749a4d7a7b6d393b3522db1125ac8c7 100644 (file)
@@ -2016,6 +2016,7 @@ USA.
   (complex:* z (complex:* z z)))
 \f
 ;;; log(1 - e^x), defined only on negative x
+
 (define (log1mexp x)
   (guarantee-real x 'log1mexp)
   (guarantee-negative x 'log1mexp)
@@ -2024,6 +2025,7 @@ USA.
       (log1p (- (exp x)))))
 
 ;;; log(1 + e^x)
+
 (define (log1pexp x)
   (guarantee-real x 'log1pexp)
   (cond ((<= x flo:subnormal-exponent-min-base-e) 0.)
@@ -2032,73 +2034,83 @@ USA.
        ((<= x 33.3) (+ x (exp (- x))))
        (else (exact->inexact x))))
 
-;;; Lemma 1.  If |d'| < 1/2, then |(d' - d)/(1 + d')| <= 2|d' - d|.
+;;; Some lemmas for the bounds below.
+;;;
+;;; Lemma 1.  If |d| < 1/2, then 1/(1 + d) <= 2.
 ;;;
-;;; Proof.  Case I: If d' > 0, then 1 + d' > 1 so 0 < 1/(1 + d') <= 1.
-;;; Case II: If -1/2 < d' < 0, then 1/2 < 1 + d' < 1 so that 1 < 1/(1 +
-;;; d') <= 2.  QED.
+;;; Proof.  If 0 <= d <= 1/2, then 1 + d >= 1, so that 1/(1 + d) <= 1.
+;;; If -1/2 <= d <= 0, then 1 + d >= 1/2, so that 1/(1 + d) <= 2.  QED.
 ;;;
 ;;; Lemma 2. If b = a*(1 + d)/(1 + d') for |d'| < 1/2 and nonzero a, b,
 ;;; then b = a*(1 + e) for |e| <= 2|d' - d|.
 ;;;
 ;;; Proof.  |a - b|/|a|
-;;;             = |a - a*(1 + d)/(1 + d')|/|a|
-;;;             = |1 - (1 + d)/(1 + d')|
-;;;             = |(1 + d' - 1 - d)/(1 + d')|
-;;;             = |(d' - d)/(1 + d')|
-;;;            <= 2|d' - d|, by Lemma 1,
+;;;            = |a - a*(1 + d)/(1 + d')|/|a|
+;;;            = |1 - (1 + d)/(1 + d')|
+;;;            = |(1 + d' - 1 - d)/(1 + d')|
+;;;            = |(d' - d)/(1 + d')|
+;;;           <= 2|d' - d|, by Lemma 1,
 ;;;
 ;;; QED.
 ;;;
 ;;; Lemma 3.  For |d|, |d'| < 1/4,
 ;;;
-;;;     |log((1 + d)/(1 + d'))| <= 4|d - d'|.
+;;;    |log((1 + d)/(1 + d'))| <= 4|d - d'|.
 ;;;
 ;;; Proof.  Write
 ;;;
-;;;     log((1 + d)/(1 + d'))
-;;;      = log(1 + (1 + d)/(1 + d') - 1)
-;;;      = log(1 + (1 + d - 1 - d')/(1 + d')
-;;;      = log(1 + (d - d')/(1 + d')).
+;;;    log((1 + d)/(1 + d'))
+;;;     = log(1 + (1 + d)/(1 + d') - 1)
+;;;     = log(1 + (1 + d - 1 - d')/(1 + d')
+;;;     = log(1 + (d - d')/(1 + d')).
 ;;;
 ;;; By Lemma 1, |(d - d')/(1 + d')| < 2|d' - d| < 1, so the Taylor
 ;;; series of log(1 + x) converges absolutely for (d - d')/(1 + d'),
 ;;; and thus we have
 ;;;
-;;;     |log(1 + (d - d')/(1 + d'))|
-;;;      = |\sum_{n=1}^\infty ((d - d')/(1 + d'))^n/n|
-;;;     <= \sum_{n=1}^\infty |(d - d')/(1 + d')|^n/n
-;;;     <= \sum_{n=1}^\infty |2(d' - d)|^n/n
-;;;     <= \sum_{n=1}^\infty |2(d' - d)|^n
-;;;      = 1/(1 - |2(d' - d)|)
-;;;     <= 4|d' - d|,
+;;;    |log(1 + (d - d')/(1 + d'))|
+;;;     = |\sum_{n=1}^\infty ((d - d')/(1 + d'))^n/n|
+;;;    <= \sum_{n=1}^\infty |(d - d')/(1 + d')|^n/n
+;;;    <= \sum_{n=1}^\infty |2(d' - d)|^n/n
+;;;    <= \sum_{n=1}^\infty |2(d' - d)|^n
+;;;     = 1/(1 - |2(d' - d)|)
+;;;    <= 4|d' - d|,
 ;;;
 ;;; QED.
 ;;;
 ;;; Lemma 4.  If 1/e <= 1 + x <= e, then
 ;;;
-;;;     log(1 + (1 + d) x) = (1 + d') log(1 + x)
+;;;    log(1 + (1 + d) x) = (1 + d') log(1 + x)
 ;;;
 ;;; for |d'| < 8|d|.
 ;;;
 ;;; Proof.  Write
 ;;;
-;;;     log(1 + (1 + d) x)
-;;;     = log(1 + x + x*d)
-;;;     = log((1 + x) (1 + x + x*d)/(1 + x))
-;;;     = log(1 + x) + log((1 + x + x*d)/(1 + x))
-;;;     = log(1 + x) (1 + log((1 + x + x*d)/(1 + x))/log(1 + x)).
+;;;    log(1 + (1 + d) x)
+;;;    = log(1 + x + x*d)
+;;;    = log((1 + x) (1 + x + x*d)/(1 + x))
+;;;    = log(1 + x) + log((1 + x + x*d)/(1 + x))
+;;;    = log(1 + x) (1 + log((1 + x + x*d)/(1 + x))/log(1 + x)).
 ;;;
 ;;; The relative error is bounded by
 ;;;
-;;;     |log((1 + x + x*d)/(1 + x))/log(1 + x)|
-;;;     <= 4|x + x*d - x|/|log(1 + x)|, by Lemma 3,
-;;;      = 4|x*d|/|log(1 + x)|
-;;;      < 8|d|,
+;;;    |log((1 + x + x*d)/(1 + x))/log(1 + x)|
+;;;    <= 4|x + x*d - x|/|log(1 + x)|, by Lemma 3,
+;;;     = 4|x*d|/|log(1 + x)|
+;;;     < 8|d|,
 ;;;
 ;;; since in this range 0 < 1 - 1/e < x/log(1 + x) <= e - 1 < 2.  QED.
 
-;;; 1/(1 + e^{-x}) = e^x/(1 + e^x)
+;;; Logistic function: 1/(1 + e^{-x}) = e^x/(1 + e^x). Maps a
+;;; log-odds-space probability in [-\infty, +\infty] into a
+;;; direct-space probability in [0,1]. Inverse of logit.
+;;;
+;;; Ill-conditioned for large x; the identity logistic(-x) = 1 -
+;;; logistic(x) and the function (logistic-1/2 x) = (- (logistic x)
+;;; 1/2) may help to rearrange a computation.
+;;;
+;;; This implementation gives relative error bounded by 7 eps.
+
 (define (logistic x)
   (guarantee-real x 'logistic)
   (cond ((<= x flo:subnormal-exponent-min-base-e)
@@ -2162,12 +2174,60 @@ USA.
         ;;
         1.)))
 
+;;; Logistic function, translated in output by 1/2: logistic(x) - 1/2 =
+;;; 1/(1 + e^{-x}) - 1/2. Well-conditioned on the entire real plane,
+;;; with maximum condition number 1 at 0.
+;;;
+;;; This implementation gives relative error bounded by 5 eps.
+
+(define (logistic-1/2 x)
+  ;; Suppose exp has error d0, + has error d1, expm1 has error d2, and
+  ;; / has error d3, so we evaluate
+  ;;
+  ;;   -(1 + d2) (1 + d3) (e^{-x} - 1)
+  ;;     / [2 (1 + d1) (1 + (1 + d0) e^{-x})].
+  ;;
+  ;; In the denominator,
+  ;;
+  ;;   1 + (1 + d0) e^{-x}
+  ;;   = 1 + e^{-x} + d0 e^{-x}
+  ;;   = (1 + e^{-x}) (1 + d0 e^{-x}/(1 + e^{-x})),
+  ;;
+  ;; so the relative error of the numerator is
+  ;;
+  ;;   d' = d2 + d3 + d2 d3,
+  ;; and of the denominator,
+  ;;   d'' = d1 + d0 e^{-x}/(1 + e^{-x}) + d0 d1 e^{-x}/(1 + e^{-x})
+  ;;       = d1 + d0 L(-x) + d0 d1 L(-x),
+  ;;
+  ;; where L(-x) is logistic(-x).  By Lemma 1 the relative error of the
+  ;; quotient is bounded by
+  ;;
+  ;;   2|d2 + d3 + d2 d3 - d1 - d0 L(x) + d0 d1 L(x)|,
+  ;;
+  ;; Since 0 < L(x) < 1, this is bounded by
+  ;;
+  ;;   2|d2| + 2|d3| + 2|d2 d3| + 2|d1| + 2|d0| + 2|d0 d1|
+  ;;   <= 4 eps + 2 eps^2.
+  ;;
+  (- (/ (expm1 (- x)) (* 2 (+ 1 (exp (- x)))))))
+
 (define-integrable logit-boundary-lo   ;logistic(-1)
   (flo:/ (flo:exp -1.) (flo:+ 1. (flo:exp -1.))))
 (define-integrable logit-boundary-hi   ;logistic(+1)
   (flo:/ 1. (flo:+ 1. (flo:exp -1.))))
 
-;;; log p/(1 - p)
+;;; Logit function: log p/(1 - p).  Defined on [0,1].  Maps a
+;;; direct-space probability in [0,1] to a log-odds-space probability
+;;; in [-\infty, +\infty].  Inverse of logistic.
+;;;
+;;; Ill-conditioned near 1/2 and 1; the identity logit(1 - p) =
+;;; -logit(p) and the function (logit1/2+ p0) = (logit (+ 1/2 p0)) may
+;;; help to rearrange a computation for p in [1/(1 + e), 1 - 1/(1 +
+;;; e)].
+;;;
+;;; This implementation gives relative error bounded by 10 eps.
+
 (define (logit p)
   (guarantee-real p 'logit)
   (if (not (<= 0 p 1))
@@ -2182,8 +2242,8 @@ USA.
   ;;     = -log(1 + (1 - 2p)/p).
   ;;
   ;; to get an intermediate quotient near zero.
+  ;;
   (cond ((<= logit-boundary-lo p logit-boundary-hi)
-        ;;
         ;; Since p = 2p/2 <= 1 <= 2*2p = 4p, the floating-point
         ;; evaluation of 1 - 2p is exact; the only error arises from
         ;; division and log1p.  First, note that if logistic(-1) <= p
@@ -2228,7 +2288,79 @@ USA.
         ;;
         (log (/ p (- 1 p))))))
 
+;;; Logit function, translated in input by 1/2: (logit1/2+ p-1/2) =
+;;; (logit (+ 1/2 p-1/2)).  Defined on [-1/2, 1/2].  Inverse of
+;;; logistic-1/2.
+;;;
+;;; Ill-conditioned near +/-1/2.  If |p0| > 1/2 - 1/(1 + e), it may be
+;;; better to compute 1/2 + p0 or -1/2 - p0 and to use logit instead.
+;;; This implementation gives relative error bounded by 10 eps.
+
+(define (logit1/2+ p-1/2)
+  (cond ((<= (abs p-1/2) (- 1/2 (/ 1 (+ 1 (exp 1)))))
+        ;; If p' = p - 1/2, then p = 1/2 + p', so we compute:
+        ;;
+        ;; log(p/(1 - p))
+        ;; = log((1/2 + p')/(1 - (1/2 + p')))
+        ;; = log((1/2 + p')/(1/2 - p'))
+        ;; = log(1 + (1/2 + p')/(1/2 - p') - 1)
+        ;; = log(1 + (1/2 + p' - (1/2 - p'))/(1/2 - p'))
+        ;; = log(1 + (1/2 + p' - 1/2 + p')/(1/2 - p'))
+        ;; = log(1 + 2 p'/(1/2 - p'))
+        ;;
+        ;; Note that since p0/2 <= 1/2 <= 2 p0, 1/2 - p0 is
+        ;; computed exactly without error; the only error
+        ;; arises from division and log1p.  If the error of
+        ;; division is d0 and the error of log1p is d1, then
+        ;; what we compute is
+        ;;
+        ;;     (1 + d1) log(1 + (1 + d0) 2 p0/(1/2 - p0))
+        ;;     = (1 + d1) (1 + d') log(1 + 2 p0/(1/2 - p0))
+        ;;     = (1 + d1 + d' + d1 d') log(1 + 2 p0/(1/2 - p0)).
+        ;;
+        ;; where |d'| < 8|d0| by Lemma 4, since
+        ;;
+        ;;     1/e <= 1 + 2*p0/(1/2 - p0) <= e
+        ;;
+        ;; when |p0| <= 1/2 - 1/(1 + e).  Hence the relative
+        ;; error is bounded by
+        ;;
+        ;;     |d1 + d' + d1 d'|
+        ;;     <= |d1| + |d'| + |d1 d'|
+        ;;     <= |d1| + 8 |d0| + 8 |d1 d0|
+        ;;     <= 9 eps + 8 eps^2.
+        ;;
+        (log1p (/ (* 2 p-1/2) (- 1/2 p-1/2))))
+       (else
+        ;; We have a choice of computing logit(1/2 + p0) or -logit(1 -
+        ;; (1/2 + p0)) = -logit(1/2 - p0).  It doesn't matter which
+        ;; way we do this: either way, since 1/2 p0 <= 1/2 <= 2 p0,
+        ;; the sum and difference are computed exactly.  So let's do
+        ;; the one that skips the final negation.
+        ;;
+        ;; Again, the only error arises from division and log.  So the
+        ;; result is
+        ;;
+        ;;     (1 + d1) log((1 + d0) (1/2 + p0)/(1/2 - p0))
+        ;;     = (1 + d1) (1 + log(1 + d0)/log((1/2 + p0)/(1/2 - p0)))
+        ;;       * log((1/2 + p0)/(1/2 - p0))
+        ;;     = (1 + d') log((1/2 + p0)/(1/2 - p0))
+        ;;
+        ;; where
+        ;;
+        ;;     d' = d1 + log(1 + d0)/log((1/2 + p0)/(1/2 - p0))
+        ;;          + d1 log(1 + d0)/log((1/2 + p0)/(1/2 - p0)).
+        ;;
+        ;; For |p| > 1/2 - 1/(1 + e), logit(1/2 + p0) > 1.  For |d0| <
+        ;; 1/2, |log(1 + d0)| < 2|d0|.  Hence this is bounded by
+        ;;
+        ;;     |d'| <= |d1| + 2|d0| + 2|d0 d1|
+        ;;          <= 3 eps + 2 eps^2.
+        ;;
+        (log (/ (+ 1/2 p-1/2) (- 1/2 p-1/2))))))
+
 ;;; log logistic(x) = -log (1 + e^{-x})
+
 (define (log-logistic x)
   (guarantee-real x 'log-logistic)
   (- (log1pexp (- x))))
@@ -2239,6 +2371,7 @@ USA.
   (flo:- 0. (flo:log (flo:+ 1. (flo:exp -1.)))))
 
 ;;; log e^t/(1 - e^t) = logit(e^t)
+
 (define (logit-exp t)
   (guarantee-real t 'logit-exp)
   (cond ((<= t flo:log-epsilon)
index 893cabc6529a08253950a2ea674dd7e60b388722..6eea6fe64b4bdf331c27cdc8ebf7eb2899fb41c0 100644 (file)
@@ -3383,8 +3383,10 @@ USA.
          log1mexp
          log1pexp
          logistic
+         logistic-1/2
          logit
          logit-exp
+         logit1/2+
          max
          min
          modulo
index 9384402e94c4450c20815560b75eedccae712e6e..b959afd71f2055e2a38fa3e2a8c6843999866d9e 100644 (file)
@@ -232,6 +232,20 @@ USA.
       (if (< p 1)
           (assert-<= (relerr (log1p (- p)) (log-logistic (- x))) 1e-15)))))
 
+(define-enumerated-test 'logit-logistic-1/2
+  (vector
+   (vector 1e-300 4e-300)
+   (vector 1e-16 4e-16)
+   (vector .2310585786300049 1.)
+   (vector .49999999999999994 37.42994775023705))
+  (lambda (v)
+    (let ((p (vector-ref v 0))
+          (x (vector-ref v 1)))
+      (assert-<= (relerr x (logit1/2+ p)) 1e-15)
+      (assert-<= (relerr p (logistic-1/2 x)) 1e-15)
+      (assert-= (- (logit1/2+ p)) (logit1/2+ (- p)))
+      (assert-= (- (logistic-1/2 x)) (logistic-1/2 (- x))))))
+
 (define-enumerated-test 'expt-exact
   (vector
    (vector 2. -1075 "0.")