diff options
author | Brian Patton <bjp@google.com> | 2018-09-26 14:10:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 14:14:46 -0700 |
commit | 72b927960625cd2920fea06e242df1ff0d220c77 (patch) | |
tree | 633fa27b1fec1c0db08b657877e9131488e5d60b /tensorflow/python/kernel_tests | |
parent | ce58563454de6c33ea3bdea5840234eeefbc835e (diff) |
Specify a preferred_dtype=self.dtype when converting Distribution methods' sample-like args to Tensors.
After this change, you could conceivably write tfd.Normal(0., 1.).log_prob(1)
The tf core distributions can't use tfp dtype_util.common_dtype, so you can't yet write tfd.Normal(0, 1).
Works around an eager bug that loses precision in the presence in tf.convert_to_tensor(0.5, preferred_dtype=tf.int32)
PiperOrigin-RevId: 214666222
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/bernoulli_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/normal_test.py | 8 |
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 26d013bccb..37b35ba51a 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -118,7 +118,9 @@ class BernoulliTest(test.TestCase): self.assertEqual(dist.probs.dtype, dist.stddev().dtype) self.assertEqual(dist.probs.dtype, dist.entropy().dtype) self.assertEqual(dist.probs.dtype, dist.prob(0).dtype) + self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype) self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype) + self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype) dist64 = make_bernoulli([], dtypes.int64) self.assertEqual(dist64.dtype, dtypes.int64) @@ -181,6 +183,16 @@ class BernoulliTest(test.TestCase): return self._testPmf(logits=special.logit(p)) + @test_util.run_in_graph_and_eager_modes + def testPmfWithFloatArgReturnsXEntropy(self): + p = [[0.2], [0.4], [0.3], [0.6]] + samps = [0, 0.1, 0.8] + self.assertAllClose( + np.float32(samps) * np.log(np.float32(p)) + + (1 - np.float32(samps)) * np.log(1 - np.float32(p)), + self.evaluate( + bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps))) + def testBroadcasting(self): with self.cached_session(): p = array_ops.placeholder(dtypes.float32) diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index de73a40b23..6625a88843 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -78,6 +78,14 @@ class NormalTest(test.TestCase): self.assertEqual(expected, sigma_shape) @test_util.run_in_graph_and_eager_modes + def testSampleLikeArgsGetDistDType(self): + dist = normal_lib.Normal(0., 1.) + self.assertEqual(dtypes.float32, dist.dtype) + for method in ("log_prob", "prob", "log_cdf", "cdf", + "log_survival_function", "survival_function", "quantile"): + self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype) + + @test_util.run_in_graph_and_eager_modes def testParamShapes(self): sample_shape = [10, 3, 4] self._testParamShapes(sample_shape, sample_shape) |