aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/multinomial.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/multinomial.py')
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py49
1 files changed, 20 insertions, 29 deletions
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index d49fac59ca..00b5697c83 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -26,7 +26,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
@@ -141,8 +140,6 @@ class Multinomial(distribution.Distribution):
counts = [[2., 1, 1], [3, 1, 1]]
dist.prob(counts) # Shape [2]
-
- dist.sample(5) # Shape [5, 2, 3]
```
"""
@@ -234,35 +231,29 @@ class Multinomial(distribution.Distribution):
def _sample_n(self, n, seed=None):
n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
+ if self.total_count.get_shape().ndims is not None:
+ if self.total_count.get_shape().ndims != 0:
+ raise NotImplementedError(
+ "Sample only supported for scalar number of draws.")
+ elif self.validate_args:
+ is_scalar = check_ops.assert_rank(
+ n_draws, 0,
+ message="Sample only supported for scalar number of draws.")
+ n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
k = self.event_shape_tensor()[0]
-
- # boardcast the total_count and logits to same shape
- n_draws = array_ops.ones_like(
- self.logits[..., 0], dtype=n_draws.dtype) * n_draws
- logits = array_ops.ones_like(
- n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
-
- # flatten the total_count and logits
- flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k]
- flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm]
-
- # computes each total_count and logits situation by map_fn
- def _sample_single(args):
- logits, n_draw = args[0], args[1] # [K], []
- x = random_ops.multinomial(logits[array_ops.newaxis, ...],
- n_draw, seed) # [1, n*n_draw]
- x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw]
- x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k]
- return x
- x = functional_ops.map_fn(_sample_single,
- [flat_logits, flat_ndraws],
- dtype=self.dtype) # [B1B2...Bm, n, k]
-
- # reshape the results to proper shape
+ # Flatten batch dims so logits has shape [B, k],
+ # where B = reduce_prod(self.batch_shape_tensor()).
+ x = random_ops.multinomial(
+ logits=array_ops.reshape(self.logits, [-1, k]),
+ num_samples=n * n_draws,
+ seed=seed)
+ x = array_ops.reshape(x, shape=[-1, n, n_draws])
+ x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
+ axis=-2) # shape: [B, n, k]
x = array_ops.transpose(x, perm=[1, 0, 2])
final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
- x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k]
- return x
+ x = array_ops.reshape(x, final_shape)
+ return math_ops.cast(x, self.dtype)
@distribution_util.AppendDocstring(_multinomial_sample_note)
def _log_prob(self, counts):