aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-20 04:52:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-20 04:56:38 -0700
commit7bab2180c239171a3d0d3017e3f140c027171459 (patch)
tree1ed33d324a1afdeb1f573932e557c373cd07a9ca
parent679c30e10ded33fa98ae8aab8e439f02df4ad4d6 (diff)
Generalize categorical CDF broadcasting over argument.
PiperOrigin-RevId: 162601793
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py96
-rw-r--r--tensorflow/python/ops/distributions/categorical.py70
2 files changed, 138 insertions, 28 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index 33db933e82..019c1bc353 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import kullback_leibler
+from tensorflow.python.ops.distributions import normal
from tensorflow.python.platform import test
@@ -183,11 +184,105 @@ class CategoricalTest(test.TestCase):
with self.test_session():
self.assertAlmostEqual(cdf_op.eval(), expected_cdf)
+ def testCDFBroadcasting(self):
+ # shape: [batch=2, n_bins=3]
+ histograms = [[0.2, 0.1, 0.7],
+ [0.3, 0.45, 0.25]]
+
+ # shape: [batch=3, batch=2]
+ devent = [
+ [0, 0],
+ [1, 1],
+ [2, 2]
+ ]
+ dist = categorical.Categorical(probs=histograms)
+
+ # We test that the probabilities are correctly broadcasted over the
+ # additional leading batch dimension of size 3.
+ expected_cdf_result = np.zeros((3, 2))
+ expected_cdf_result[0, 0] = 0
+ expected_cdf_result[0, 1] = 0
+ expected_cdf_result[1, 0] = 0.2
+ expected_cdf_result[1, 1] = 0.3
+ expected_cdf_result[2, 0] = 0.3
+ expected_cdf_result[2, 1] = 0.75
+
+ with self.test_session():
+ self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result)
+
+ def testBroadcastWithBatchParamsAndBiggerEvent(self):
+ ## The parameters have a single batch dimension, and the event has two.
+
+ # param shape is [3 x 4], where 4 is the number of bins (non-batch dim).
+ cat_params_py = [
+ [0.2, 0.15, 0.35, 0.3],
+ [0.1, 0.05, 0.68, 0.17],
+ [0.1, 0.05, 0.68, 0.17]
+ ]
+
+ # event shape = [5, 3], both are "batch" dimensions.
+ disc_event_py = [
+ [0, 1, 2],
+ [1, 2, 3],
+ [0, 0, 0],
+ [1, 1, 1],
+ [2, 1, 0]
+ ]
+
+ # shape is [3]
+ normal_params_py = [
+ -10.0,
+ 120.0,
+ 50.0
+ ]
+
+ # shape is [5, 3]
+ real_event_py = [
+ [-1.0, 0.0, 1.0],
+ [100.0, 101, -50],
+ [90, 90, 90],
+ [-4, -400, 20.0],
+ [0.0, 0.0, 0.0]
+ ]
+
+ cat_params_tf = array_ops.constant(cat_params_py)
+ disc_event_tf = array_ops.constant(disc_event_py)
+ cat = categorical.Categorical(probs=cat_params_tf)
+
+ normal_params_tf = array_ops.constant(normal_params_py)
+ real_event_tf = array_ops.constant(real_event_py)
+ norm = normal.Normal(loc=normal_params_tf, scale=1.0)
+
+ # Check that normal and categorical have the same broadcasting behaviour.
+ to_run = {
+ "cat_prob": cat.prob(disc_event_tf),
+ "cat_log_prob": cat.log_prob(disc_event_tf),
+ "cat_cdf": cat.cdf(disc_event_tf),
+ "cat_log_cdf": cat.log_cdf(disc_event_tf),
+ "norm_prob": norm.prob(real_event_tf),
+ "norm_log_prob": norm.log_prob(real_event_tf),
+ "norm_cdf": norm.cdf(real_event_tf),
+ "norm_log_cdf": norm.log_cdf(real_event_tf),
+ }
+
+ with self.test_session() as sess:
+ run_result = sess.run(to_run)
+
+ self.assertAllEqual(run_result["cat_prob"].shape,
+ run_result["norm_prob"].shape)
+ self.assertAllEqual(run_result["cat_log_prob"].shape,
+ run_result["norm_log_prob"].shape)
+ self.assertAllEqual(run_result["cat_cdf"].shape,
+ run_result["norm_cdf"].shape)
+ self.assertAllEqual(run_result["cat_log_cdf"].shape,
+ run_result["norm_log_cdf"].shape)
+
def testLogPMF(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = categorical.Categorical(logits)
with self.test_session():
self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4]))
+ self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4]))
def testEntropyNoBatch(self):
logits = np.log([0.2, 0.8]) - 50.
@@ -263,6 +358,7 @@ class CategoricalTest(test.TestCase):
def testLogPMFBroadcasting(self):
with self.test_session():
+ # 1 x 2 x 2
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index d72f6a1fe5..84ca6db4c4 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -31,6 +31,33 @@ from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
+ """Broadcasts the event or distribution parameters."""
+ if event.shape.ndims is None:
+ raise NotImplementedError(
+ "Cannot broadcast with an event tensor of unknown rank.")
+
+ if event.dtype.is_integer:
+ pass
+ elif event.dtype.is_floating:
+ # When `validate_args=True` we've already ensured int/float casting
+ # is closed.
+ event = math_ops.cast(event, dtype=dtypes.int32)
+ else:
+ raise TypeError("`value` should have integer `dtype` or "
+ "`self.dtype` ({})".format(base_dtype))
+
+ if params.get_shape()[:-1] == event.get_shape():
+ params = params
+ else:
+ params *= array_ops.ones_like(
+ array_ops.expand_dims(event, -1), dtype=params.dtype)
+ params_shape = array_ops.shape(params)[:-1]
+ event *= array_ops.ones(params_shape, dtype=event.dtype)
+ event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1]))
+ return event, params
+
+
class Categorical(distribution.Distribution):
"""Categorical distribution.
@@ -248,43 +275,30 @@ class Categorical(distribution.Distribution):
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=dtypes.int32)
- # If there are multiple batch dimension, flatten them into one.
- batch_flattened_probs = array_ops.reshape(self._probs,
- [-1, self._event_size])
+ k, probs = _broadcast_cat_event_and_params(
+ k, self.probs, base_dtype=self.dtype.base_dtype)
+
+ # batch-flatten everything in order to use `sequence_mask()`.
+ batch_flattened_probs = array_ops.reshape(probs,
+ (-1, self._event_size))
batch_flattened_k = array_ops.reshape(k, [-1])
- # Form a tensor to sum over.
- # We don't need to cast k to integer since `sequence_mask` does this for us.
- mask_tensor = array_ops.sequence_mask(batch_flattened_k, self._event_size)
- to_sum_over = array_ops.where(mask_tensor,
- batch_flattened_probs,
- array_ops.zeros_like(batch_flattened_probs))
- batch_flat_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
- return array_ops.reshape(batch_flat_cdf, self._batch_shape())
+ to_sum_over = array_ops.where(
+ array_ops.sequence_mask(batch_flattened_k, self._event_size),
+ batch_flattened_probs,
+ array_ops.zeros_like(batch_flattened_probs))
+ batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
+ # Reshape back to the shape of the argument.
+ return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
def _log_prob(self, k):
k = ops.convert_to_tensor(k, name="k")
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=dtypes.int32)
+ k, logits = _broadcast_cat_event_and_params(
+ k, self.logits, base_dtype=self.dtype.base_dtype)
- if self.logits.get_shape()[:-1] == k.get_shape():
- logits = self.logits
- else:
- logits = self.logits * array_ops.ones_like(
- array_ops.expand_dims(k, -1), dtype=self.logits.dtype)
- logits_shape = array_ops.shape(logits)[:-1]
- k *= array_ops.ones(logits_shape, dtype=k.dtype)
- k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1]))
- if k.dtype.is_integer:
- pass
- elif k.dtype.is_floating:
- # When `validate_args=True` we've already ensured int/float casting
- # is closed.
- return ops.cast(k, dtype=dtypes.int32)
- else:
- raise TypeError("`value` should have integer `dtype` or "
- "`self.dtype` ({})".format(self.dtype.base_dtype))
return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
logits=logits)