diff options
author | 2016-12-14 12:15:32 -0800 | |
---|---|---|
committer | 2016-12-14 12:25:11 -0800 | |
commit | 48191da012d73eea3a322d839cc279ae976e67d7 (patch) | |
tree | 9e3d74c364f88e48c1e753dc9f91eb2c22b4e381 | |
parent | 2051298c57886787c38c2d590212ae60c9beedba (diff) |
(Part 2 of 2.) Allow TransformedDistribution to override batch_shape or event_shape when
is_scalar_batch or is_event_batch but not necessarily both.
Change: 142047746
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py | 119 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/transformed_distribution.py | 128 |
2 files changed, 152 insertions, 95 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 20a87add88..fdb69b8df8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -159,36 +159,36 @@ class TransformedDistributionTest(tf.test.TestCase): self.assertAllClose(actual_mvn.entropy().eval(), fake_mvn.entropy().eval()) - def testMultivariateFromScalarBatchScalarEvent(self): - with self.test_session() as sess: - shift = np.array([-1, 0, 1], dtype=np.float32) - scale = la.LinearOperatorTriL( - [[[-1., 0, 0], - [2, 1, 0], - [3, 2, 1]], - [[2, 0, 0], - [3, -2, 0], - [4, 3, 2]]], - is_non_singular=True, - is_positive_definite=False) +class ScalarToMultiTest(tf.test.TestCase): + + def setUp(self): + self._shift = np.array([-1, 0, 1], dtype=np.float32) + self._tril = np.array( + [[[-1., 0, 0], + [2, 1, 0], + [3, 2, 1]], + [[2, 0, 0], + [3, -2, 0], + [4, 3, 2]]], dtype=np.float32) + + def _testMVN(self, base_distribution, batch_shape=None, + event_shape=None, not_implemented_message=None): + with self.test_session() as sess: # Overriding shapes must be compatible w/bijector; most bijectors are # batch_shape agnostic and only care about event_ndims. # In the case of `Affine`, if we got it wrong then it would fire an # exception due to incompatible dimensions. fake_mvn = ds.TransformedDistribution( - distribution=ds.Normal(mu=0., sigma=1.), - bijector=bs.AffineLinearOperator(shift, scale), - batch_shape=scale.batch_shape, # [2] - event_shape=[scale.domain_dimension.value], # [3] + distribution=base_distribution[0](validate_args=True, + **base_distribution[1]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), + batch_shape=batch_shape, + event_shape=event_shape, validate_args=True) - # Note: Affine ellided this tile. - actual_mean = np.tile(shift, [2, 1]) - # Since LinOp.apply doesn't support `adjoint_b` nor composition, - # we cannot do: scale.apply(scale, adjoint_b=True).eval() - actual_cov = scale.apply(tf.matrix_transpose(scale.to_dense())).eval() - + actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile. + actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1])) actual_mvn = ds.MultivariateNormalFull(mu=actual_mean, sigma=actual_cov) # Ensure sample works by checking first, second moments. @@ -226,63 +226,60 @@ class TransformedDistributionTest(tf.test.TestCase): fake_mvn.survival_function, fake_mvn.log_survival_function): with self.assertRaisesRegexp( - NotImplementedError, "not implemented when overriding event_shape"): + NotImplementedError, not_implemented_message): self.assertRaisesRegexp(unsupported_fn(x)) - def testMultivariateFromNonScalarBatchOrNonScalarEvent(self): + def testScalarBatchScalarEvent(self): + self._testMVN( + base_distribution=[ds.Normal, {"mu": 0., "sigma": 1.}], + batch_shape=[2], + event_shape=[3], + not_implemented_message="not implemented when overriding event_shape") + + def testScalarBatchNonScalarEvent(self): + self._testMVN( + base_distribution=[ds.MultivariateNormalDiag, { + "mu": [0., 0., 0.], "diag_stdev": [1., 1, 1]}], + batch_shape=[2], + not_implemented_message="not implemented$") + with self.test_session(): - shift = np.array([[-1, 0], [1, 0]], dtype=np.float32) - scale = la.LinearOperatorDiag( - [[-1., 2, 3], - [-1., 2, 3]], - is_non_singular=True, - is_positive_definite=False) - - # Can't override batch_shape for scalar batch, non-scalar event. + # Can't override event_shape for scalar batch, non-scalar event. with self.assertRaisesRegexp(ValueError, "requires scalar"): ds.TransformedDistribution( - distribution=ds.MultivariateNormalDiag( - mu=[0.], diag_stdev=[1.]), - bijector=bs.AffineLinearOperator(shift, scale), - batch_shape=scale.batch_shape, # [2] + distribution=ds.MultivariateNormalDiag(mu=[0.], diag_stdev=[1.]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), + batch_shape=[2], + event_shape=[3], validate_args=True) - # Can't override event_shape for non-scalar batch, scalar event. - with self.assertRaisesRegexp(ValueError, "requires scalar"): - ds.TransformedDistribution( - distribution=ds.Normal(mu=[0.], sigma=[1.]), - bijector=bs.AffineLinearOperator(shift, scale), - event_shape=[scale.domain_dimension.value], # [3] - validate_args=True) + def testNonScalarBatchScalarEvent(self): + self._testMVN( + base_distribution=[ds.Normal, {"mu": [0., 0], "sigma": [1., 1]}], + event_shape=[3], + not_implemented_message="not implemented when overriding event_shape") + with self.test_session(): # Can't override batch_shape for non-scalar batch, scalar event. with self.assertRaisesRegexp(ValueError, "requires scalar"): ds.TransformedDistribution( distribution=ds.Normal(mu=[0.], sigma=[1.]), - bijector=bs.AffineLinearOperator(shift, scale), - batch_shape=scale.batch_shape, # [2] - event_shape=[scale.domain_dimension.value], # [3] - validate_args=True) - - # Can't override event_shape for scalar batch, non-scalar event. - with self.assertRaisesRegexp(ValueError, "requires scalar"): - ds.TransformedDistribution( - distribution=ds.MultivariateNormalDiag( - mu=[0.], diag_stdev=[1.]), - bijector=bs.AffineLinearOperator(shift, scale), - batch_shape=scale.batch_shape, # [2] - event_shape=[scale.domain_dimension.value], # [3] + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), + batch_shape=[2], + event_shape=[3], validate_args=True) + def testNonScalarBatchNonScalarEvent(self): + with self.test_session(): # Can't override event_shape and/or batch_shape for non_scalar batch, # non-scalar event. with self.assertRaisesRegexp(ValueError, "requires scalar"): ds.TransformedDistribution( - distribution=ds.MultivariateNormalDiag( - mu=[[0.]], diag_stdev=[[1.]]), - bijector=bs.AffineLinearOperator(shift, scale), - batch_shape=scale.batch_shape, # [2] - event_shape=[scale.domain_dimension.value], # [3] + distribution=ds.MultivariateNormalDiag(mu=[[0.]], + diag_stdev=[[1.]]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), + batch_shape=[2], + event_shape=[3], validate_args=True) diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index a66b06fd79..6aa521e1af 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -47,12 +47,30 @@ _condition_kwargs_dict = { # graph construction. -def _logical_and(*args): - """Convenience function which attempts to statically `reduce_all`.""" - args_static = [tensor_util.constant_value(x) for x in args] - if any(x is None for x in args_static): - return math_ops.reduce_all(args) - return ops.convert_to_tensor(all(args_static), name="logical_and") +def _logical_not(x): + """Convenience function which attempts to statically apply `logical_not`.""" + if tensor_util.constant_value(x) is not None: + return not tensor_util.constant_value(x) + return math_ops.logical_not(x) + + +def _concat_vectors(*args): + """Convenience function which flattens input vectors.""" + vals = [tensor_util.constant_value(ops.convert_to_tensor(x)) for x in args] + if any(x is None for x in vals): + return array_ops.concat_v2(args, 0) + return [x for v in vals for x in v] + + +def _pick_scalar_condition(pred, cond_true, cond_false): + """Convenience function which chooses the condition based on the predicate.""" + # Note: This function is only valid if all of pred, cond_true, and cond_false + # are scalars. This means its semantics are arguably more like tf.cond than + # tf.select even though we use tf.select to implement it. + pred_static = tensor_util.constant_value(pred) + if pred_static is None: + return math_ops.select(pred, cond_true, cond_false) + return cond_true if pred_static else cond_false def _ones_like(x): @@ -109,14 +127,15 @@ class TransformedDistribution(distributions.Distribution): Mathematically: ```none - (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + (log o det o J o g^{-1})(y) + (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + + (log o abs o det o J o g^{-1})(y) ``` Programmatically: ```python - return (bijector.inverse_log_det_jacobian(x) + - distribution.log_prob(bijector.inverse(x)) + return (distribution.log_prob(bijector.inverse(x)) + + bijector.inverse_log_det_jacobian(x)) ``` * `log_cdf`: @@ -130,7 +149,7 @@ class TransformedDistribution(distributions.Distribution): Programmatically: ```python - return distribution.log_prob(bijector.inverse(x)) + return distribution.log_cdf(bijector.inverse(x)) ``` * and similarly for: `cdf`, `prob`, `log_survival_function`, @@ -215,12 +234,10 @@ class TransformedDistribution(distributions.Distribution): bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` - `batch_shape`; valid only if `distribution.is_scalar_batch` and - `distribution.is_scalar_event`. + `batch_shape`; valid only if `distribution.is_scalar_batch`. event_shape: `integer` vector `Tensor` which overrides `distribution` - `event_shape`; valid only if `distribution.is_scalar_batch` and - `distribution.is_scalar_event` - validate_args: Python `Boolean`. Whether to validate input with asserts. + `event_shape`; valid only if `distribution.is_scalar_event`. + validate_args: Python Boolean. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. name: The name for the distribution. Default: @@ -232,26 +249,50 @@ class TransformedDistribution(distributions.Distribution): bijector = bijectors.Identity(validate_args=validate_args) name = name or bijector.name + distribution.name with ops.name_scope(name, values=[event_shape, batch_shape]): - if batch_shape is not None or event_shape is not None: - is_scalar_batch_and_scalar_event = _logical_and( - distribution.is_scalar_batch, - distribution.is_scalar_event) if batch_shape is not None: batch_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(batch_shape, name="batch_shape"), - is_scalar_batch_and_scalar_event, validate_args) + distribution.is_scalar_batch, validate_args) self._override_batch_shape = batch_shape if event_shape is not None: event_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(event_shape, name="event_shape"), - is_scalar_batch_and_scalar_event, validate_args) - event_ndims = (event_shape.get_shape().ndims - if event_shape.get_shape().ndims is not None - else array_ops.rank(event_shape, "event_ndims")) - self._reduce_event_indices = math_ops.range(-event_ndims, 0) + distribution.is_scalar_event, validate_args) + self._override_event_ndims = ( + event_shape.get_shape().ndims + if event_shape.get_shape().ndims is not None + else array_ops.rank(event_shape, name="event_ndims")) + else: + self._override_event_ndims = 0 self._override_event_shape = event_shape + # To convert a scalar distribution into a multivariate distribution we + # will draw dims from the sample dims, which are otherwise iid. This is + # easy to do except in the case that: + # batch_shape is None and + # event_shape is not None and + # not distribution.is_scalar_batch. + # When that case happens the event dims will incorrectly be to the left of + # the batch dims. In this case we'll cyclically permute left the new dims. + if batch_shape is None and event_shape is not None: + self._needs_rotation = ops.convert_to_tensor( + _logical_not(distribution.is_scalar_batch), name="needs_rotation") + n = _pick_scalar_condition(self._needs_rotation, + self._override_event_ndims, 0) + # We'll be reducing the head dims (if at all), i.e., this will be [] + # if we don't need to reduce. + self._reduce_event_indices = math_ops.range( + n - self._override_event_ndims, n) + else: + self._needs_rotation = ops.convert_to_tensor(False, + name="needs_rotation") + # We'll be reducing the tail dims (if at all), i.e., this will be [] + # if we don't need to reduce. + self._reduce_event_indices = ( + math_ops.range(-self._override_event_ndims, 0) + if event_shape is not None else []) + self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( @@ -313,18 +354,23 @@ class TransformedDistribution(distributions.Distribution): self._override_event_shape is None): sample_shape = [n] else: - sample_shape = [[n]] - if self._override_batch_shape is not None: - sample_shape += [self._override_batch_shape] - if self._override_event_shape is not None: - sample_shape += [self._override_event_shape] - sample_shape = array_ops.concat_v2(sample_shape, 0) + if (self._override_batch_shape is not None and + self._override_event_shape is not None): + sample_shape = [[n], + self._override_batch_shape, + self._override_event_shape] + elif self._override_batch_shape is not None: + sample_shape = [[n], self._override_batch_shape] + elif self._override_event_shape is not None: + sample_shape = [self._override_event_shape, [n]] + sample_shape = _concat_vectors(*sample_shape) x = self.distribution.sample(sample_shape=sample_shape, seed=seed, **distribution_kwargs) + x = self._maybe_rotate_dims(x) return self.bijector.forward(x, **bijector_kwargs) @distribution_util.AppendDocstring( - """Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`, + """Implements `(log o p o g^{-1})(y) + (log o abs o det o J o g^{-1})(y)`, where `g^{-1}` is the inverse of `transform`. Also raises a `ValueError` if `inverse` was not provided to the @@ -335,6 +381,7 @@ class TransformedDistribution(distributions.Distribution): distribution_kwargs = distribution_kwargs or {} x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian( y, **bijector_kwargs) + x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._override_event_shape is not None: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) @@ -352,6 +399,7 @@ class TransformedDistribution(distributions.Distribution): distribution_kwargs = distribution_kwargs or {} x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian( y, **bijector_kwargs) + x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._override_event_shape is not None: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) @@ -409,9 +457,9 @@ class TransformedDistribution(distributions.Distribution): raise NotImplementedError("entropy is not implemented") # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: - # H[Y] = H[X] + E_X[(log o det o Jacobian o g)(X)]. + # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: - # E_X[(log o det o Jacobian o g)(X)] = (log o det o Jacobian o g)(c) + # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy() if self._override_event_shape is not None: @@ -463,3 +511,15 @@ class TransformedDistribution(distributions.Distribution): [is_scalar], override_shape) return override_shape + + def _maybe_rotate_dims(self, x, rotate_right=False): + """Helper which rolls left event_dims left or right event_dims right.""" + if tensor_util.constant_value(self._needs_rotation) is False: + return x + ndims = array_ops.rank(x) + n = _pick_scalar_condition(self._needs_rotation, + self._override_event_ndims, 0) + if rotate_right: + n = ndims - n + return array_ops.transpose( + x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n))) |