aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:16:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:20:57 -0700
commit49bbfec04b729960999ef054e3acab719631b101 (patch)
tree38d7b43176036eb466a52196ef9c1fc5108d5e5e /tensorflow/python/kernel_tests
parent24333d8e55bdd995089e93122750340bf8d1ddba (diff)
Override implementation of log survival for Exponential distribution to better handle small values.
PiperOrigin-RevId: 215299532
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 27d1291912..367f8bb0f1 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -81,6 +81,22 @@ class ExponentialTest(test.TestCase):
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ def testExponentialLogSurvival(self):
+ batch_size = 7
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)
+
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ log_survival = exponential.log_survival_function(x)
+ self.assertEqual(log_survival.get_shape(), (7,))
+
+ if not stats:
+ return
+ expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
+
def testExponentialMean(self):
lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v)