Add logistic/logit-family functions.
authorTaylor R Campbell <campbell@mumble.net>
Sun, 21 Oct 2018 07:15:32 +0000 (07:15 +0000)
committerTaylor R Campbell <campbell@mumble.net>
Sun, 21 Oct 2018 07:15:32 +0000 (07:15 +0000)
Inverse pairs:

(logistic x) = 1/(1 + e^{-x})
(logit p) = log p/(1 - p)

(log-logistic x) = log 1/(1 + e^{-x})
(logit-exp t) = log e^t/(1 - e^t)

src/runtime/arith.scm
src/runtime/runtime.pkg
tests/runtime/test-arith.scm

index 1beb8fff63e7990269025431a934c650f4f937e2..2a8d874f4588f4eae13f7982f9e0b74bec4a1c9b 100644 (file)
@@ -1972,6 +1972,50 @@ USA.
 (define (cube z)
   (complex:* z (complex:* z z)))
 \f
+;;; log(1 - e^x), defined only on negative x
+(define (log1mexp x)
+  (guarantee-real x 'log1mexp)
+  (guarantee-negative x 'log1mexp)
+  (if (< (- flo:log2) x)
+      (log (- (expm1 x)))
+      (log1p (- (exp x)))))
+
+;;; log(1 + e^x)
+(define (log1pexp x)
+  (guarantee-real x 'log1pexp)
+  (cond ((<= x -745) 0.)
+       ((<= x -37) (exp x))
+       ((<= x 18) (log1p (exp x)))
+       ((<= x 33.3) (+ x (exp (- x))))
+       (else x)))
+
+;;; 1/(1 + e^{-x})
+(define (logistic x)
+  (guarantee-real x 'logistic)
+  (cond ((<= x -745) 0.)
+       ((<= x -37) (exp x))
+       ((<= x 745) (/ 1 (+ 1 (exp (- x)))))
+       (else 1.)))
+
+;;; log p/(1 - p)
+(define (logit p)
+  (guarantee-real p 'logit)
+  (if (not (<= 0 p 1))
+      (error:bad-range-argument p 'logit))
+  (log (/ p (- 1 p))))
+
+;;; log logistic(x) = -log (1 + e^{-x})
+(define (log-logistic x)
+  (guarantee-real x 'log-logistic)
+  (- (log1pexp (- x))))
+
+;;; log e^t/(1 - e^t) = -log (1 - e^t)/e^t = -log (e^{-t} - 1)
+(define (logit-exp t)
+  (guarantee-real t 'logit-exp)
+  (if (<= t -37)
+      t
+      (- (log (expm1 (- t))))))
+\f
 ;;; Replaced with arity-dispatched version in INITIALIZE-PACKAGE!.
 
 (define =)
index 693ff665d9a590f640b5924287ce31ea4b1cb0f6..dbcf62d041131b077cf862f77fbef6f16f39b677 100644 (file)
@@ -3365,6 +3365,12 @@ USA.
          integer-divide-quotient
          integer-divide-remainder
          lcm
+         log-logistic
+         log1mexp
+         log1pexp
+         logistic
+         logit
+         logit-exp
          max
          min
          modulo
index 290966bc7006acfdba5d1bb745a762bac534d7b1..b8dee7d12a5b495fe72cebff877778e2190ea561 100644 (file)
@@ -137,8 +137,10 @@ USA.
   (lambda (v)
     (assert-eqv (cdr v) (expm1 (car v)))))
 
-(define (relerr a e)
-  (abs (/ (- a e) a)))
+(define (relerr e a)
+  (if (= e 0)
+      (if (= a 0) 0 1)
+      (abs (/ (- e a) a))))
 
 (define-enumerated-test 'expm1-approx
   (vector
@@ -160,4 +162,61 @@ USA.
    (cons (- 1 (sqrt 1/2)) 0.25688251232181475)
    (cons 0.3 .26236426446749106))
   (lambda (v)
-    (assert->= 1e-15 (relerr (cdr v) (log1p (car v))))))
\ No newline at end of file
+    (assert->= 1e-15 (relerr (cdr v) (log1p (car v))))))
+
+(define-enumerated-test 'log1mexp
+  (vector
+   (cons -1e-17 -39.1439465808987777)
+   (cons -0.69 -0.696304297144056727)
+   (cons (- (log 2)) (- (log 2)))
+   (cons -0.70 -0.686341002808385170)
+   (cons -708 -3.30755300363840783e-308))
+  (lambda (v)
+    (assert->= 1e-15 (relerr (cdr v) (log1mexp (car v))))))
+
+(define-enumerated-test 'log1pexp
+  (vector
+   (cons -1000 0)
+   (cons -708 3.30755300363840783e-308)
+   (cons -38 3.13913279204802960e-17)
+   (cons -37 8.53304762574406580e-17)
+   (cons -36 2.31952283024356914e-16)
+   (cons 0 (log 2))
+   (cons 17 17.0000000413993746)
+   (cons 18 18.0000000152299791)
+   (cons 19 19.0000000056027964)
+   (cons 33 33.0000000000000071)
+   (cons 34 34))
+  (lambda (v)
+    (assert->= 1e-15 (relerr (cdr v) (log1pexp (car v))))))
+
+(define-enumerated-test 'logit-logistic
+  (vector
+   (vector -1 0.2689414213699951 (log 0.2689414213699951))
+   (vector 0 1/2 (log 1/2))
+   (vector +1 0.7310585786300049 (log 0.7310585786300049))
+   ;; Would like to do +/-710 but we get inexact result traps.
+   (vector +708 1 -3.307553003638408e-308)
+   (vector -708 3.307553003638408e-308 -708)
+   (vector +1000 1 0)
+   (vector -1000 0 -1000))
+  (lambda (v)
+    (let ((x (vector-ref v 0))
+          (p (vector-ref v 1))
+          (t (vector-ref v 2))
+          (maxerr (* 5 microcode-id/floating-epsilon)))
+      (assert->= maxerr (relerr p (logistic x)))
+      (if (and (not (= p 0))
+               (not (= p 1)))
+          (assert->= maxerr (relerr x (logit p))))
+      (if (< p 1)
+          (begin
+            (assert->= maxerr (relerr (- 1 p) (logistic (- x))))
+            (if (< (- 1 p) 1)
+                (assert->= maxerr (relerr (- x) (logit (- 1 p))))))
+          (assert->= 1e-300 (logistic (- x))))
+      (assert->= maxerr (relerr t (log-logistic x)))
+      (if (<= x 709)
+          (assert->= maxerr (relerr x (logit-exp t))))
+      (if (< p 1)
+          (assert->= maxerr (relerr (log1p (- p)) (log-logistic (- x))))))))
\ No newline at end of file