aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-02 14:37:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-02 14:49:25 -0800
commit4eb0164dbcf690b9e33160004198905e96b3f049 (patch)
treed087cd13e22329dfe52eebb8dbdffb52ba1124a5
parent4ebaa6bcec0fadc2c10d73d0626bc92cdd34d6c1 (diff)
Simplify bijector API: remove the need for batch_ndims.
batch_ndims never needs to be provided explicitly. the two bijectors that use it, Affine and AffineLinearOp, can both derive it from their arguments. This brings us back to a more TITO (Tensor-in, Tensor-out) form for the bijector API. Change: 146408773
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py47
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py1
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijector.py145
3 files changed, 100 insertions, 93 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
index 71d9310aa9..092758e6ef 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
@@ -97,7 +97,7 @@ def assert_scalar_congruency(bijector,
"""
# Checks and defaults.
- assert bijector.shaper is None or bijector.shaper.event_ndims.eval() == 0
+ assert bijector.event_ndims.eval() == 0
if sess is None:
sess = ops.get_default_session()
@@ -255,7 +255,6 @@ class BrokenBijectorWithInverseAndInverseLogDetJacobian(bijectors.Bijector):
def __init__(self, forward_missing=False, inverse_missing=False):
super(BrokenBijectorWithInverseAndInverseLogDetJacobian, self).__init__(
- batch_ndims=0,
event_ndims=0,
validate_args=False,
name="BrokenBijectorDual")
@@ -287,7 +286,7 @@ class BrokenBijectorSeparateInverseAndInverseLogDetJacobian(bijectors.Bijector):
def __init__(self, forward_missing=False, inverse_missing=False):
super(BrokenBijectorSeparateInverseAndInverseLogDetJacobian, self).__init__(
- batch_ndims=0, event_ndims=0, validate_args=False, name="broken")
+ event_ndims=0, validate_args=False, name="broken")
self._forward_missing = forward_missing
self._inverse_missing = inverse_missing
@@ -641,7 +640,7 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to scale = 2
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=2., event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1., 2, 3] # Three scalar samples (no batches).
self.assertAllClose([1., 3, 5], run(bijector.forward, x))
self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
@@ -663,7 +662,7 @@ class AffineBijectorTest(test.TestCase):
mu = -1.
# Corresponds to scale = 2
bijector = bijectors.Affine(shift=mu, scale_diag=[2.], event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1., 2, 3] # Three scalar samples (no batches).
self.assertAllClose([1., 3, 5], run(bijector.forward, x))
self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
@@ -686,7 +685,7 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to scale = 2.
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=2., event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape.
self.assertAllClose([[1., 3, 5],
[7, 9, 11]],
@@ -713,7 +712,7 @@ class AffineBijectorTest(test.TestCase):
# One batch, scalar.
# Corresponds to scale = 1.
bijector = bijectors.Affine(shift=mu, event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1.] # One sample from one batches.
self.assertAllClose([2.], run(bijector.forward, x))
self.assertAllClose([0.], run(bijector.inverse, x))
@@ -735,7 +734,7 @@ class AffineBijectorTest(test.TestCase):
# One batch, scalar.
# Corresponds to scale = 1.
bijector = bijectors.Affine(shift=mu, scale_diag=[1.], event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1.] # One sample from one batches.
self.assertAllClose([2.], run(bijector.forward, x))
self.assertAllClose([0.], run(bijector.inverse, x))
@@ -757,7 +756,7 @@ class AffineBijectorTest(test.TestCase):
# Univariate, two batches.
# Corresponds to scale = 1.
bijector = bijectors.Affine(shift=mu, event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1., 1] # One sample from each of two batches.
self.assertAllClose([2., 0], run(bijector.forward, x))
self.assertAllClose([0., 2], run(bijector.inverse, x))
@@ -779,7 +778,7 @@ class AffineBijectorTest(test.TestCase):
# Univariate, two batches.
# Corresponds to scale = 1.
bijector = bijectors.Affine(shift=mu, scale_diag=[1.], event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
x = [1., 1] # One sample from each of two batches.
self.assertAllClose([2., 0], run(bijector.forward, x))
self.assertAllClose([0., 2], run(bijector.inverse, x))
@@ -801,7 +800,7 @@ class AffineBijectorTest(test.TestCase):
# Multivariate
# Corresponds to scale = [[1., 0], [0, 1.]]
bijector = bijectors.Affine(shift=mu)
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 1]
# matmul(sigma, x) + shift
# = [-1, -1] + [1, -1]
@@ -832,7 +831,7 @@ class AffineBijectorTest(test.TestCase):
# Multivariate
# Corresponds to scale = [[2., 0], [0, 1.]]
bijector = bijectors.Affine(shift=mu, scale_diag=[2., 1])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 1]
# matmul(sigma, x) + shift
# = [-1, -1] + [1, -1]
@@ -875,7 +874,7 @@ class AffineBijectorTest(test.TestCase):
bijector = bijectors.Affine(
shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
- self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
+ self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict))
self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict))
self.assertAllClose(
@@ -898,7 +897,7 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to 1 2x2 matrix, with twos on the diagonal.
scale = 2.
bijector = bijectors.Affine(shift=mu, scale_identity_multiplier=scale)
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[[1., 1]]]
self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -921,7 +920,7 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to 1 2x2 matrix, with twos on the diagonal.
scale_diag = [[2., 2]]
bijector = bijectors.Affine(shift=mu, scale_diag=scale_diag)
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[[1., 1]]]
self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -949,7 +948,7 @@ class AffineBijectorTest(test.TestCase):
bijector = bijectors.Affine(
shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
- self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
+ self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict))
self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict))
self.assertAllClose([-math.log(4)],
@@ -975,7 +974,7 @@ class AffineBijectorTest(test.TestCase):
scale_identity_multiplier=1.,
scale_diag=[1.],
event_ndims=0)
- self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(0, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Three scalar samples (no batches).
self.assertAllClose([1., 3, 5], run(bijector.forward, x))
self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
@@ -1000,7 +999,7 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_identity_multiplier=1.,
scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[1., 5]], run(bijector.forward, x))
self.assertAllClose([[1., 0.5]], run(bijector.inverse, x))
@@ -1023,7 +1022,7 @@ class AffineBijectorTest(test.TestCase):
# scale = [[2., 0], [2, 3]]
bijector = bijectors.Affine(
shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[1., 7]], run(bijector.forward, x))
self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x))
@@ -1049,7 +1048,7 @@ class AffineBijectorTest(test.TestCase):
scale_identity_multiplier=1.0,
scale_diag=[1., 2.],
scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[2., 9]], run(bijector.forward, x))
self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x))
@@ -1079,7 +1078,7 @@ class AffineBijectorTest(test.TestCase):
[0, 1]])
bijector_ref = bijectors.Affine(shift=mu, scale_diag=[10., 2, 3])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 3, 8], run(bijector.forward, x))
self.assertAllClose(
@@ -1117,7 +1116,7 @@ class AffineBijectorTest(test.TestCase):
[0, 1]])
bijector_ref = bijectors.Affine(shift=mu, scale_diag=[10., 3, 5])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 5, 14], run(bijector.forward, x))
self.assertAllClose(
@@ -1159,7 +1158,7 @@ class AffineBijectorTest(test.TestCase):
[1, 3, 0],
[2, 3, 5]])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 6, 22], run(bijector.forward, x))
self.assertAllClose(
@@ -1195,7 +1194,7 @@ class AffineBijectorTest(test.TestCase):
bijector_ref = bijectors.Affine(
shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]])
- self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([5., 6, 22], run(bijector.forward, x))
self.assertAllClose(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py
index 50e62435de..17a1fe13cf 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_bijector_test.py
@@ -27,7 +27,6 @@ class _TestBijector(conditional_bijector.ConditionalBijector):
def __init__(self):
super(_TestBijector, self).__init__(
- batch_ndims=0,
event_ndims=0,
graph_parents=[],
is_constant_jacobian=True,
diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index 5cfee1f9fa..5e003b3f5f 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -256,7 +256,7 @@ class Bijector(object):
```
class Exp(Bijector):
def __init__(self, event_ndims=0, validate_args=False, name="exp"):
- super(Exp, self).__init__(batch_ndims=0, event_ndims=event_ndims,
+ super(Exp, self).__init__(event_ndims=event_ndims,
validate_args=validate_args, name=name)
def _forward(self, x):
return math_ops.exp(x)
@@ -264,9 +264,9 @@ class Bijector(object):
x = math_ops.log(y)
return x, -self._forward_log_det_jacobian(x)
def _forward_log_det_jacobian(self, x):
- if self.shaper is None:
+ if self.event_ndims is None:
raise ValueError("Jacobian requires known event_ndims.")
- _, _, event_dims = self.shaper.get_dims(x)
+ event_dims = array_ops.shape(x)[-self.event_ndims:]
return math_ops.reduce_sum(x, reduction_indices=event_dims)
```
@@ -389,7 +389,6 @@ class Bijector(object):
@abc.abstractmethod
def __init__(self,
- batch_ndims=None,
event_ndims=None,
graph_parents=None,
is_constant_jacobian=False,
@@ -403,17 +402,16 @@ class Bijector(object):
Examples:
```python
- # Create the Y = g(X) = X transform which operates on 4-Tensors of vectors.
- identity = Identity(batch_ndims=4, event_ndims=1)
+ # Create the Y = g(X) = X transform which operates on vector events.
+ identity = Identity(event_ndims=1)
# Create the Y = g(X) = exp(X) transform which operates on matrices.
- exp = Exp(batch_ndims=0, event_ndims=2)
+ exp = Exp(event_ndims=2)
```
See `Bijector` subclass docstring for more details and specific examples.
Args:
- batch_ndims: number of dimensions associated with batch coordinates.
event_ndims: number of dimensions associated with event coordinates.
graph_parents: Python list of graph prerequisites of this `Bijector`.
is_constant_jacobian: `Boolean` indicating that the Jacobian is not a
@@ -425,13 +423,9 @@ class Bijector(object):
enforced.
name: The name to give Ops created by the initializer.
"""
- if batch_ndims is None or event_ndims is None:
- self._shaper = None # Apparently subclass will create.
- else:
- self._shaper = _DistributionShape(
- batch_ndims=batch_ndims,
- event_ndims=event_ndims,
- validate_args=validate_args)
+ self._event_ndims = (
+ ops.convert_to_tensor(event_ndims, dtype=dtypes.int32)
+ if event_ndims is not None else None)
self._graph_parents = graph_parents or []
self._is_constant_jacobian = is_constant_jacobian
self._validate_args = validate_args
@@ -452,9 +446,9 @@ class Bijector(object):
self._name = camel_to_snake(type(self).__name__)
@property
- def shaper(self):
- """Returns shape object used to manage shape constraints."""
- return self._shaper
+ def event_ndims(self):
+ """Returns then number of event dimensions this bijector operates on."""
+ return self._event_ndims
@property
def graph_parents(self):
@@ -857,10 +851,29 @@ class Bijector(object):
return self._from_y.get(mapping.y_key, mapping)
return mapping
+ def _event_dims_tensor(self, sample):
+ """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`."""
+ if self.event_ndims is None:
+ raise ValueError("Jacobian cannot be computed with unknown event_ndims")
+ static_event_ndims = tensor_util.constant_value(self.event_ndims)
+ static_rank = sample.get_shape().ndims
+ if static_event_ndims is not None and static_rank is not None:
+ return ops.convert_to_tensor(
+ static_rank + np.arange(-static_event_ndims, 0).astype(np.int32))
+
+ if static_event_ndims is not None:
+ event_range = np.arange(-static_event_ndims, 0).astype(np.int32)
+ else:
+ event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32)
+
+ if static_rank is not None:
+ return event_range + static_rank
+ else:
+ return event_range + array_ops.rank(sample)
+
class Inline(Bijector):
- # pylint: disable=line-too-long
- """Bijector constructed from callables implementing forward, inverse, and inverse_log_det_jacobian.
+ """Bijector constructed from custom callables.
Example Use:
@@ -875,7 +888,6 @@ class Inline(Bijector):
The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`.
"""
- # pylint: enable=line-too-long
def __init__(self,
forward_fn=None,
@@ -913,7 +925,6 @@ class Inline(Bijector):
name: `String`, name given to ops managed by this object.
"""
super(Inline, self).__init__(
- batch_ndims=0,
event_ndims=0,
is_constant_jacobian=is_constant_jacobian,
validate_args=validate_args,
@@ -1013,12 +1024,12 @@ class Invert(Bijector):
self._bijector = bijector
super(Invert, self).__init__(
+ event_ndims=bijector.event_ndims,
graph_parents=bijector.graph_parents,
is_constant_jacobian=bijector.is_constant_jacobian,
validate_args=validate_args,
dtype=bijector.dtype,
name=name or "_".join(["invert", bijector.name]))
- self._shaper = bijector.shaper
def _forward_event_shape(self, input_shape):
return self.bijector.inverse_event_shape(input_shape)
@@ -1102,16 +1113,21 @@ class Chain(Bijector):
raise ValueError("incompatible dtypes: %s" % dtype)
elif len(dtype) == 2:
dtype = dtype[1] if dtype[0] is None else dtype[0]
+ event_ndims = bijectors[0].event_ndims
elif len(dtype) == 1:
dtype = dtype[0]
+ event_ndims = bijectors[0].event_ndims
else:
dtype = None
+ event_ndims = None
+
super(Chain, self).__init__(
graph_parents=list(itertools.chain.from_iterable(
b.graph_parents for b in bijectors)),
is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors),
validate_args=validate_args,
dtype=dtype,
+ event_ndims=event_ndims,
name=name or ("identity" if not bijectors else
"_of_".join(["chain"] + [b.name for b in bijectors])))
@@ -1172,14 +1188,14 @@ class Chain(Bijector):
class Identity(Bijector):
- """Bijector which computes Y = g(X) = X.
+ """Compute Y = g(X) = X.
Example Use:
```python
# Create the Y=g(X)=X transform which is intended for Tensors with 1 batch
# ndim and 1 event ndim (i.e., vector of vectors).
- identity = Identity(batch_ndims=1, event_ndims=1)
+ identity = Identity(event_ndims=1)
x = [[1., 2],
[3, 4]]
x == identity.forward(x) == identity.inverse(x)
@@ -1187,9 +1203,10 @@ class Identity(Bijector):
"""
- def __init__(self, validate_args=False, name="identity"):
+ def __init__(self, validate_args=False, event_ndims=0, name="identity"):
super(Identity, self).__init__(
is_constant_jacobian=True,
+ event_ndims=event_ndims,
validate_args=validate_args,
name=name)
@@ -1204,7 +1221,7 @@ class Identity(Bijector):
class PowerTransform(Bijector):
- """Bijector which computes `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`.
+ """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`.
The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps
inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse`
@@ -1242,7 +1259,6 @@ class PowerTransform(Bijector):
raise ValueError("`power` must be a non-negative TF constant.")
self._power = power
super(PowerTransform, self).__init__(
- batch_ndims=0,
event_ndims=event_ndims,
validate_args=validate_args,
name=name)
@@ -1262,9 +1278,7 @@ class PowerTransform(Bijector):
def _inverse_and_inverse_log_det_jacobian(self, y):
y = self._maybe_assert_valid_y(y)
- if self.shaper is None:
- raise ValueError("Jacobian cannot be computed with unknown event_ndims")
- _, _, event_dims = self.shaper.get_dims(y)
+ event_dims = self._event_dims_tensor(y)
if self.power == 0.:
x = math_ops.log(y)
ildj = -math_ops.reduce_sum(x, reduction_indices=event_dims)
@@ -1279,9 +1293,7 @@ class PowerTransform(Bijector):
def _forward_log_det_jacobian(self, x):
x = self._maybe_assert_valid_x(x)
- if self.shaper is None:
- raise ValueError("Jacobian cannot be computed with unknown event_ndims")
- _, _, event_dims = self.shaper.get_dims(x)
+ event_dims = self._event_dims_tensor(x)
if self.power == 0.:
return math_ops.reduce_sum(x, reduction_indices=event_dims)
return (1. / self.power - 1.) * math_ops.reduce_sum(
@@ -1306,14 +1318,14 @@ class PowerTransform(Bijector):
class Exp(PowerTransform):
- """Bijector which computes Y = g(X) = exp(X).
+ """Compute `Y = g(X) = exp(X)`.
Example Use:
```python
# Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
# batch ndim and 2 event ndims (i.e., vector of matrices).
- exp = Exp(batch_ndims=1, event_ndims=2)
+ exp = Exp(event_ndims=2)
x = [[[1., 2],
[3, 4]],
[[5, 6],
@@ -1483,11 +1495,11 @@ class _TriLPlusVDVTLightweightOperatorPD(object):
class Affine(Bijector):
- # pylint: disable=line-too-long
- """Bijector which computes `Y = g(X; shift, scale) = matmul(scale, X) + shift` where `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`.
+ """Compute `Y = g(X; shift, scale) = scale @ X + shift`.
+
+ Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`.
- Write `A @ X` for `matmul(A, X)`. In TF parlance, the `scale` term is
- logically equivalent to:
+ In TF parlance, the `scale` term is logically equivalent to:
```python
scale = (
@@ -1537,7 +1549,6 @@ class Affine(Bijector):
```
"""
- # pylint: enable=line-too-long
def __init__(self,
shift=None,
@@ -1653,9 +1664,12 @@ class Affine(Bijector):
self._shift.dtype.base_dtype != self._scale.dtype.base_dtype):
raise TypeError("shift.dtype({}) does not match scale.dtype({})".format(
self._shift.dtype, self._scale.dtype))
- super(Affine, self).__init__(
+ self._shaper = _DistributionShape(
batch_ndims=self._infer_batch_ndims(),
event_ndims=event_ndims,
+ validate_args=validate_args)
+ super(Affine, self).__init__(
+ event_ndims=event_ndims,
graph_parents=(
[event_ndims] +
[self._scale] if contrib_framework.is_tensor(self._scale)
@@ -1686,7 +1700,7 @@ class Affine(Bijector):
for correctness.
Returns:
- scale and batch_ndims. In the case of scaling by a constant, scale is a
+ scale. In the case of scaling by a constant, scale is a
floating point `Tensor`. Otherwise, scale is an `OperatorPD`.
Raises:
@@ -1831,9 +1845,9 @@ class Affine(Bijector):
if self.shift is not None:
return y + self.shift
return y
- y, sample_shape = self.shaper.make_batch_of_event_sample_matrices(y)
+ y, sample_shape = self._shaper.make_batch_of_event_sample_matrices(y)
y = self._scale.sqrt_matmul(y)
- y = self.shaper.undo_make_batch_of_event_sample_matrices(y, sample_shape)
+ y = self._shaper.undo_make_batch_of_event_sample_matrices(y, sample_shape)
if self.shift is not None:
return y + self.shift
return y
@@ -1844,9 +1858,9 @@ class Affine(Bijector):
x -= self.shift
if self._is_only_identity_multiplier:
return x / self._scale
- x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
+ x, sample_shape = self._shaper.make_batch_of_event_sample_matrices(x)
x = self._scale.sqrt_solve(x)
- x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
+ x = self._shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
return x
def _inverse_log_det_jacobian(self, y):
@@ -1858,7 +1872,7 @@ class Affine(Bijector):
# applied via broadcast.
d = math_ops.cast(array_ops.shape(x)[-1], dtype=self._scale.dtype)
return math_ops.log(math_ops.abs(self._scale)) * array_ops.where(
- math_ops.equal(self.shaper.event_ndims, 0), 1., d)
+ math_ops.equal(self._shaper.event_ndims, 0), 1., d)
fldj = self._scale.sqrt_log_abs_det()
# We need to squeeze off the padded dimension.
start = array_ops.where(self._rank_two_event_ndims_one, 1, 0)
@@ -1866,7 +1880,7 @@ class Affine(Bijector):
class AffineLinearOperator(Bijector):
- """Bijector which computes `Y = g(X; shift, scale) = scale @ X.T + shift`.
+ """Compute `Y = g(X; shift, scale) = scale(X.T) + shift`.
`shift` is a numeric `Tensor` and `scale` is a `LinearOperator`.
@@ -1989,10 +2003,12 @@ class AffineLinearOperator(Bijector):
else:
batch_ndims = 0 # We won't need shape inference when scale is None.
self._scale = scale
-
- super(AffineLinearOperator, self).__init__(
+ self._shaper = _DistributionShape(
batch_ndims=batch_ndims,
event_ndims=event_ndims,
+ validate_args=validate_args)
+ super(AffineLinearOperator, self).__init__(
+ event_ndims=event_ndims,
graph_parents=graph_parents,
is_constant_jacobian=True,
validate_args=validate_args,
@@ -2011,12 +2027,12 @@ class AffineLinearOperator(Bijector):
def _forward(self, x):
y = x
if self.scale is not None:
- y, sample_shape = self.shaper.make_batch_of_event_sample_matrices(
+ y, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
y, expand_batch_dim=False)
with ops.control_dependencies([self.scale.assert_non_singular()] if
self.validate_args else []):
y = self.scale.apply(y)
- y = self.shaper.undo_make_batch_of_event_sample_matrices(
+ y = self._shaper.undo_make_batch_of_event_sample_matrices(
y, sample_shape, expand_batch_dim=False)
if self.shift is not None:
y += self.shift
@@ -2027,11 +2043,11 @@ class AffineLinearOperator(Bijector):
if self.shift is not None:
x -= self.shift
if self.scale is not None:
- x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(
+ x, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
x, expand_batch_dim=False)
# Solve fails if the op is singular so we may safely skip this assertion.
x = self.scale.solve(x)
- x = self.shaper.undo_make_batch_of_event_sample_matrices(
+ x = self._shaper.undo_make_batch_of_event_sample_matrices(
x, sample_shape, expand_batch_dim=False)
return x
@@ -2060,7 +2076,7 @@ class Softplus(Bijector):
```python
# Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1
# batch ndim and 2 event ndims (i.e., vector of matrices).
- softplus = Softplus(batch_ndims=1, event_ndims=2)
+ softplus = Softplus(event_ndims=2)
x = [[[1., 2],
[3, 4]],
[[5, 6],
@@ -2078,7 +2094,6 @@ class Softplus(Bijector):
validate_args=False,
name="softplus"):
super(Softplus, self).__init__(
- batch_ndims=0,
event_ndims=event_ndims,
validate_args=validate_args,
name=name)
@@ -2087,9 +2102,7 @@ class Softplus(Bijector):
return nn_ops.softplus(x)
def _inverse_and_inverse_log_det_jacobian(self, y):
- if self.shaper is None:
- raise ValueError("Jacobian cannot be computed with unknown event_ndims")
- _, _, event_dims = self.shaper.get_dims(y)
+ event_dims = self._event_dims_tensor(y)
# Could also do:
# ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y),
# reduction_indices=event_dims)
@@ -2104,9 +2117,7 @@ class Softplus(Bijector):
return distribution_util.softplus_inverse(y), ildj
def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument
- if self.shaper is None:
- raise ValueError("Jacobian cannot be computed with unknown event_ndims")
- _, _, event_dims = self.shaper.get_dims(x)
+ event_dims = self._event_dims_tensor(x)
return -math_ops.reduce_sum(
nn_ops.softplus(-x), reduction_indices=event_dims)
@@ -2154,7 +2165,6 @@ class SoftmaxCentered(Bijector):
raise ValueError("`event_ndims` must be a TF constant which is 0 or 1")
self._static_event_ndims = event_ndims
super(SoftmaxCentered, self).__init__(
- batch_ndims=0, # We'll regard all non-event dims as sample dims.
event_ndims=event_ndims,
validate_args=validate_args,
name=name)
@@ -2330,12 +2340,11 @@ class SigmoidCentered(SoftmaxCentered):
def __init__(self, validate_args=False, name="sigmoid_centered"):
super(SigmoidCentered, self).__init__(
- validate_args=validate_args, name=name)
+ event_ndims=0, validate_args=validate_args, name=name)
class CholeskyOuterProduct(Bijector):
- # pylint: disable=line-too-long
- """Bijector which computes Y = g(X) = X X.T where X is a lower-triangular, positive-diagonal matrix.
+ """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
`event_ndims` must be 0 or 2, i.e., scalar or matrix.
@@ -2352,7 +2361,6 @@ class CholeskyOuterProduct(Bijector):
```
"""
- # pylint: enable=line-too-long
def __init__(self, event_ndims=2, validate_args=False,
name="cholesky_outer_product"):
@@ -2378,6 +2386,7 @@ class CholeskyOuterProduct(Bijector):
raise ValueError("`event_ndims` must be a TF constant which is 0 or 2")
self._static_event_ndims = event_ndims
super(CholeskyOuterProduct, self).__init__(
+ event_ndims=event_ndims,
validate_args=validate_args,
name=name)