From fa257b411b3469fdc38c4af3e3eee2b15540b2f6 Mon Sep 17 00:00:00 2001 From: Taylor R Campbell Date: Sun, 28 Oct 2018 04:13:26 +0000 Subject: [PATCH] Tidy up error analysis of logit and logistic. Add (logit1/2+ p) = (logit (+ 1/2 p)) and (logistic-1/2 x) = (- (logistic x) 1/2), for reasons like log1p and expm1. Add some trivial tests of the edge cases they cover where logit and logistic are ill-conditioned. --- src/runtime/arith.scm | 201 +++++++++++++++++++++++++++++------ src/runtime/runtime.pkg | 2 + tests/runtime/test-arith.scm | 14 +++ 3 files changed, 183 insertions(+), 34 deletions(-) diff --git a/src/runtime/arith.scm b/src/runtime/arith.scm index 087b593c7..019ad736d 100644 --- a/src/runtime/arith.scm +++ b/src/runtime/arith.scm @@ -2016,6 +2016,7 @@ USA. (complex:* z (complex:* z z))) ;;; log(1 - e^x), defined only on negative x + (define (log1mexp x) (guarantee-real x 'log1mexp) (guarantee-negative x 'log1mexp) @@ -2024,6 +2025,7 @@ USA. (log1p (- (exp x))))) ;;; log(1 + e^x) + (define (log1pexp x) (guarantee-real x 'log1pexp) (cond ((<= x flo:subnormal-exponent-min-base-e) 0.) @@ -2032,73 +2034,83 @@ USA. ((<= x 33.3) (+ x (exp (- x)))) (else (exact->inexact x)))) -;;; Lemma 1. If |d'| < 1/2, then |(d' - d)/(1 + d')| <= 2|d' - d|. +;;; Some lemmas for the bounds below. +;;; +;;; Lemma 1. If |d| < 1/2, then 1/(1 + d) <= 2. ;;; -;;; Proof. Case I: If d' > 0, then 1 + d' > 1 so 0 < 1/(1 + d') <= 1. -;;; Case II: If -1/2 < d' < 0, then 1/2 < 1 + d' < 1 so that 1 < 1/(1 + -;;; d') <= 2. QED. +;;; Proof. If 0 <= d <= 1/2, then 1 + d >= 1, so that 1/(1 + d) <= 1. +;;; If -1/2 <= d <= 0, then 1 + d >= 1/2, so that 1/(1 + d) <= 2. QED. ;;; ;;; Lemma 2. If b = a*(1 + d)/(1 + d') for |d'| < 1/2 and nonzero a, b, ;;; then b = a*(1 + e) for |e| <= 2|d' - d|. ;;; ;;; Proof. |a - b|/|a| -;;; = |a - a*(1 + d)/(1 + d')|/|a| -;;; = |1 - (1 + d)/(1 + d')| -;;; = |(1 + d' - 1 - d)/(1 + d')| -;;; = |(d' - d)/(1 + d')| -;;; <= 2|d' - d|, by Lemma 1, +;;; = |a - a*(1 + d)/(1 + d')|/|a| +;;; = |1 - (1 + d)/(1 + d')| +;;; = |(1 + d' - 1 - d)/(1 + d')| +;;; = |(d' - d)/(1 + d')| +;;; <= 2|d' - d|, by Lemma 1, ;;; ;;; QED. ;;; ;;; Lemma 3. For |d|, |d'| < 1/4, ;;; -;;; |log((1 + d)/(1 + d'))| <= 4|d - d'|. +;;; |log((1 + d)/(1 + d'))| <= 4|d - d'|. ;;; ;;; Proof. Write ;;; -;;; log((1 + d)/(1 + d')) -;;; = log(1 + (1 + d)/(1 + d') - 1) -;;; = log(1 + (1 + d - 1 - d')/(1 + d') -;;; = log(1 + (d - d')/(1 + d')). +;;; log((1 + d)/(1 + d')) +;;; = log(1 + (1 + d)/(1 + d') - 1) +;;; = log(1 + (1 + d - 1 - d')/(1 + d') +;;; = log(1 + (d - d')/(1 + d')). ;;; ;;; By Lemma 1, |(d - d')/(1 + d')| < 2|d' - d| < 1, so the Taylor ;;; series of log(1 + x) converges absolutely for (d - d')/(1 + d'), ;;; and thus we have ;;; -;;; |log(1 + (d - d')/(1 + d'))| -;;; = |\sum_{n=1}^\infty ((d - d')/(1 + d'))^n/n| -;;; <= \sum_{n=1}^\infty |(d - d')/(1 + d')|^n/n -;;; <= \sum_{n=1}^\infty |2(d' - d)|^n/n -;;; <= \sum_{n=1}^\infty |2(d' - d)|^n -;;; = 1/(1 - |2(d' - d)|) -;;; <= 4|d' - d|, +;;; |log(1 + (d - d')/(1 + d'))| +;;; = |\sum_{n=1}^\infty ((d - d')/(1 + d'))^n/n| +;;; <= \sum_{n=1}^\infty |(d - d')/(1 + d')|^n/n +;;; <= \sum_{n=1}^\infty |2(d' - d)|^n/n +;;; <= \sum_{n=1}^\infty |2(d' - d)|^n +;;; = 1/(1 - |2(d' - d)|) +;;; <= 4|d' - d|, ;;; ;;; QED. ;;; ;;; Lemma 4. If 1/e <= 1 + x <= e, then ;;; -;;; log(1 + (1 + d) x) = (1 + d') log(1 + x) +;;; log(1 + (1 + d) x) = (1 + d') log(1 + x) ;;; ;;; for |d'| < 8|d|. ;;; ;;; Proof. Write ;;; -;;; log(1 + (1 + d) x) -;;; = log(1 + x + x*d) -;;; = log((1 + x) (1 + x + x*d)/(1 + x)) -;;; = log(1 + x) + log((1 + x + x*d)/(1 + x)) -;;; = log(1 + x) (1 + log((1 + x + x*d)/(1 + x))/log(1 + x)). +;;; log(1 + (1 + d) x) +;;; = log(1 + x + x*d) +;;; = log((1 + x) (1 + x + x*d)/(1 + x)) +;;; = log(1 + x) + log((1 + x + x*d)/(1 + x)) +;;; = log(1 + x) (1 + log((1 + x + x*d)/(1 + x))/log(1 + x)). ;;; ;;; The relative error is bounded by ;;; -;;; |log((1 + x + x*d)/(1 + x))/log(1 + x)| -;;; <= 4|x + x*d - x|/|log(1 + x)|, by Lemma 3, -;;; = 4|x*d|/|log(1 + x)| -;;; < 8|d|, +;;; |log((1 + x + x*d)/(1 + x))/log(1 + x)| +;;; <= 4|x + x*d - x|/|log(1 + x)|, by Lemma 3, +;;; = 4|x*d|/|log(1 + x)| +;;; < 8|d|, ;;; ;;; since in this range 0 < 1 - 1/e < x/log(1 + x) <= e - 1 < 2. QED. -;;; 1/(1 + e^{-x}) = e^x/(1 + e^x) +;;; Logistic function: 1/(1 + e^{-x}) = e^x/(1 + e^x). Maps a +;;; log-odds-space probability in [-\infty, +\infty] into a +;;; direct-space probability in [0,1]. Inverse of logit. +;;; +;;; Ill-conditioned for large x; the identity logistic(-x) = 1 - +;;; logistic(x) and the function (logistic-1/2 x) = (- (logistic x) +;;; 1/2) may help to rearrange a computation. +;;; +;;; This implementation gives relative error bounded by 7 eps. + (define (logistic x) (guarantee-real x 'logistic) (cond ((<= x flo:subnormal-exponent-min-base-e) @@ -2162,12 +2174,60 @@ USA. ;; 1.))) +;;; Logistic function, translated in output by 1/2: logistic(x) - 1/2 = +;;; 1/(1 + e^{-x}) - 1/2. Well-conditioned on the entire real plane, +;;; with maximum condition number 1 at 0. +;;; +;;; This implementation gives relative error bounded by 5 eps. + +(define (logistic-1/2 x) + ;; Suppose exp has error d0, + has error d1, expm1 has error d2, and + ;; / has error d3, so we evaluate + ;; + ;; -(1 + d2) (1 + d3) (e^{-x} - 1) + ;; / [2 (1 + d1) (1 + (1 + d0) e^{-x})]. + ;; + ;; In the denominator, + ;; + ;; 1 + (1 + d0) e^{-x} + ;; = 1 + e^{-x} + d0 e^{-x} + ;; = (1 + e^{-x}) (1 + d0 e^{-x}/(1 + e^{-x})), + ;; + ;; so the relative error of the numerator is + ;; + ;; d' = d2 + d3 + d2 d3, + ;; and of the denominator, + ;; d'' = d1 + d0 e^{-x}/(1 + e^{-x}) + d0 d1 e^{-x}/(1 + e^{-x}) + ;; = d1 + d0 L(-x) + d0 d1 L(-x), + ;; + ;; where L(-x) is logistic(-x). By Lemma 1 the relative error of the + ;; quotient is bounded by + ;; + ;; 2|d2 + d3 + d2 d3 - d1 - d0 L(x) + d0 d1 L(x)|, + ;; + ;; Since 0 < L(x) < 1, this is bounded by + ;; + ;; 2|d2| + 2|d3| + 2|d2 d3| + 2|d1| + 2|d0| + 2|d0 d1| + ;; <= 4 eps + 2 eps^2. + ;; + (- (/ (expm1 (- x)) (* 2 (+ 1 (exp (- x))))))) + (define-integrable logit-boundary-lo ;logistic(-1) (flo:/ (flo:exp -1.) (flo:+ 1. (flo:exp -1.)))) (define-integrable logit-boundary-hi ;logistic(+1) (flo:/ 1. (flo:+ 1. (flo:exp -1.)))) -;;; log p/(1 - p) +;;; Logit function: log p/(1 - p). Defined on [0,1]. Maps a +;;; direct-space probability in [0,1] to a log-odds-space probability +;;; in [-\infty, +\infty]. Inverse of logistic. +;;; +;;; Ill-conditioned near 1/2 and 1; the identity logit(1 - p) = +;;; -logit(p) and the function (logit1/2+ p0) = (logit (+ 1/2 p0)) may +;;; help to rearrange a computation for p in [1/(1 + e), 1 - 1/(1 + +;;; e)]. +;;; +;;; This implementation gives relative error bounded by 10 eps. + (define (logit p) (guarantee-real p 'logit) (if (not (<= 0 p 1)) @@ -2182,8 +2242,8 @@ USA. ;; = -log(1 + (1 - 2p)/p). ;; ;; to get an intermediate quotient near zero. + ;; (cond ((<= logit-boundary-lo p logit-boundary-hi) - ;; ;; Since p = 2p/2 <= 1 <= 2*2p = 4p, the floating-point ;; evaluation of 1 - 2p is exact; the only error arises from ;; division and log1p. First, note that if logistic(-1) <= p @@ -2228,7 +2288,79 @@ USA. ;; (log (/ p (- 1 p)))))) +;;; Logit function, translated in input by 1/2: (logit1/2+ p-1/2) = +;;; (logit (+ 1/2 p-1/2)). Defined on [-1/2, 1/2]. Inverse of +;;; logistic-1/2. +;;; +;;; Ill-conditioned near +/-1/2. If |p0| > 1/2 - 1/(1 + e), it may be +;;; better to compute 1/2 + p0 or -1/2 - p0 and to use logit instead. +;;; This implementation gives relative error bounded by 10 eps. + +(define (logit1/2+ p-1/2) + (cond ((<= (abs p-1/2) (- 1/2 (/ 1 (+ 1 (exp 1))))) + ;; If p' = p - 1/2, then p = 1/2 + p', so we compute: + ;; + ;; log(p/(1 - p)) + ;; = log((1/2 + p')/(1 - (1/2 + p'))) + ;; = log((1/2 + p')/(1/2 - p')) + ;; = log(1 + (1/2 + p')/(1/2 - p') - 1) + ;; = log(1 + (1/2 + p' - (1/2 - p'))/(1/2 - p')) + ;; = log(1 + (1/2 + p' - 1/2 + p')/(1/2 - p')) + ;; = log(1 + 2 p'/(1/2 - p')) + ;; + ;; Note that since p0/2 <= 1/2 <= 2 p0, 1/2 - p0 is + ;; computed exactly without error; the only error + ;; arises from division and log1p. If the error of + ;; division is d0 and the error of log1p is d1, then + ;; what we compute is + ;; + ;; (1 + d1) log(1 + (1 + d0) 2 p0/(1/2 - p0)) + ;; = (1 + d1) (1 + d') log(1 + 2 p0/(1/2 - p0)) + ;; = (1 + d1 + d' + d1 d') log(1 + 2 p0/(1/2 - p0)). + ;; + ;; where |d'| < 8|d0| by Lemma 4, since + ;; + ;; 1/e <= 1 + 2*p0/(1/2 - p0) <= e + ;; + ;; when |p0| <= 1/2 - 1/(1 + e). Hence the relative + ;; error is bounded by + ;; + ;; |d1 + d' + d1 d'| + ;; <= |d1| + |d'| + |d1 d'| + ;; <= |d1| + 8 |d0| + 8 |d1 d0| + ;; <= 9 eps + 8 eps^2. + ;; + (log1p (/ (* 2 p-1/2) (- 1/2 p-1/2)))) + (else + ;; We have a choice of computing logit(1/2 + p0) or -logit(1 - + ;; (1/2 + p0)) = -logit(1/2 - p0). It doesn't matter which + ;; way we do this: either way, since 1/2 p0 <= 1/2 <= 2 p0, + ;; the sum and difference are computed exactly. So let's do + ;; the one that skips the final negation. + ;; + ;; Again, the only error arises from division and log. So the + ;; result is + ;; + ;; (1 + d1) log((1 + d0) (1/2 + p0)/(1/2 - p0)) + ;; = (1 + d1) (1 + log(1 + d0)/log((1/2 + p0)/(1/2 - p0))) + ;; * log((1/2 + p0)/(1/2 - p0)) + ;; = (1 + d') log((1/2 + p0)/(1/2 - p0)) + ;; + ;; where + ;; + ;; d' = d1 + log(1 + d0)/log((1/2 + p0)/(1/2 - p0)) + ;; + d1 log(1 + d0)/log((1/2 + p0)/(1/2 - p0)). + ;; + ;; For |p| > 1/2 - 1/(1 + e), logit(1/2 + p0) > 1. For |d0| < + ;; 1/2, |log(1 + d0)| < 2|d0|. Hence this is bounded by + ;; + ;; |d'| <= |d1| + 2|d0| + 2|d0 d1| + ;; <= 3 eps + 2 eps^2. + ;; + (log (/ (+ 1/2 p-1/2) (- 1/2 p-1/2)))))) + ;;; log logistic(x) = -log (1 + e^{-x}) + (define (log-logistic x) (guarantee-real x 'log-logistic) (- (log1pexp (- x)))) @@ -2239,6 +2371,7 @@ USA. (flo:- 0. (flo:log (flo:+ 1. (flo:exp -1.))))) ;;; log e^t/(1 - e^t) = logit(e^t) + (define (logit-exp t) (guarantee-real t 'logit-exp) (cond ((<= t flo:log-epsilon) diff --git a/src/runtime/runtime.pkg b/src/runtime/runtime.pkg index 893cabc65..6eea6fe64 100644 --- a/src/runtime/runtime.pkg +++ b/src/runtime/runtime.pkg @@ -3383,8 +3383,10 @@ USA. log1mexp log1pexp logistic + logistic-1/2 logit logit-exp + logit1/2+ max min modulo diff --git a/tests/runtime/test-arith.scm b/tests/runtime/test-arith.scm index 9384402e9..b959afd71 100644 --- a/tests/runtime/test-arith.scm +++ b/tests/runtime/test-arith.scm @@ -232,6 +232,20 @@ USA. (if (< p 1) (assert-<= (relerr (log1p (- p)) (log-logistic (- x))) 1e-15))))) +(define-enumerated-test 'logit-logistic-1/2 + (vector + (vector 1e-300 4e-300) + (vector 1e-16 4e-16) + (vector .2310585786300049 1.) + (vector .49999999999999994 37.42994775023705)) + (lambda (v) + (let ((p (vector-ref v 0)) + (x (vector-ref v 1))) + (assert-<= (relerr x (logit1/2+ p)) 1e-15) + (assert-<= (relerr p (logistic-1/2 x)) 1e-15) + (assert-= (- (logit1/2+ p)) (logit1/2+ (- p))) + (assert-= (- (logistic-1/2 x)) (logistic-1/2 (- x)))))) + (define-enumerated-test 'expt-exact (vector (vector 2. -1075 "0.") -- 2.25.1