diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/multinomial.py')
-rw-r--r-- | tensorflow/python/ops/distributions/multinomial.py | 49 |
1 files changed, 29 insertions, 20 deletions
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index 00b5697c83..d49fac59ca 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util @@ -140,6 +141,8 @@ class Multinomial(distribution.Distribution): counts = [[2., 1, 1], [3, 1, 1]] dist.prob(counts) # Shape [2] + + dist.sample(5) # Shape [5, 2, 3] ``` """ @@ -231,29 +234,35 @@ class Multinomial(distribution.Distribution): def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) - if self.total_count.get_shape().ndims is not None: - if self.total_count.get_shape().ndims != 0: - raise NotImplementedError( - "Sample only supported for scalar number of draws.") - elif self.validate_args: - is_scalar = check_ops.assert_rank( - n_draws, 0, - message="Sample only supported for scalar number of draws.") - n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape_tensor()[0] - # Flatten batch dims so logits has shape [B, k], - # where B = reduce_prod(self.batch_shape_tensor()). - x = random_ops.multinomial( - logits=array_ops.reshape(self.logits, [-1, k]), - num_samples=n * n_draws, - seed=seed) - x = array_ops.reshape(x, shape=[-1, n, n_draws]) - x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), - axis=-2) # shape: [B, n, k] + + # boardcast the total_count and logits to same shape + n_draws = array_ops.ones_like( + self.logits[..., 0], dtype=n_draws.dtype) * n_draws + logits = array_ops.ones_like( + n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits + + # flatten the total_count and logits + flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k] + flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm] + + # computes each total_count and logits situation by map_fn + def _sample_single(args): + logits, n_draw = args[0], args[1] # [K], [] + x = random_ops.multinomial(logits[array_ops.newaxis, ...], + n_draw, seed) # [1, n*n_draw] + x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw] + x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k] + return x + x = functional_ops.map_fn(_sample_single, + [flat_logits, flat_ndraws], + dtype=self.dtype) # [B1B2...Bm, n, k] + + # reshape the results to proper shape x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) - x = array_ops.reshape(x, final_shape) - return math_ops.cast(x, self.dtype) + x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k] + return x @distribution_util.AppendDocstring(_multinomial_sample_note) def _log_prob(self, counts): |