From 22c0a40e2f1980445b9616ea969931fb096595ff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Jan 2018 09:18:11 -0800 Subject: Correct handling of dtype for Categorical sampling. PiperOrigin-RevId: 182943806 --- tensorflow/python/kernel_tests/distributions/categorical_test.py | 4 ++++ tensorflow/python/ops/distributions/categorical.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index 019c1bc353..ca2358fe99 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -100,6 +100,10 @@ class CategoricalTest(test.TestCase): self.assertEqual( dist.logits.dtype, dist.log_prob(np.array( 0, dtype=np.int64)).dtype) + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + dist = make_categorical([], 5, dtype=dtype) + self.assertEqual(dist.dtype, dtype) + self.assertEqual(dist.dtype, dist.sample(5).dtype) def testUnknownShape(self): with self.test_session(): diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index 60a515583e..9161e3fa9f 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -265,7 +265,9 @@ class Categorical(distribution.Distribution): logits_2d = self.logits else: logits_2d = array_ops.reshape(self.logits, [-1, self.event_size]) - draws = random_ops.multinomial(logits_2d, n, seed=seed) + sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) draws = array_ops.reshape( array_ops.transpose(draws), array_ops.concat([[n], self.batch_shape_tensor()], 0)) -- cgit v1.2.3