diff options
author | 2016-08-28 12:08:42 -0800 | |
---|---|---|
committer | 2016-08-28 13:17:45 -0700 | |
commit | 532d1c88ae50584f8ce42f08b294944a40942e10 (patch) | |
tree | efa287ae85fd713796396c23d37be61b492a6320 | |
parent | 676caaeaaa5b4978e9c89e5a3f208fa6ee57bb35 (diff) |
Implement Beta.log_cdf and Beta.cdf.
Change: 131537182
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/beta_test.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/beta.py | 6 |
2 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py index abf56c00dd..a36753d6a1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py @@ -262,5 +262,30 @@ class BetaTest(tf.test.TestCase): stats.beta.mean(a, b)[1, :], atol=1e-1) + def testBetaCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = tf.contrib.distributions.Beta(a, b).cdf(x).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + + def testBetaLogCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = tf.exp(tf.contrib.distributions.Beta(a, b).log_cdf(x)).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index 04f9d9acb9..6024dd9fc2 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -196,6 +196,12 @@ class Beta(distribution.Distribution): def _prob(self, x): return math_ops.exp(self._log_prob(x)) + def _log_cdf(self, x): + return math_ops.log(self._cdf(x)) + + def _cdf(self, x): + return math_ops.betainc(self.a, self.b, x) + def _entropy(self): return (math_ops.lgamma(self.a) - (self.a - 1.) * math_ops.digamma(self.a) + |