diff options
author | 2017-02-15 12:18:48 -0800 | |
---|---|---|
committer | 2017-02-15 12:27:29 -0800 | |
commit | 657dca6747077af556ec6f6781650009c8606a6f (patch) | |
tree | 78c7251efc13844af5318aaf88b2c2c0879d2354 | |
parent | 1b87fec1180fbf5c13ccafaa39beda0e618cda74 (diff) |
Fix geometric_test.py.
Change: 147628659
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/geometric.py | 9 |
2 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py index a8ecd1e161..3dbad7b607 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py @@ -217,6 +217,8 @@ class GeometricTest(test.TestCase): x = np.array([0., 2., 3., 4., 5., 6., 7.], dtype=np.float32) expected_log_prob = stats.geom.logpmf(x, [1.], loc=-1) + # Scipy incorrectly returns nan. + expected_log_prob[np.isnan(expected_log_prob)] = 0. log_prob = geom.log_prob(x) self.assertEqual([7,], log_prob.get_shape()) diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index a584588c90..fd9a50021f 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -162,7 +162,14 @@ class Geometric(distribution.Distribution): if self.validate_args: counts = distribution_util.embed_check_nonnegative_discrete( counts, check_integer=True) - return counts * math_ops.log1p(-self.probs) + math_ops.log(self.probs) + counts *= array_ops.ones_like(self.probs) + probs = self.probs * array_ops.ones_like(counts) + + safe_domain = array_ops.where( + math_ops.equal(counts, 0.), + array_ops.zeros_like(probs), + probs) + return counts * math_ops.log1p(-safe_domain) + math_ops.log(probs) def _entropy(self): probs = self._probs |