From 4eb0164dbcf690b9e33160004198905e96b3f049 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 2 Feb 2017 14:37:00 -0800 Subject: Simplify bijector API: remove the need for batch_ndims. batch_ndims never needs to be provided explicitly. the two bijectors that use it, Affine and AffineLinearOp, can both derive it from their arguments. This brings us back to a more TITO (Tensor-in, Tensor-out) form for the bijector API. Change: 146408773 --- .../python/kernel_tests/bijector_test.py | 47 ++++--- .../kernel_tests/conditional_bijector_test.py | 1 - .../contrib/distributions/python/ops/bijector.py | 145 +++++++++++---------- 3 files changed, 100 insertions(+), 93 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py index 71d9310aa9..092758e6ef 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py @@ -97,7 +97,7 @@ def assert_scalar_congruency(bijector, """ # Checks and defaults. - assert bijector.shaper is None or bijector.shaper.event_ndims.eval() == 0 + assert bijector.event_ndims.eval() == 0 if sess is None: sess = ops.get_default_session() @@ -255,7 +255,6 @@ class BrokenBijectorWithInverseAndInverseLogDetJacobian(bijectors.Bijector): def __init__(self, forward_missing=False, inverse_missing=False): super(BrokenBijectorWithInverseAndInverseLogDetJacobian, self).__init__( - batch_ndims=0, event_ndims=0, validate_args=False, name="BrokenBijectorDual") @@ -287,7 +286,7 @@ class BrokenBijectorSeparateInverseAndInverseLogDetJacobian(bijectors.Bijector): def __init__(self, forward_missing=False, inverse_missing=False): super(BrokenBijectorSeparateInverseAndInverseLogDetJacobian, self).__init__( - batch_ndims=0, event_ndims=0, validate_args=False, name="broken") + event_ndims=0, validate_args=False, name="broken") self._forward_missing = forward_missing self._inverse_missing = inverse_missing @@ -641,7 +640,7 @@ class AffineBijectorTest(test.TestCase): # Corresponds to scale = 2 bijector = bijectors.Affine( shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) @@ -663,7 +662,7 @@ class AffineBijectorTest(test.TestCase): mu = -1. # Corresponds to scale = 2 bijector = bijectors.Affine(shift=mu, scale_diag=[2.], event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) @@ -686,7 +685,7 @@ class AffineBijectorTest(test.TestCase): # Corresponds to scale = 2. bijector = bijectors.Affine( shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape. self.assertAllClose([[1., 3, 5], [7, 9, 11]], @@ -713,7 +712,7 @@ class AffineBijectorTest(test.TestCase): # One batch, scalar. # Corresponds to scale = 1. bijector = bijectors.Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1.] # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) self.assertAllClose([0.], run(bijector.inverse, x)) @@ -735,7 +734,7 @@ class AffineBijectorTest(test.TestCase): # One batch, scalar. # Corresponds to scale = 1. bijector = bijectors.Affine(shift=mu, scale_diag=[1.], event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1.] # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) self.assertAllClose([0.], run(bijector.inverse, x)) @@ -757,7 +756,7 @@ class AffineBijectorTest(test.TestCase): # Univariate, two batches. # Corresponds to scale = 1. bijector = bijectors.Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 1] # One sample from each of two batches. self.assertAllClose([2., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) @@ -779,7 +778,7 @@ class AffineBijectorTest(test.TestCase): # Univariate, two batches. # Corresponds to scale = 1. bijector = bijectors.Affine(shift=mu, scale_diag=[1.], event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar" + self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 1] # One sample from each of two batches. self.assertAllClose([2., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) @@ -801,7 +800,7 @@ class AffineBijectorTest(test.TestCase): # Multivariate # Corresponds to scale = [[1., 0], [0, 1.]] bijector = bijectors.Affine(shift=mu) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -832,7 +831,7 @@ class AffineBijectorTest(test.TestCase): # Multivariate # Corresponds to scale = [[2., 0], [0, 1.]] bijector = bijectors.Affine(shift=mu, scale_diag=[2., 1]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -875,7 +874,7 @@ class AffineBijectorTest(test.TestCase): bijector = bijectors.Affine( shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict)) + self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose( @@ -898,7 +897,7 @@ class AffineBijectorTest(test.TestCase): # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale = 2. bijector = bijectors.Affine(shift=mu, scale_identity_multiplier=scale) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) @@ -921,7 +920,7 @@ class AffineBijectorTest(test.TestCase): # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale_diag = [[2., 2]] bijector = bijectors.Affine(shift=mu, scale_diag=scale_diag) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) @@ -949,7 +948,7 @@ class AffineBijectorTest(test.TestCase): bijector = bijectors.Affine( shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict)) + self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose([-math.log(4)], @@ -975,7 +974,7 @@ class AffineBijectorTest(test.TestCase): scale_identity_multiplier=1., scale_diag=[1.], event_ndims=0) - self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(0, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) @@ -1000,7 +999,7 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_identity_multiplier=1., scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 5]], run(bijector.forward, x)) self.assertAllClose([[1., 0.5]], run(bijector.inverse, x)) @@ -1023,7 +1022,7 @@ class AffineBijectorTest(test.TestCase): # scale = [[2., 0], [2, 3]] bijector = bijectors.Affine( shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 7]], run(bijector.forward, x)) self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x)) @@ -1049,7 +1048,7 @@ class AffineBijectorTest(test.TestCase): scale_identity_multiplier=1.0, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[2., 9]], run(bijector.forward, x)) self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x)) @@ -1079,7 +1078,7 @@ class AffineBijectorTest(test.TestCase): [0, 1]]) bijector_ref = bijectors.Affine(shift=mu, scale_diag=[10., 2, 3]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 3, 8], run(bijector.forward, x)) self.assertAllClose( @@ -1117,7 +1116,7 @@ class AffineBijectorTest(test.TestCase): [0, 1]]) bijector_ref = bijectors.Affine(shift=mu, scale_diag=[10., 3, 5]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 5, 14], run(bijector.forward, x)) self.assertAllClose( @@ -1159,7 +1158,7 @@ class AffineBijectorTest(test.TestCase): [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -1195,7 +1194,7 @@ class AffineBijectorTest(test.TestCase): bijector_ref = bijectors.Affine( shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector" + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([5., 6, 22], run(bijector.forward, x)) self.assertAllClose( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py index 50e62435de..17a1fe13cf 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py @@ -27,7 +27,6 @@ class _TestBijector(conditional_bijector.ConditionalBijector): def __init__(self): super(_TestBijector, self).__init__( - batch_ndims=0, event_ndims=0, graph_parents=[], is_constant_jacobian=True, diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py index 5cfee1f9fa..5e003b3f5f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijector.py @@ -256,7 +256,7 @@ class Bijector(object): ``` class Exp(Bijector): def __init__(self, event_ndims=0, validate_args=False, name="exp"): - super(Exp, self).__init__(batch_ndims=0, event_ndims=event_ndims, + super(Exp, self).__init__(event_ndims=event_ndims, validate_args=validate_args, name=name) def _forward(self, x): return math_ops.exp(x) @@ -264,9 +264,9 @@ class Bijector(object): x = math_ops.log(y) return x, -self._forward_log_det_jacobian(x) def _forward_log_det_jacobian(self, x): - if self.shaper is None: + if self.event_ndims is None: raise ValueError("Jacobian requires known event_ndims.") - _, _, event_dims = self.shaper.get_dims(x) + event_dims = array_ops.shape(x)[-self.event_ndims:] return math_ops.reduce_sum(x, reduction_indices=event_dims) ``` @@ -389,7 +389,6 @@ class Bijector(object): @abc.abstractmethod def __init__(self, - batch_ndims=None, event_ndims=None, graph_parents=None, is_constant_jacobian=False, @@ -403,17 +402,16 @@ class Bijector(object): Examples: ```python - # Create the Y = g(X) = X transform which operates on 4-Tensors of vectors. - identity = Identity(batch_ndims=4, event_ndims=1) + # Create the Y = g(X) = X transform which operates on vector events. + identity = Identity(event_ndims=1) # Create the Y = g(X) = exp(X) transform which operates on matrices. - exp = Exp(batch_ndims=0, event_ndims=2) + exp = Exp(event_ndims=2) ``` See `Bijector` subclass docstring for more details and specific examples. Args: - batch_ndims: number of dimensions associated with batch coordinates. event_ndims: number of dimensions associated with event coordinates. graph_parents: Python list of graph prerequisites of this `Bijector`. is_constant_jacobian: `Boolean` indicating that the Jacobian is not a @@ -425,13 +423,9 @@ class Bijector(object): enforced. name: The name to give Ops created by the initializer. """ - if batch_ndims is None or event_ndims is None: - self._shaper = None # Apparently subclass will create. - else: - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) + self._event_ndims = ( + ops.convert_to_tensor(event_ndims, dtype=dtypes.int32) + if event_ndims is not None else None) self._graph_parents = graph_parents or [] self._is_constant_jacobian = is_constant_jacobian self._validate_args = validate_args @@ -452,9 +446,9 @@ class Bijector(object): self._name = camel_to_snake(type(self).__name__) @property - def shaper(self): - """Returns shape object used to manage shape constraints.""" - return self._shaper + def event_ndims(self): + """Returns then number of event dimensions this bijector operates on.""" + return self._event_ndims @property def graph_parents(self): @@ -857,10 +851,29 @@ class Bijector(object): return self._from_y.get(mapping.y_key, mapping) return mapping + def _event_dims_tensor(self, sample): + """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`.""" + if self.event_ndims is None: + raise ValueError("Jacobian cannot be computed with unknown event_ndims") + static_event_ndims = tensor_util.constant_value(self.event_ndims) + static_rank = sample.get_shape().ndims + if static_event_ndims is not None and static_rank is not None: + return ops.convert_to_tensor( + static_rank + np.arange(-static_event_ndims, 0).astype(np.int32)) + + if static_event_ndims is not None: + event_range = np.arange(-static_event_ndims, 0).astype(np.int32) + else: + event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32) + + if static_rank is not None: + return event_range + static_rank + else: + return event_range + array_ops.rank(sample) + class Inline(Bijector): - # pylint: disable=line-too-long - """Bijector constructed from callables implementing forward, inverse, and inverse_log_det_jacobian. + """Bijector constructed from custom callables. Example Use: @@ -875,7 +888,6 @@ class Inline(Bijector): The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. """ - # pylint: enable=line-too-long def __init__(self, forward_fn=None, @@ -913,7 +925,6 @@ class Inline(Bijector): name: `String`, name given to ops managed by this object. """ super(Inline, self).__init__( - batch_ndims=0, event_ndims=0, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, @@ -1013,12 +1024,12 @@ class Invert(Bijector): self._bijector = bijector super(Invert, self).__init__( + event_ndims=bijector.event_ndims, graph_parents=bijector.graph_parents, is_constant_jacobian=bijector.is_constant_jacobian, validate_args=validate_args, dtype=bijector.dtype, name=name or "_".join(["invert", bijector.name])) - self._shaper = bijector.shaper def _forward_event_shape(self, input_shape): return self.bijector.inverse_event_shape(input_shape) @@ -1102,16 +1113,21 @@ class Chain(Bijector): raise ValueError("incompatible dtypes: %s" % dtype) elif len(dtype) == 2: dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims elif len(dtype) == 1: dtype = dtype[0] + event_ndims = bijectors[0].event_ndims else: dtype = None + event_ndims = None + super(Chain, self).__init__( graph_parents=list(itertools.chain.from_iterable( b.graph_parents for b in bijectors)), is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), validate_args=validate_args, dtype=dtype, + event_ndims=event_ndims, name=name or ("identity" if not bijectors else "_of_".join(["chain"] + [b.name for b in bijectors]))) @@ -1172,14 +1188,14 @@ class Chain(Bijector): class Identity(Bijector): - """Bijector which computes Y = g(X) = X. + """Compute Y = g(X) = X. Example Use: ```python # Create the Y=g(X)=X transform which is intended for Tensors with 1 batch # ndim and 1 event ndim (i.e., vector of vectors). - identity = Identity(batch_ndims=1, event_ndims=1) + identity = Identity(event_ndims=1) x = [[1., 2], [3, 4]] x == identity.forward(x) == identity.inverse(x) @@ -1187,9 +1203,10 @@ class Identity(Bijector): """ - def __init__(self, validate_args=False, name="identity"): + def __init__(self, validate_args=False, event_ndims=0, name="identity"): super(Identity, self).__init__( is_constant_jacobian=True, + event_ndims=event_ndims, validate_args=validate_args, name=name) @@ -1204,7 +1221,7 @@ class Identity(Bijector): class PowerTransform(Bijector): - """Bijector which computes `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` @@ -1242,7 +1259,6 @@ class PowerTransform(Bijector): raise ValueError("`power` must be a non-negative TF constant.") self._power = power super(PowerTransform, self).__init__( - batch_ndims=0, event_ndims=event_ndims, validate_args=validate_args, name=name) @@ -1262,9 +1278,7 @@ class PowerTransform(Bijector): def _inverse_and_inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - if self.shaper is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - _, _, event_dims = self.shaper.get_dims(y) + event_dims = self._event_dims_tensor(y) if self.power == 0.: x = math_ops.log(y) ildj = -math_ops.reduce_sum(x, reduction_indices=event_dims) @@ -1279,9 +1293,7 @@ class PowerTransform(Bijector): def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - if self.shaper is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - _, _, event_dims = self.shaper.get_dims(x) + event_dims = self._event_dims_tensor(x) if self.power == 0.: return math_ops.reduce_sum(x, reduction_indices=event_dims) return (1. / self.power - 1.) * math_ops.reduce_sum( @@ -1306,14 +1318,14 @@ class PowerTransform(Bijector): class Exp(PowerTransform): - """Bijector which computes Y = g(X) = exp(X). + """Compute `Y = g(X) = exp(X)`. Example Use: ```python # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(batch_ndims=1, event_ndims=2) + exp = Exp(event_ndims=2) x = [[[1., 2], [3, 4]], [[5, 6], @@ -1483,11 +1495,11 @@ class _TriLPlusVDVTLightweightOperatorPD(object): class Affine(Bijector): - # pylint: disable=line-too-long - """Bijector which computes `Y = g(X; shift, scale) = matmul(scale, X) + shift` where `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - Write `A @ X` for `matmul(A, X)`. In TF parlance, the `scale` term is - logically equivalent to: + In TF parlance, the `scale` term is logically equivalent to: ```python scale = ( @@ -1537,7 +1549,6 @@ class Affine(Bijector): ``` """ - # pylint: enable=line-too-long def __init__(self, shift=None, @@ -1653,8 +1664,11 @@ class Affine(Bijector): self._shift.dtype.base_dtype != self._scale.dtype.base_dtype): raise TypeError("shift.dtype({}) does not match scale.dtype({})".format( self._shift.dtype, self._scale.dtype)) - super(Affine, self).__init__( + self._shaper = _DistributionShape( batch_ndims=self._infer_batch_ndims(), + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( event_ndims=event_ndims, graph_parents=( [event_ndims] + @@ -1686,7 +1700,7 @@ class Affine(Bijector): for correctness. Returns: - scale and batch_ndims. In the case of scaling by a constant, scale is a + scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is an `OperatorPD`. Raises: @@ -1831,9 +1845,9 @@ class Affine(Bijector): if self.shift is not None: return y + self.shift return y - y, sample_shape = self.shaper.make_batch_of_event_sample_matrices(y) + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices(y) y = self._scale.sqrt_matmul(y) - y = self.shaper.undo_make_batch_of_event_sample_matrices(y, sample_shape) + y = self._shaper.undo_make_batch_of_event_sample_matrices(y, sample_shape) if self.shift is not None: return y + self.shift return y @@ -1844,9 +1858,9 @@ class Affine(Bijector): x -= self.shift if self._is_only_identity_multiplier: return x / self._scale - x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x) + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices(x) x = self._scale.sqrt_solve(x) - x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape) + x = self._shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape) return x def _inverse_log_det_jacobian(self, y): @@ -1858,7 +1872,7 @@ class Affine(Bijector): # applied via broadcast. d = math_ops.cast(array_ops.shape(x)[-1], dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * array_ops.where( - math_ops.equal(self.shaper.event_ndims, 0), 1., d) + math_ops.equal(self._shaper.event_ndims, 0), 1., d) fldj = self._scale.sqrt_log_abs_det() # We need to squeeze off the padded dimension. start = array_ops.where(self._rank_two_event_ndims_one, 1, 0) @@ -1866,7 +1880,7 @@ class Affine(Bijector): class AffineLinearOperator(Bijector): - """Bijector which computes `Y = g(X; shift, scale) = scale @ X.T + shift`. + """Compute `Y = g(X; shift, scale) = scale(X.T) + shift`. `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. @@ -1989,9 +2003,11 @@ class AffineLinearOperator(Bijector): else: batch_ndims = 0 # We won't need shape inference when scale is None. self._scale = scale - - super(AffineLinearOperator, self).__init__( + self._shaper = _DistributionShape( batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( event_ndims=event_ndims, graph_parents=graph_parents, is_constant_jacobian=True, @@ -2011,12 +2027,12 @@ class AffineLinearOperator(Bijector): def _forward(self, x): y = x if self.scale is not None: - y, sample_shape = self.shaper.make_batch_of_event_sample_matrices( + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( y, expand_batch_dim=False) with ops.control_dependencies([self.scale.assert_non_singular()] if self.validate_args else []): y = self.scale.apply(y) - y = self.shaper.undo_make_batch_of_event_sample_matrices( + y = self._shaper.undo_make_batch_of_event_sample_matrices( y, sample_shape, expand_batch_dim=False) if self.shift is not None: y += self.shift @@ -2027,11 +2043,11 @@ class AffineLinearOperator(Bijector): if self.shift is not None: x -= self.shift if self.scale is not None: - x, sample_shape = self.shaper.make_batch_of_event_sample_matrices( + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( x, expand_batch_dim=False) # Solve fails if the op is singular so we may safely skip this assertion. x = self.scale.solve(x) - x = self.shaper.undo_make_batch_of_event_sample_matrices( + x = self._shaper.undo_make_batch_of_event_sample_matrices( x, sample_shape, expand_batch_dim=False) return x @@ -2060,7 +2076,7 @@ class Softplus(Bijector): ```python # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(batch_ndims=1, event_ndims=2) + softplus = Softplus(event_ndims=2) x = [[[1., 2], [3, 4]], [[5, 6], @@ -2078,7 +2094,6 @@ class Softplus(Bijector): validate_args=False, name="softplus"): super(Softplus, self).__init__( - batch_ndims=0, event_ndims=event_ndims, validate_args=validate_args, name=name) @@ -2087,9 +2102,7 @@ class Softplus(Bijector): return nn_ops.softplus(x) def _inverse_and_inverse_log_det_jacobian(self, y): - if self.shaper is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - _, _, event_dims = self.shaper.get_dims(y) + event_dims = self._event_dims_tensor(y) # Could also do: # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), # reduction_indices=event_dims) @@ -2104,9 +2117,7 @@ class Softplus(Bijector): return distribution_util.softplus_inverse(y), ildj def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.shaper is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - _, _, event_dims = self.shaper.get_dims(x) + event_dims = self._event_dims_tensor(x) return -math_ops.reduce_sum( nn_ops.softplus(-x), reduction_indices=event_dims) @@ -2154,7 +2165,6 @@ class SoftmaxCentered(Bijector): raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") self._static_event_ndims = event_ndims super(SoftmaxCentered, self).__init__( - batch_ndims=0, # We'll regard all non-event dims as sample dims. event_ndims=event_ndims, validate_args=validate_args, name=name) @@ -2330,12 +2340,11 @@ class SigmoidCentered(SoftmaxCentered): def __init__(self, validate_args=False, name="sigmoid_centered"): super(SigmoidCentered, self).__init__( - validate_args=validate_args, name=name) + event_ndims=0, validate_args=validate_args, name=name) class CholeskyOuterProduct(Bijector): - # pylint: disable=line-too-long - """Bijector which computes Y = g(X) = X X.T where X is a lower-triangular, positive-diagonal matrix. + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. `event_ndims` must be 0 or 2, i.e., scalar or matrix. @@ -2352,7 +2361,6 @@ class CholeskyOuterProduct(Bijector): ``` """ - # pylint: enable=line-too-long def __init__(self, event_ndims=2, validate_args=False, name="cholesky_outer_product"): @@ -2378,6 +2386,7 @@ class CholeskyOuterProduct(Bijector): raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") self._static_event_ndims = event_ndims super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) -- cgit v1.2.3