aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 12:18:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 12:27:29 -0800
commit657dca6747077af556ec6f6781650009c8606a6f (patch)
tree78c7251efc13844af5318aaf88b2c2c0879d2354
parent1b87fec1180fbf5c13ccafaa39beda0e618cda74 (diff)
Fix geometric_test.py.
Change: 147628659
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/geometric.py9
2 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
index a8ecd1e161..3dbad7b607 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
@@ -217,6 +217,8 @@ class GeometricTest(test.TestCase):
x = np.array([0., 2., 3., 4., 5., 6., 7.], dtype=np.float32)
expected_log_prob = stats.geom.logpmf(x, [1.], loc=-1)
+ # Scipy incorrectly returns nan.
+ expected_log_prob[np.isnan(expected_log_prob)] = 0.
log_prob = geom.log_prob(x)
self.assertEqual([7,], log_prob.get_shape())
diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py
index a584588c90..fd9a50021f 100644
--- a/tensorflow/contrib/distributions/python/ops/geometric.py
+++ b/tensorflow/contrib/distributions/python/ops/geometric.py
@@ -162,7 +162,14 @@ class Geometric(distribution.Distribution):
if self.validate_args:
counts = distribution_util.embed_check_nonnegative_discrete(
counts, check_integer=True)
- return counts * math_ops.log1p(-self.probs) + math_ops.log(self.probs)
+ counts *= array_ops.ones_like(self.probs)
+ probs = self.probs * array_ops.ones_like(counts)
+
+ safe_domain = array_ops.where(
+ math_ops.equal(counts, 0.),
+ array_ops.zeros_like(probs),
+ probs)
+ return counts * math_ops.log1p(-safe_domain) + math_ops.log(probs)
def _entropy(self):
probs = self._probs