diff options
author | 2018-10-01 16:16:43 -0700 | |
---|---|---|
committer | 2018-10-01 16:20:57 -0700 | |
commit | 49bbfec04b729960999ef054e3acab719631b101 (patch) | |
tree | 38d7b43176036eb466a52196ef9c1fc5108d5e5e /tensorflow/python | |
parent | 24333d8e55bdd995089e93122750340bf8d1ddba (diff) |
Override implementation of log survival for Exponential distribution to better handle small values.
PiperOrigin-RevId: 215299532
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/exponential_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/distributions/exponential.py | 3 |
2 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index 27d1291912..367f8bb0f1 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -81,6 +81,22 @@ class ExponentialTest(test.TestCase): expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) self.assertAllClose(self.evaluate(cdf), expected_cdf) + def testExponentialLogSurvival(self): + batch_size = 7 + lam = constant_op.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32) + + exponential = exponential_lib.Exponential(rate=lam) + + log_survival = exponential.log_survival_function(x) + self.assertEqual(log_survival.get_shape(), (7,)) + + if not stats: + return + expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v) + self.assertAllClose(self.evaluate(log_survival), expected_log_survival) + def testExponentialMean(self): lam_v = np.array([1.0, 4.0, 2.5]) exponential = exponential_lib.Exponential(rate=lam_v) diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py index 4325a14449..02129b5e2a 100644 --- a/tensorflow/python/ops/distributions/exponential.py +++ b/tensorflow/python/ops/distributions/exponential.py @@ -114,6 +114,9 @@ class Exponential(gamma.Gamma): def rate(self): return self._rate + def _log_survival_function(self, value): + return self._log_prob(value) - math_ops.log(self._rate) + def _sample_n(self, n, seed=None): shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0) # Uniform variates must be sampled from the open-interval `(0, 1)` rather |