aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/multinomial.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/multinomial.py')
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py49
1 files changed, 29 insertions, 20 deletions
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index 00b5697c83..d49fac59ca 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
@@ -140,6 +141,8 @@ class Multinomial(distribution.Distribution):
counts = [[2., 1, 1], [3, 1, 1]]
dist.prob(counts) # Shape [2]
+
+ dist.sample(5) # Shape [5, 2, 3]
```
"""
@@ -231,29 +234,35 @@ class Multinomial(distribution.Distribution):
def _sample_n(self, n, seed=None):
n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
- if self.total_count.get_shape().ndims is not None:
- if self.total_count.get_shape().ndims != 0:
- raise NotImplementedError(
- "Sample only supported for scalar number of draws.")
- elif self.validate_args:
- is_scalar = check_ops.assert_rank(
- n_draws, 0,
- message="Sample only supported for scalar number of draws.")
- n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
k = self.event_shape_tensor()[0]
- # Flatten batch dims so logits has shape [B, k],
- # where B = reduce_prod(self.batch_shape_tensor()).
- x = random_ops.multinomial(
- logits=array_ops.reshape(self.logits, [-1, k]),
- num_samples=n * n_draws,
- seed=seed)
- x = array_ops.reshape(x, shape=[-1, n, n_draws])
- x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
- axis=-2) # shape: [B, n, k]
+
+ # boardcast the total_count and logits to same shape
+ n_draws = array_ops.ones_like(
+ self.logits[..., 0], dtype=n_draws.dtype) * n_draws
+ logits = array_ops.ones_like(
+ n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
+
+ # flatten the total_count and logits
+ flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k]
+ flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm]
+
+ # computes each total_count and logits situation by map_fn
+ def _sample_single(args):
+ logits, n_draw = args[0], args[1] # [K], []
+ x = random_ops.multinomial(logits[array_ops.newaxis, ...],
+ n_draw, seed) # [1, n*n_draw]
+ x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw]
+ x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k]
+ return x
+ x = functional_ops.map_fn(_sample_single,
+ [flat_logits, flat_ndraws],
+ dtype=self.dtype) # [B1B2...Bm, n, k]
+
+ # reshape the results to proper shape
x = array_ops.transpose(x, perm=[1, 0, 2])
final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
- x = array_ops.reshape(x, final_shape)
- return math_ops.cast(x, self.dtype)
+ x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k]
+ return x
@distribution_util.AppendDocstring(_multinomial_sample_note)
def _log_prob(self, counts):