diff options
author | 2017-07-20 04:52:20 -0700 | |
---|---|---|
committer | 2017-07-20 04:56:38 -0700 | |
commit | 7bab2180c239171a3d0d3017e3f140c027171459 (patch) | |
tree | 1ed33d324a1afdeb1f573932e557c373cd07a9ca | |
parent | 679c30e10ded33fa98ae8aab8e439f02df4ad4d6 (diff) |
Generalize categorical CDF broadcasting over argument.
PiperOrigin-RevId: 162601793
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/categorical_test.py | 96 | ||||
-rw-r--r-- | tensorflow/python/ops/distributions/categorical.py | 70 |
2 files changed, 138 insertions, 28 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index 33db933e82..019c1bc353 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal from tensorflow.python.platform import test @@ -183,11 +184,105 @@ class CategoricalTest(test.TestCase): with self.test_session(): self.assertAlmostEqual(cdf_op.eval(), expected_cdf) + def testCDFBroadcasting(self): + # shape: [batch=2, n_bins=3] + histograms = [[0.2, 0.1, 0.7], + [0.3, 0.45, 0.25]] + + # shape: [batch=3, batch=2] + devent = [ + [0, 0], + [1, 1], + [2, 2] + ] + dist = categorical.Categorical(probs=histograms) + + # We test that the probabilities are correctly broadcasted over the + # additional leading batch dimension of size 3. + expected_cdf_result = np.zeros((3, 2)) + expected_cdf_result[0, 0] = 0 + expected_cdf_result[0, 1] = 0 + expected_cdf_result[1, 0] = 0.2 + expected_cdf_result[1, 1] = 0.3 + expected_cdf_result[2, 0] = 0.3 + expected_cdf_result[2, 1] = 0.75 + + with self.test_session(): + self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result) + + def testBroadcastWithBatchParamsAndBiggerEvent(self): + ## The parameters have a single batch dimension, and the event has two. + + # param shape is [3 x 4], where 4 is the number of bins (non-batch dim). + cat_params_py = [ + [0.2, 0.15, 0.35, 0.3], + [0.1, 0.05, 0.68, 0.17], + [0.1, 0.05, 0.68, 0.17] + ] + + # event shape = [5, 3], both are "batch" dimensions. + disc_event_py = [ + [0, 1, 2], + [1, 2, 3], + [0, 0, 0], + [1, 1, 1], + [2, 1, 0] + ] + + # shape is [3] + normal_params_py = [ + -10.0, + 120.0, + 50.0 + ] + + # shape is [5, 3] + real_event_py = [ + [-1.0, 0.0, 1.0], + [100.0, 101, -50], + [90, 90, 90], + [-4, -400, 20.0], + [0.0, 0.0, 0.0] + ] + + cat_params_tf = array_ops.constant(cat_params_py) + disc_event_tf = array_ops.constant(disc_event_py) + cat = categorical.Categorical(probs=cat_params_tf) + + normal_params_tf = array_ops.constant(normal_params_py) + real_event_tf = array_ops.constant(real_event_py) + norm = normal.Normal(loc=normal_params_tf, scale=1.0) + + # Check that normal and categorical have the same broadcasting behaviour. + to_run = { + "cat_prob": cat.prob(disc_event_tf), + "cat_log_prob": cat.log_prob(disc_event_tf), + "cat_cdf": cat.cdf(disc_event_tf), + "cat_log_cdf": cat.log_cdf(disc_event_tf), + "norm_prob": norm.prob(real_event_tf), + "norm_log_prob": norm.log_prob(real_event_tf), + "norm_cdf": norm.cdf(real_event_tf), + "norm_log_cdf": norm.log_cdf(real_event_tf), + } + + with self.test_session() as sess: + run_result = sess.run(to_run) + + self.assertAllEqual(run_result["cat_prob"].shape, + run_result["norm_prob"].shape) + self.assertAllEqual(run_result["cat_log_prob"].shape, + run_result["norm_log_prob"].shape) + self.assertAllEqual(run_result["cat_cdf"].shape, + run_result["norm_cdf"].shape) + self.assertAllEqual(run_result["cat_log_cdf"].shape, + run_result["norm_log_cdf"].shape) + def testLogPMF(self): logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50. dist = categorical.Categorical(logits) with self.test_session(): self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4])) + self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4])) def testEntropyNoBatch(self): logits = np.log([0.2, 0.8]) - 50. @@ -263,6 +358,7 @@ class CategoricalTest(test.TestCase): def testLogPMFBroadcasting(self): with self.test_session(): + # 1 x 2 x 2 histograms = [[[0.2, 0.8], [0.4, 0.6]]] dist = categorical.Categorical(math_ops.log(histograms) - 50.) diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index d72f6a1fe5..84ca6db4c4 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -31,6 +31,33 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): + """Broadcasts the event or distribution parameters.""" + if event.shape.ndims is None: + raise NotImplementedError( + "Cannot broadcast with an event tensor of unknown rank.") + + if event.dtype.is_integer: + pass + elif event.dtype.is_floating: + # When `validate_args=True` we've already ensured int/float casting + # is closed. + event = math_ops.cast(event, dtype=dtypes.int32) + else: + raise TypeError("`value` should have integer `dtype` or " + "`self.dtype` ({})".format(base_dtype)) + + if params.get_shape()[:-1] == event.get_shape(): + params = params + else: + params *= array_ops.ones_like( + array_ops.expand_dims(event, -1), dtype=params.dtype) + params_shape = array_ops.shape(params)[:-1] + event *= array_ops.ones(params_shape, dtype=event.dtype) + event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1])) + return event, params + + class Categorical(distribution.Distribution): """Categorical distribution. @@ -248,43 +275,30 @@ class Categorical(distribution.Distribution): k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=dtypes.int32) - # If there are multiple batch dimension, flatten them into one. - batch_flattened_probs = array_ops.reshape(self._probs, - [-1, self._event_size]) + k, probs = _broadcast_cat_event_and_params( + k, self.probs, base_dtype=self.dtype.base_dtype) + + # batch-flatten everything in order to use `sequence_mask()`. + batch_flattened_probs = array_ops.reshape(probs, + (-1, self._event_size)) batch_flattened_k = array_ops.reshape(k, [-1]) - # Form a tensor to sum over. - # We don't need to cast k to integer since `sequence_mask` does this for us. - mask_tensor = array_ops.sequence_mask(batch_flattened_k, self._event_size) - to_sum_over = array_ops.where(mask_tensor, - batch_flattened_probs, - array_ops.zeros_like(batch_flattened_probs)) - batch_flat_cdf = math_ops.reduce_sum(to_sum_over, axis=-1) - return array_ops.reshape(batch_flat_cdf, self._batch_shape()) + to_sum_over = array_ops.where( + array_ops.sequence_mask(batch_flattened_k, self._event_size), + batch_flattened_probs, + array_ops.zeros_like(batch_flattened_probs)) + batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1) + # Reshape back to the shape of the argument. + return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k)) def _log_prob(self, k): k = ops.convert_to_tensor(k, name="k") if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=dtypes.int32) + k, logits = _broadcast_cat_event_and_params( + k, self.logits, base_dtype=self.dtype.base_dtype) - if self.logits.get_shape()[:-1] == k.get_shape(): - logits = self.logits - else: - logits = self.logits * array_ops.ones_like( - array_ops.expand_dims(k, -1), dtype=self.logits.dtype) - logits_shape = array_ops.shape(logits)[:-1] - k *= array_ops.ones(logits_shape, dtype=k.dtype) - k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1])) - if k.dtype.is_integer: - pass - elif k.dtype.is_floating: - # When `validate_args=True` we've already ensured int/float casting - # is closed. - return ops.cast(k, dtype=dtypes.int32) - else: - raise TypeError("`value` should have integer `dtype` or " - "`self.dtype` ({})".format(self.dtype.base_dtype)) return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits) |