aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 21:02:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 21:05:03 -0700
commit3027f580046866cb74d5edf4e41c9406e007234c (patch)
tree3b7c6d26b2f1a5edcc0c92340cc0cd07a7a76439
parent38e0139329482d8e44629dea2e87853808eacd0d (diff)
BUG_FIX: Allow Uniform pdf to work on float64 inputs.
PiperOrigin-RevId: 191391778
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py16
-rw-r--r--tensorflow/python/ops/distributions/uniform.py3
2 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index df99a0ed25..a8def95b14 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -281,6 +281,22 @@ class UniformTest(test.TestCase):
expected_pdf = [1.0, 0.1]
self.assertAllClose(expected_pdf, pdf.eval())
+ def testUniformFloat64(self):
+ uniform = uniform_lib.Uniform(
+ low=np.float64(0.), high=np.float64(1.))
+
+ self.assertAllClose(
+ [1., 1.],
+ self.evaluate(uniform.prob(np.array([0.5, 0.6], dtype=np.float64))))
+
+ self.assertAllClose(
+ [0.5, 0.6],
+ self.evaluate(uniform.cdf(np.array([0.5, 0.6], dtype=np.float64))))
+
+ self.assertAllClose(0.5, self.evaluate(uniform.mean()))
+ self.assertAllClose(1 / 12., self.evaluate(uniform.variance()))
+ self.assertAllClose(0., self.evaluate(uniform.entropy()))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index ec623b55eb..0891bffdd5 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -166,7 +166,8 @@ class Uniform(distribution.Distribution):
return self.low + self.range() * samples
def _prob(self, x):
- broadcasted_x = x * array_ops.ones(self.batch_shape_tensor())
+ broadcasted_x = x * array_ops.ones(
+ self.batch_shape_tensor(), dtype=x.dtype)
return array_ops.where(
math_ops.is_nan(broadcasted_x),
broadcasted_x,