diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/multinomial.py')
-rw-r--r-- | tensorflow/python/ops/distributions/multinomial.py | 49 |
1 files changed, 20 insertions, 29 deletions
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index d49fac59ca..00b5697c83 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -26,7 +26,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util @@ -141,8 +140,6 @@ class Multinomial(distribution.Distribution): counts = [[2., 1, 1], [3, 1, 1]] dist.prob(counts) # Shape [2] - - dist.sample(5) # Shape [5, 2, 3] ``` """ @@ -234,35 +231,29 @@ class Multinomial(distribution.Distribution): def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) + if self.total_count.get_shape().ndims is not None: + if self.total_count.get_shape().ndims != 0: + raise NotImplementedError( + "Sample only supported for scalar number of draws.") + elif self.validate_args: + is_scalar = check_ops.assert_rank( + n_draws, 0, + message="Sample only supported for scalar number of draws.") + n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape_tensor()[0] - - # boardcast the total_count and logits to same shape - n_draws = array_ops.ones_like( - self.logits[..., 0], dtype=n_draws.dtype) * n_draws - logits = array_ops.ones_like( - n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits - - # flatten the total_count and logits - flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k] - flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm] - - # computes each total_count and logits situation by map_fn - def _sample_single(args): - logits, n_draw = args[0], args[1] # [K], [] - x = random_ops.multinomial(logits[array_ops.newaxis, ...], - n_draw, seed) # [1, n*n_draw] - x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw] - x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k] - return x - x = functional_ops.map_fn(_sample_single, - [flat_logits, flat_ndraws], - dtype=self.dtype) # [B1B2...Bm, n, k] - - # reshape the results to proper shape + # Flatten batch dims so logits has shape [B, k], + # where B = reduce_prod(self.batch_shape_tensor()). + x = random_ops.multinomial( + logits=array_ops.reshape(self.logits, [-1, k]), + num_samples=n * n_draws, + seed=seed) + x = array_ops.reshape(x, shape=[-1, n, n_draws]) + x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), + axis=-2) # shape: [B, n, k] x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) - x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k] - return x + x = array_ops.reshape(x, final_shape) + return math_ops.cast(x, self.dtype) @distribution_util.AppendDocstring(_multinomial_sample_note) def _log_prob(self, counts): |