aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 08:30:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 08:34:53 -0700
commit28757ad658243526d84fd16d53b9eefbf809c6ff (patch)
tree4ac420c5bc9effbf9858293356f32d86026b6c9c /tensorflow/python/ops
parent97d515273a1e86a861cdfb338671a42b3b1126a7 (diff)
Use xlogy in a few places in TFP to avoid NaN's for certain special cases.
PiperOrigin-RevId: 215392621
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/distributions/beta.py4
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py2
-rw-r--r--tensorflow/python/ops/distributions/gamma.py2
3 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 2ba1ea6744..d6f89a3517 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -267,8 +267,8 @@ class Beta(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return ((self.concentration1 - 1.) * math_ops.log(x)
- + (self.concentration0 - 1.) * math_ops.log1p(-x))
+ return (math_ops.xlogy(self.concentration1 - 1., x) +
+ (self.concentration0 - 1.) * math_ops.log1p(-x))
def _log_normalization(self):
return (math_ops.lgamma(self.concentration1)
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 415249a958..997b1d392d 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -236,7 +236,7 @@ class Dirichlet(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
+ return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
def _log_normalization(self):
return special_math_ops.lbeta(self.concentration)
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 3293cda874..bbc64da7bc 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -225,7 +225,7 @@ class Gamma(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return (self.concentration - 1.) * math_ops.log(x) - self.rate * x
+ return math_ops.xlogy(self.concentration - 1., x) - self.rate * x
def _log_normalization(self):
return (math_ops.lgamma(self.concentration)