aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-23 09:18:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 09:22:21 -0800
commit22c0a40e2f1980445b9616ea969931fb096595ff (patch)
tree372c23b5ce0f6c6ece96f978b302387ba0a736be
parentee2010ade01e7705157e2845643b664b5719db53 (diff)
Correct handling of dtype for Categorical sampling.
PiperOrigin-RevId: 182943806
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py4
-rw-r--r--tensorflow/python/ops/distributions/categorical.py4
2 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index 019c1bc353..ca2358fe99 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -100,6 +100,10 @@ class CategoricalTest(test.TestCase):
self.assertEqual(
dist.logits.dtype, dist.log_prob(np.array(
0, dtype=np.int64)).dtype)
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ dist = make_categorical([], 5, dtype=dtype)
+ self.assertEqual(dist.dtype, dtype)
+ self.assertEqual(dist.dtype, dist.sample(5).dtype)
def testUnknownShape(self):
with self.test_session():
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 60a515583e..9161e3fa9f 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -265,7 +265,9 @@ class Categorical(distribution.Distribution):
logits_2d = self.logits
else:
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
- draws = random_ops.multinomial(logits_2d, n, seed=seed)
+ sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
+ draws = random_ops.multinomial(
+ logits_2d, n, seed=seed, output_dtype=sample_dtype)
draws = array_ops.reshape(
array_ops.transpose(draws),
array_ops.concat([[n], self.batch_shape_tensor()], 0))