aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2016-12-14 12:15:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-14 12:25:11 -0800
commit48191da012d73eea3a322d839cc279ae976e67d7 (patch)
tree9e3d74c364f88e48c1e753dc9f91eb2c22b4e381
parent2051298c57886787c38c2d590212ae60c9beedba (diff)
(Part 2 of 2.) Allow TransformedDistribution to override batch_shape or event_shape when
is_scalar_batch or is_event_batch but not necessarily both. Change: 142047746
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py119
-rw-r--r--tensorflow/contrib/distributions/python/ops/transformed_distribution.py128
2 files changed, 152 insertions, 95 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 20a87add88..fdb69b8df8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -159,36 +159,36 @@ class TransformedDistributionTest(tf.test.TestCase):
self.assertAllClose(actual_mvn.entropy().eval(),
fake_mvn.entropy().eval())
- def testMultivariateFromScalarBatchScalarEvent(self):
- with self.test_session() as sess:
- shift = np.array([-1, 0, 1], dtype=np.float32)
- scale = la.LinearOperatorTriL(
- [[[-1., 0, 0],
- [2, 1, 0],
- [3, 2, 1]],
- [[2, 0, 0],
- [3, -2, 0],
- [4, 3, 2]]],
- is_non_singular=True,
- is_positive_definite=False)
+class ScalarToMultiTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._shift = np.array([-1, 0, 1], dtype=np.float32)
+ self._tril = np.array(
+ [[[-1., 0, 0],
+ [2, 1, 0],
+ [3, 2, 1]],
+ [[2, 0, 0],
+ [3, -2, 0],
+ [4, 3, 2]]], dtype=np.float32)
+
+ def _testMVN(self, base_distribution, batch_shape=None,
+ event_shape=None, not_implemented_message=None):
+ with self.test_session() as sess:
# Overriding shapes must be compatible w/bijector; most bijectors are
# batch_shape agnostic and only care about event_ndims.
# In the case of `Affine`, if we got it wrong then it would fire an
# exception due to incompatible dimensions.
fake_mvn = ds.TransformedDistribution(
- distribution=ds.Normal(mu=0., sigma=1.),
- bijector=bs.AffineLinearOperator(shift, scale),
- batch_shape=scale.batch_shape, # [2]
- event_shape=[scale.domain_dimension.value], # [3]
+ distribution=base_distribution[0](validate_args=True,
+ **base_distribution[1]),
+ bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
+ batch_shape=batch_shape,
+ event_shape=event_shape,
validate_args=True)
- # Note: Affine ellided this tile.
- actual_mean = np.tile(shift, [2, 1])
- # Since LinOp.apply doesn't support `adjoint_b` nor composition,
- # we cannot do: scale.apply(scale, adjoint_b=True).eval()
- actual_cov = scale.apply(tf.matrix_transpose(scale.to_dense())).eval()
-
+ actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile.
+ actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1]))
actual_mvn = ds.MultivariateNormalFull(mu=actual_mean, sigma=actual_cov)
# Ensure sample works by checking first, second moments.
@@ -226,63 +226,60 @@ class TransformedDistributionTest(tf.test.TestCase):
fake_mvn.survival_function,
fake_mvn.log_survival_function):
with self.assertRaisesRegexp(
- NotImplementedError, "not implemented when overriding event_shape"):
+ NotImplementedError, not_implemented_message):
self.assertRaisesRegexp(unsupported_fn(x))
- def testMultivariateFromNonScalarBatchOrNonScalarEvent(self):
+ def testScalarBatchScalarEvent(self):
+ self._testMVN(
+ base_distribution=[ds.Normal, {"mu": 0., "sigma": 1.}],
+ batch_shape=[2],
+ event_shape=[3],
+ not_implemented_message="not implemented when overriding event_shape")
+
+ def testScalarBatchNonScalarEvent(self):
+ self._testMVN(
+ base_distribution=[ds.MultivariateNormalDiag, {
+ "mu": [0., 0., 0.], "diag_stdev": [1., 1, 1]}],
+ batch_shape=[2],
+ not_implemented_message="not implemented$")
+
with self.test_session():
- shift = np.array([[-1, 0], [1, 0]], dtype=np.float32)
- scale = la.LinearOperatorDiag(
- [[-1., 2, 3],
- [-1., 2, 3]],
- is_non_singular=True,
- is_positive_definite=False)
-
- # Can't override batch_shape for scalar batch, non-scalar event.
+ # Can't override event_shape for scalar batch, non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
ds.TransformedDistribution(
- distribution=ds.MultivariateNormalDiag(
- mu=[0.], diag_stdev=[1.]),
- bijector=bs.AffineLinearOperator(shift, scale),
- batch_shape=scale.batch_shape, # [2]
+ distribution=ds.MultivariateNormalDiag(mu=[0.], diag_stdev=[1.]),
+ bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
+ batch_shape=[2],
+ event_shape=[3],
validate_args=True)
- # Can't override event_shape for non-scalar batch, scalar event.
- with self.assertRaisesRegexp(ValueError, "requires scalar"):
- ds.TransformedDistribution(
- distribution=ds.Normal(mu=[0.], sigma=[1.]),
- bijector=bs.AffineLinearOperator(shift, scale),
- event_shape=[scale.domain_dimension.value], # [3]
- validate_args=True)
+ def testNonScalarBatchScalarEvent(self):
+ self._testMVN(
+ base_distribution=[ds.Normal, {"mu": [0., 0], "sigma": [1., 1]}],
+ event_shape=[3],
+ not_implemented_message="not implemented when overriding event_shape")
+ with self.test_session():
# Can't override batch_shape for non-scalar batch, scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
ds.TransformedDistribution(
distribution=ds.Normal(mu=[0.], sigma=[1.]),
- bijector=bs.AffineLinearOperator(shift, scale),
- batch_shape=scale.batch_shape, # [2]
- event_shape=[scale.domain_dimension.value], # [3]
- validate_args=True)
-
- # Can't override event_shape for scalar batch, non-scalar event.
- with self.assertRaisesRegexp(ValueError, "requires scalar"):
- ds.TransformedDistribution(
- distribution=ds.MultivariateNormalDiag(
- mu=[0.], diag_stdev=[1.]),
- bijector=bs.AffineLinearOperator(shift, scale),
- batch_shape=scale.batch_shape, # [2]
- event_shape=[scale.domain_dimension.value], # [3]
+ bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
+ batch_shape=[2],
+ event_shape=[3],
validate_args=True)
+ def testNonScalarBatchNonScalarEvent(self):
+ with self.test_session():
# Can't override event_shape and/or batch_shape for non_scalar batch,
# non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
ds.TransformedDistribution(
- distribution=ds.MultivariateNormalDiag(
- mu=[[0.]], diag_stdev=[[1.]]),
- bijector=bs.AffineLinearOperator(shift, scale),
- batch_shape=scale.batch_shape, # [2]
- event_shape=[scale.domain_dimension.value], # [3]
+ distribution=ds.MultivariateNormalDiag(mu=[[0.]],
+ diag_stdev=[[1.]]),
+ bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
+ batch_shape=[2],
+ event_shape=[3],
validate_args=True)
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
index a66b06fd79..6aa521e1af 100644
--- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
@@ -47,12 +47,30 @@ _condition_kwargs_dict = {
# graph construction.
-def _logical_and(*args):
- """Convenience function which attempts to statically `reduce_all`."""
- args_static = [tensor_util.constant_value(x) for x in args]
- if any(x is None for x in args_static):
- return math_ops.reduce_all(args)
- return ops.convert_to_tensor(all(args_static), name="logical_and")
+def _logical_not(x):
+ """Convenience function which attempts to statically apply `logical_not`."""
+ if tensor_util.constant_value(x) is not None:
+ return not tensor_util.constant_value(x)
+ return math_ops.logical_not(x)
+
+
+def _concat_vectors(*args):
+ """Convenience function which flattens input vectors."""
+ vals = [tensor_util.constant_value(ops.convert_to_tensor(x)) for x in args]
+ if any(x is None for x in vals):
+ return array_ops.concat_v2(args, 0)
+ return [x for v in vals for x in v]
+
+
+def _pick_scalar_condition(pred, cond_true, cond_false):
+ """Convenience function which chooses the condition based on the predicate."""
+ # Note: This function is only valid if all of pred, cond_true, and cond_false
+ # are scalars. This means its semantics are arguably more like tf.cond than
+ # tf.select even though we use tf.select to implement it.
+ pred_static = tensor_util.constant_value(pred)
+ if pred_static is None:
+ return math_ops.select(pred, cond_true, cond_false)
+ return cond_true if pred_static else cond_false
def _ones_like(x):
@@ -109,14 +127,15 @@ class TransformedDistribution(distributions.Distribution):
Mathematically:
```none
- (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + (log o det o J o g^{-1})(y)
+ (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) +
+ (log o abs o det o J o g^{-1})(y)
```
Programmatically:
```python
- return (bijector.inverse_log_det_jacobian(x) +
- distribution.log_prob(bijector.inverse(x))
+ return (distribution.log_prob(bijector.inverse(x)) +
+ bijector.inverse_log_det_jacobian(x))
```
* `log_cdf`:
@@ -130,7 +149,7 @@ class TransformedDistribution(distributions.Distribution):
Programmatically:
```python
- return distribution.log_prob(bijector.inverse(x))
+ return distribution.log_cdf(bijector.inverse(x))
```
* and similarly for: `cdf`, `prob`, `log_survival_function`,
@@ -215,12 +234,10 @@ class TransformedDistribution(distributions.Distribution):
bijector: The object responsible for calculating the transformation.
Typically an instance of `Bijector`. `None` means `Identity()`.
batch_shape: `integer` vector `Tensor` which overrides `distribution`
- `batch_shape`; valid only if `distribution.is_scalar_batch` and
- `distribution.is_scalar_event`.
+ `batch_shape`; valid only if `distribution.is_scalar_batch`.
event_shape: `integer` vector `Tensor` which overrides `distribution`
- `event_shape`; valid only if `distribution.is_scalar_batch` and
- `distribution.is_scalar_event`
- validate_args: Python `Boolean`. Whether to validate input with asserts.
+ `event_shape`; valid only if `distribution.is_scalar_event`.
+ validate_args: Python Boolean. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
name: The name for the distribution. Default:
@@ -232,26 +249,50 @@ class TransformedDistribution(distributions.Distribution):
bijector = bijectors.Identity(validate_args=validate_args)
name = name or bijector.name + distribution.name
with ops.name_scope(name, values=[event_shape, batch_shape]):
- if batch_shape is not None or event_shape is not None:
- is_scalar_batch_and_scalar_event = _logical_and(
- distribution.is_scalar_batch,
- distribution.is_scalar_event)
if batch_shape is not None:
batch_shape = self._maybe_validate_shape_override(
ops.convert_to_tensor(batch_shape, name="batch_shape"),
- is_scalar_batch_and_scalar_event, validate_args)
+ distribution.is_scalar_batch, validate_args)
self._override_batch_shape = batch_shape
if event_shape is not None:
event_shape = self._maybe_validate_shape_override(
ops.convert_to_tensor(event_shape, name="event_shape"),
- is_scalar_batch_and_scalar_event, validate_args)
- event_ndims = (event_shape.get_shape().ndims
- if event_shape.get_shape().ndims is not None
- else array_ops.rank(event_shape, "event_ndims"))
- self._reduce_event_indices = math_ops.range(-event_ndims, 0)
+ distribution.is_scalar_event, validate_args)
+ self._override_event_ndims = (
+ event_shape.get_shape().ndims
+ if event_shape.get_shape().ndims is not None
+ else array_ops.rank(event_shape, name="event_ndims"))
+ else:
+ self._override_event_ndims = 0
self._override_event_shape = event_shape
+ # To convert a scalar distribution into a multivariate distribution we
+ # will draw dims from the sample dims, which are otherwise iid. This is
+ # easy to do except in the case that:
+ # batch_shape is None and
+ # event_shape is not None and
+ # not distribution.is_scalar_batch.
+ # When that case happens the event dims will incorrectly be to the left of
+ # the batch dims. In this case we'll cyclically permute left the new dims.
+ if batch_shape is None and event_shape is not None:
+ self._needs_rotation = ops.convert_to_tensor(
+ _logical_not(distribution.is_scalar_batch), name="needs_rotation")
+ n = _pick_scalar_condition(self._needs_rotation,
+ self._override_event_ndims, 0)
+ # We'll be reducing the head dims (if at all), i.e., this will be []
+ # if we don't need to reduce.
+ self._reduce_event_indices = math_ops.range(
+ n - self._override_event_ndims, n)
+ else:
+ self._needs_rotation = ops.convert_to_tensor(False,
+ name="needs_rotation")
+ # We'll be reducing the tail dims (if at all), i.e., this will be []
+ # if we don't need to reduce.
+ self._reduce_event_indices = (
+ math_ops.range(-self._override_event_ndims, 0)
+ if event_shape is not None else [])
+
self._distribution = distribution
self._bijector = bijector
super(TransformedDistribution, self).__init__(
@@ -313,18 +354,23 @@ class TransformedDistribution(distributions.Distribution):
self._override_event_shape is None):
sample_shape = [n]
else:
- sample_shape = [[n]]
- if self._override_batch_shape is not None:
- sample_shape += [self._override_batch_shape]
- if self._override_event_shape is not None:
- sample_shape += [self._override_event_shape]
- sample_shape = array_ops.concat_v2(sample_shape, 0)
+ if (self._override_batch_shape is not None and
+ self._override_event_shape is not None):
+ sample_shape = [[n],
+ self._override_batch_shape,
+ self._override_event_shape]
+ elif self._override_batch_shape is not None:
+ sample_shape = [[n], self._override_batch_shape]
+ elif self._override_event_shape is not None:
+ sample_shape = [self._override_event_shape, [n]]
+ sample_shape = _concat_vectors(*sample_shape)
x = self.distribution.sample(sample_shape=sample_shape, seed=seed,
**distribution_kwargs)
+ x = self._maybe_rotate_dims(x)
return self.bijector.forward(x, **bijector_kwargs)
@distribution_util.AppendDocstring(
- """Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`,
+ """Implements `(log o p o g^{-1})(y) + (log o abs o det o J o g^{-1})(y)`,
where `g^{-1}` is the inverse of `transform`.
Also raises a `ValueError` if `inverse` was not provided to the
@@ -335,6 +381,7 @@ class TransformedDistribution(distributions.Distribution):
distribution_kwargs = distribution_kwargs or {}
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
y, **bijector_kwargs)
+ x = self._maybe_rotate_dims(x, rotate_right=True)
log_prob = self.distribution.log_prob(x, **distribution_kwargs)
if self._override_event_shape is not None:
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
@@ -352,6 +399,7 @@ class TransformedDistribution(distributions.Distribution):
distribution_kwargs = distribution_kwargs or {}
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
y, **bijector_kwargs)
+ x = self._maybe_rotate_dims(x, rotate_right=True)
prob = self.distribution.prob(x, **distribution_kwargs)
if self._override_event_shape is not None:
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
@@ -409,9 +457,9 @@ class TransformedDistribution(distributions.Distribution):
raise NotImplementedError("entropy is not implemented")
# Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
# can be shown that:
- # H[Y] = H[X] + E_X[(log o det o Jacobian o g)(X)].
+ # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
# If is_constant_jacobian then:
- # E_X[(log o det o Jacobian o g)(X)] = (log o det o Jacobian o g)(c)
+ # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
# where c can by anything.
entropy = self.distribution.entropy()
if self._override_event_shape is not None:
@@ -463,3 +511,15 @@ class TransformedDistribution(distributions.Distribution):
[is_scalar], override_shape)
return override_shape
+
+ def _maybe_rotate_dims(self, x, rotate_right=False):
+ """Helper which rolls left event_dims left or right event_dims right."""
+ if tensor_util.constant_value(self._needs_rotation) is False:
+ return x
+ ndims = array_ops.rank(x)
+ n = _pick_scalar_condition(self._needs_rotation,
+ self._override_event_ndims, 0)
+ if rotate_right:
+ n = ndims - n
+ return array_ops.transpose(
+ x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))