diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-02 21:02:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 21:05:03 -0700 |
commit | 3027f580046866cb74d5edf4e41c9406e007234c (patch) | |
tree | 3b7c6d26b2f1a5edcc0c92340cc0cd07a7a76439 | |
parent | 38e0139329482d8e44629dea2e87853808eacd0d (diff) |
BUG_FIX: Allow Uniform pdf to work on float64 inputs.
PiperOrigin-RevId: 191391778
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/uniform_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/distributions/uniform.py | 3 |
2 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index df99a0ed25..a8def95b14 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -281,6 +281,22 @@ class UniformTest(test.TestCase): expected_pdf = [1.0, 0.1] self.assertAllClose(expected_pdf, pdf.eval()) + def testUniformFloat64(self): + uniform = uniform_lib.Uniform( + low=np.float64(0.), high=np.float64(1.)) + + self.assertAllClose( + [1., 1.], + self.evaluate(uniform.prob(np.array([0.5, 0.6], dtype=np.float64)))) + + self.assertAllClose( + [0.5, 0.6], + self.evaluate(uniform.cdf(np.array([0.5, 0.6], dtype=np.float64)))) + + self.assertAllClose(0.5, self.evaluate(uniform.mean())) + self.assertAllClose(1 / 12., self.evaluate(uniform.variance())) + self.assertAllClose(0., self.evaluate(uniform.entropy())) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py index ec623b55eb..0891bffdd5 100644 --- a/tensorflow/python/ops/distributions/uniform.py +++ b/tensorflow/python/ops/distributions/uniform.py @@ -166,7 +166,8 @@ class Uniform(distribution.Distribution): return self.low + self.range() * samples def _prob(self, x): - broadcasted_x = x * array_ops.ones(self.batch_shape_tensor()) + broadcasted_x = x * array_ops.ones( + self.batch_shape_tensor(), dtype=x.dtype) return array_ops.where( math_ops.is_nan(broadcasted_x), broadcasted_x, |