# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The Categorical distribution class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): """Broadcasts the event or distribution parameters.""" if event.shape.ndims is None: raise NotImplementedError( "Cannot broadcast with an event tensor of unknown rank.") if event.dtype.is_integer: pass elif event.dtype.is_floating: # When `validate_args=True` we've already ensured int/float casting # is closed. event = math_ops.cast(event, dtype=dtypes.int32) else: raise TypeError("`value` should have integer `dtype` or " "`self.dtype` ({})".format(base_dtype)) if params.get_shape()[:-1] == event.get_shape(): params = params else: params *= array_ops.ones_like( array_ops.expand_dims(event, -1), dtype=params.dtype) params_shape = array_ops.shape(params)[:-1] event *= array_ops.ones(params_shape, dtype=event.dtype) event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1])) return event, params @tf_export("distributions.Categorical") class Categorical(distribution.Distribution): """Categorical distribution. The Categorical distribution is parameterized by either probabilities or log-probabilities of a set of `K` classes. It is defined over the integers `{0, 1, ..., K}`. The Categorical distribution is closely related to the `OneHotCategorical` and `Multinomial` distributions. The Categorical distribution can be intuited as generating samples according to `argmax{ OneHotCategorical(probs) }` itself being identical to `argmax{ Multinomial(probs, total_count=1) }. #### Mathematical Details The probability mass function (pmf) is, ```none pmf(k; pi) = prod_j pi_j**[k == j] ``` #### Pitfalls The number of classes, `K`, must not exceed: - the largest integer representable by `self.dtype`, i.e., `2**(mantissa_bits+1)` (IEE754), - the maximum `Tensor` index, i.e., `2**31-1`. In other words, ```python K <= min(2**31-1, { tf.float16: 2**11, tf.float32: 2**24, tf.float64: 2**53 }[param.dtype]) ``` Note: This condition is validated only when `self.validate_args = True`. #### Examples Creates a 3-class distribution with the 2nd class being most likely. ```python dist = Categorical(probs=[0.1, 0.5, 0.4]) n = 1e4 empirical_prob = tf.cast( tf.histogram_fixed_width( dist.sample(int(n)), [0., 2], nbins=3), dtype=tf.float32) / n # ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32) ``` Creates a 3-class distribution with the 2nd class being most likely. Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than probabilities. ```python dist = Categorical(logits=np.log([0.1, 0.5, 0.4]) n = 1e4 empirical_prob = tf.cast( tf.histogram_fixed_width( dist.sample(int(n)), [0., 2], nbins=3), dtype=tf.float32) / n # ==> array([0.1045, 0.5047, 0.3908], dtype=float32) ``` Creates a 3-class distribution with the 3rd class being most likely. The distribution functions can be evaluated on counts. ```python # counts is a scalar. p = [0.1, 0.4, 0.5] dist = Categorical(probs=p) dist.prob(0) # Shape [] # p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts. counts = [1, 0] dist.prob(counts) # Shape [2] # p will be broadcast to shape [3, 5, 7, 3] to match counts. counts = [[...]] # Shape [5, 7, 3] dist.prob(counts) # Shape [5, 7, 3] ``` """ def __init__( self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Categorical"): """Initialize Categorical distributions using class log-probabilities. Args: logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples (default: int32). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, multidimensional=True, name=name) if validate_args: self._logits = distribution_util.embed_check_categorical_event_shape( self._logits) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 logits_shape = array_ops.shape(self._logits, name="logits_shape") if logits_shape_static[-1].value is not None: self._event_size = ops.convert_to_tensor( logits_shape_static[-1].value, dtype=dtypes.int32, name="event_size") else: with ops.name_scope(name="event_size"): self._event_size = logits_shape[self._batch_rank] if logits_shape_static[:-1].is_fully_defined(): self._batch_shape_val = constant_op.constant( logits_shape_static[:-1].as_list(), dtype=dtypes.int32, name="batch_shape") else: with ops.name_scope(name="batch_shape"): self._batch_shape_val = logits_shape[:-1] super(Categorical, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name) @property def event_size(self): """Scalar `int32` tensor: the number of classes.""" return self._event_size @property def logits(self): """Vector of coordinatewise logits.""" return self._logits @property def probs(self): """Vector of coordinatewise probabilities.""" return self._probs def _batch_shape_tensor(self): return array_ops.identity(self._batch_shape_val) def _batch_shape(self): return self.logits.get_shape()[:-1] def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): return tensor_shape.scalar() def _sample_n(self, n, seed=None): if self.logits.get_shape().ndims == 2: logits_2d = self.logits else: logits_2d = array_ops.reshape(self.logits, [-1, self.event_size]) sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32 draws = random_ops.multinomial( logits_2d, n, seed=seed, output_dtype=sample_dtype) draws = array_ops.reshape( array_ops.transpose(draws), array_ops.concat([[n], self.batch_shape_tensor()], 0)) return math_ops.cast(draws, self.dtype) def _cdf(self, k): k = ops.convert_to_tensor(k, name="k") if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=dtypes.int32) k, probs = _broadcast_cat_event_and_params( k, self.probs, base_dtype=self.dtype.base_dtype) # batch-flatten everything in order to use `sequence_mask()`. batch_flattened_probs = array_ops.reshape(probs, (-1, self._event_size)) batch_flattened_k = array_ops.reshape(k, [-1]) to_sum_over = array_ops.where( array_ops.sequence_mask(batch_flattened_k, self._event_size), batch_flattened_probs, array_ops.zeros_like(batch_flattened_probs)) batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1) # Reshape back to the shape of the argument. return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k)) def _log_prob(self, k): k = ops.convert_to_tensor(k, name="k") if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=dtypes.int32) k, logits = _broadcast_cat_event_and_params( k, self.logits, base_dtype=self.dtype.base_dtype) return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits) def _entropy(self): return -math_ops.reduce_sum( nn_ops.log_softmax(self.logits) * self.probs, axis=-1) def _mode(self): ret = math_ops.argmax(self.logits, dimension=self._batch_rank) ret = math_ops.cast(ret, self.dtype) ret.set_shape(self.batch_shape) return ret @kullback_leibler.RegisterKL(Categorical, Categorical) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Categorical. Args: a: instance of a Categorical distribution object. b: instance of a Categorical distribution object. name: (optional) Name to use for created operations. default is "kl_categorical_categorical". Returns: Batchwise KL(a || b) """ with ops.name_scope(name, "kl_categorical_categorical", values=[a.logits, b.logits]): # sum(probs log(probs / (1 - probs))) delta_log_probs1 = (nn_ops.log_softmax(a.logits) - nn_ops.log_softmax(b.logits)) return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1, axis=-1)