aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-09-26 14:10:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:14:46 -0700
commit72b927960625cd2920fea06e242df1ff0d220c77 (patch)
tree633fa27b1fec1c0db08b657877e9131488e5d60b /tensorflow/python/kernel_tests
parentce58563454de6c33ea3bdea5840234eeefbc835e (diff)
Specify a preferred_dtype=self.dtype when converting Distribution methods' sample-like args to Tensors.
After this change, you could conceivably write tfd.Normal(0., 1.).log_prob(1) The tf core distributions can't use tfp dtype_util.common_dtype, so you can't yet write tfd.Normal(0, 1). Works around an eager bug that loses precision in the presence in tf.convert_to_tensor(0.5, preferred_dtype=tf.int32) PiperOrigin-RevId: 214666222
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py12
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py8
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 26d013bccb..37b35ba51a 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -118,7 +118,9 @@ class BernoulliTest(test.TestCase):
self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
dist64 = make_bernoulli([], dtypes.int64)
self.assertEqual(dist64.dtype, dtypes.int64)
@@ -181,6 +183,16 @@ class BernoulliTest(test.TestCase):
return
self._testPmf(logits=special.logit(p))
+ @test_util.run_in_graph_and_eager_modes
+ def testPmfWithFloatArgReturnsXEntropy(self):
+ p = [[0.2], [0.4], [0.3], [0.6]]
+ samps = [0, 0.1, 0.8]
+ self.assertAllClose(
+ np.float32(samps) * np.log(np.float32(p)) +
+ (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
+ self.evaluate(
+ bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
+
def testBroadcasting(self):
with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index de73a40b23..6625a88843 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -78,6 +78,14 @@ class NormalTest(test.TestCase):
self.assertEqual(expected, sigma_shape)
@test_util.run_in_graph_and_eager_modes
+ def testSampleLikeArgsGetDistDType(self):
+ dist = normal_lib.Normal(0., 1.)
+ self.assertEqual(dtypes.float32, dist.dtype)
+ for method in ("log_prob", "prob", "log_cdf", "cdf",
+ "log_survival_function", "survival_function", "quantile"):
+ self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
+
+ @test_util.run_in_graph_and_eager_modes
def testParamShapes(self):
sample_shape = [10, 3, 4]
self._testParamShapes(sample_shape, sample_shape)