From: Taylor R Campbell Date: Sun, 21 Oct 2018 07:15:32 +0000 (+0000) Subject: Add logistic/logit-family functions. X-Git-Tag: mit-scheme-pucked-10.1.2~16^2~174^2~11 X-Git-Url: https://birchwood-abbey.net/git?a=commitdiff_plain;h=8decee26f13f3ef09abef2662d2c92b87aeef04c;p=mit-scheme.git Add logistic/logit-family functions. Inverse pairs: (logistic x) = 1/(1 + e^{-x}) (logit p) = log p/(1 - p) (log-logistic x) = log 1/(1 + e^{-x}) (logit-exp t) = log e^t/(1 - e^t) --- diff --git a/src/runtime/arith.scm b/src/runtime/arith.scm index 1beb8fff6..2a8d874f4 100644 --- a/src/runtime/arith.scm +++ b/src/runtime/arith.scm @@ -1972,6 +1972,50 @@ USA. (define (cube z) (complex:* z (complex:* z z))) +;;; log(1 - e^x), defined only on negative x +(define (log1mexp x) + (guarantee-real x 'log1mexp) + (guarantee-negative x 'log1mexp) + (if (< (- flo:log2) x) + (log (- (expm1 x))) + (log1p (- (exp x))))) + +;;; log(1 + e^x) +(define (log1pexp x) + (guarantee-real x 'log1pexp) + (cond ((<= x -745) 0.) + ((<= x -37) (exp x)) + ((<= x 18) (log1p (exp x))) + ((<= x 33.3) (+ x (exp (- x)))) + (else x))) + +;;; 1/(1 + e^{-x}) +(define (logistic x) + (guarantee-real x 'logistic) + (cond ((<= x -745) 0.) + ((<= x -37) (exp x)) + ((<= x 745) (/ 1 (+ 1 (exp (- x))))) + (else 1.))) + +;;; log p/(1 - p) +(define (logit p) + (guarantee-real p 'logit) + (if (not (<= 0 p 1)) + (error:bad-range-argument p 'logit)) + (log (/ p (- 1 p)))) + +;;; log logistic(x) = -log (1 + e^{-x}) +(define (log-logistic x) + (guarantee-real x 'log-logistic) + (- (log1pexp (- x)))) + +;;; log e^t/(1 - e^t) = -log (1 - e^t)/e^t = -log (e^{-t} - 1) +(define (logit-exp t) + (guarantee-real t 'logit-exp) + (if (<= t -37) + t + (- (log (expm1 (- t)))))) + ;;; Replaced with arity-dispatched version in INITIALIZE-PACKAGE!. (define =) diff --git a/src/runtime/runtime.pkg b/src/runtime/runtime.pkg index 693ff665d..dbcf62d04 100644 --- a/src/runtime/runtime.pkg +++ b/src/runtime/runtime.pkg @@ -3365,6 +3365,12 @@ USA. integer-divide-quotient integer-divide-remainder lcm + log-logistic + log1mexp + log1pexp + logistic + logit + logit-exp max min modulo diff --git a/tests/runtime/test-arith.scm b/tests/runtime/test-arith.scm index 290966bc7..b8dee7d12 100644 --- a/tests/runtime/test-arith.scm +++ b/tests/runtime/test-arith.scm @@ -137,8 +137,10 @@ USA. (lambda (v) (assert-eqv (cdr v) (expm1 (car v))))) -(define (relerr a e) - (abs (/ (- a e) a))) +(define (relerr e a) + (if (= e 0) + (if (= a 0) 0 1) + (abs (/ (- e a) a)))) (define-enumerated-test 'expm1-approx (vector @@ -160,4 +162,61 @@ USA. (cons (- 1 (sqrt 1/2)) 0.25688251232181475) (cons 0.3 .26236426446749106)) (lambda (v) - (assert->= 1e-15 (relerr (cdr v) (log1p (car v)))))) \ No newline at end of file + (assert->= 1e-15 (relerr (cdr v) (log1p (car v)))))) + +(define-enumerated-test 'log1mexp + (vector + (cons -1e-17 -39.1439465808987777) + (cons -0.69 -0.696304297144056727) + (cons (- (log 2)) (- (log 2))) + (cons -0.70 -0.686341002808385170) + (cons -708 -3.30755300363840783e-308)) + (lambda (v) + (assert->= 1e-15 (relerr (cdr v) (log1mexp (car v)))))) + +(define-enumerated-test 'log1pexp + (vector + (cons -1000 0) + (cons -708 3.30755300363840783e-308) + (cons -38 3.13913279204802960e-17) + (cons -37 8.53304762574406580e-17) + (cons -36 2.31952283024356914e-16) + (cons 0 (log 2)) + (cons 17 17.0000000413993746) + (cons 18 18.0000000152299791) + (cons 19 19.0000000056027964) + (cons 33 33.0000000000000071) + (cons 34 34)) + (lambda (v) + (assert->= 1e-15 (relerr (cdr v) (log1pexp (car v)))))) + +(define-enumerated-test 'logit-logistic + (vector + (vector -1 0.2689414213699951 (log 0.2689414213699951)) + (vector 0 1/2 (log 1/2)) + (vector +1 0.7310585786300049 (log 0.7310585786300049)) + ;; Would like to do +/-710 but we get inexact result traps. + (vector +708 1 -3.307553003638408e-308) + (vector -708 3.307553003638408e-308 -708) + (vector +1000 1 0) + (vector -1000 0 -1000)) + (lambda (v) + (let ((x (vector-ref v 0)) + (p (vector-ref v 1)) + (t (vector-ref v 2)) + (maxerr (* 5 microcode-id/floating-epsilon))) + (assert->= maxerr (relerr p (logistic x))) + (if (and (not (= p 0)) + (not (= p 1))) + (assert->= maxerr (relerr x (logit p)))) + (if (< p 1) + (begin + (assert->= maxerr (relerr (- 1 p) (logistic (- x)))) + (if (< (- 1 p) 1) + (assert->= maxerr (relerr (- x) (logit (- 1 p)))))) + (assert->= 1e-300 (logistic (- x)))) + (assert->= maxerr (relerr t (log-logistic x))) + (if (<= x 709) + (assert->= maxerr (relerr x (logit-exp t)))) + (if (< p 1) + (assert->= maxerr (relerr (log1p (- p)) (log-logistic (- x)))))))) \ No newline at end of file