aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/categorical.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/categorical.py')
-rw-r--r--tensorflow/python/ops/distributions/categorical.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 60a515583e..9161e3fa9f 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -265,7 +265,9 @@ class Categorical(distribution.Distribution):
logits_2d = self.logits
else:
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
- draws = random_ops.multinomial(logits_2d, n, seed=seed)
+ sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
+ draws = random_ops.multinomial(
+ logits_2d, n, seed=seed, output_dtype=sample_dtype)
draws = array_ops.reshape(
array_ops.transpose(draws),
array_ops.concat([[n], self.batch_shape_tensor()], 0))