aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2016-10-31 15:07:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 16:16:40 -0700
commit337d4a60b8e3f6666b79cb60e83ca297ae28c67d (patch)
treec17c814cd2fce09e6ae405bb3bfa9a990cc9b460
parent197da59c5089501ff03d3ab00cf87785591ca794 (diff)
ScaleAndShift Bijector updates/fixes:
DOC: Clarify that (for the ScaleAndShift bijector) scale * X is matrix multiplication, and that scale must be triangular and non-singular. CODE: Check that the conditions on scale are met. Change: 137761917
-rw-r--r--tensorflow/contrib/distributions/BUILD2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py62
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijector.py66
3 files changed, 110 insertions, 20 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 850cbf8d26..326cba4dd3 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -388,7 +388,7 @@ cuda_py_tests(
cuda_py_tests(
name = "bijector_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/bijector_test.py"],
additional_deps = [
":distributions_py",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
index 7356511a12..b69caf0bf4 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
@@ -368,23 +368,29 @@ class ScaleAndShiftBijectorTest(tf.test.TestCase):
for run in (static_run, dynamic_run):
mu = [1., -1]
- sigma = np.eye(2, dtype=np.float32)
+ # Note: sigma is -1 * identity matrix.
+ sigma = -np.eye(2, dtype=np.float32)
bijector = bijectors.ScaleAndShift(
shift=mu, scale=sigma, event_ndims=1)
self.assertEqual(0, bijector.shaper.batch_ndims.eval()) # "no batches"
self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
x = [1., 1]
- self.assertAllClose([2., 0], run(bijector.forward, x))
- self.assertAllClose([0., 2], run(bijector.inverse, x))
+ # matmul(sigma, x) + shift
+ # = [-1, -1] + [1, -1]
+ self.assertAllClose([0., -2], run(bijector.forward, x))
+ self.assertAllClose([0., -2], run(bijector.inverse, x))
self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+ # x is a 2-batch of 2-vectors.
+ # The first vector is [1, 1], the second is [-1, -1].
+ # Each undergoes matmul(sigma, x) + shift.
x = [[1., 1],
[-1., -1]]
- self.assertAllClose([[2., 0],
- [0, -2]],
+ self.assertAllClose([[0., -2],
+ [2., 0]],
run(bijector.forward, x))
- self.assertAllClose([[0., 2],
- [-2., 0]],
+ self.assertAllClose([[0., -2],
+ [2., 0]],
run(bijector.inverse, x))
self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
@@ -476,11 +482,53 @@ class ScaleAndShiftBijectorTest(tf.test.TestCase):
self.assertAllClose(
[0.], sess.run(bijector.inverse_log_det_jacobian(x), feed_dict))
+ def testNoBatchMultivariateRaisesWhenSingular(self):
+ with self.test_session():
+ mu = [1., -1]
+ sigma = [[0., 1.], [1., 1.]] # Has zero on the diag!
+ bijector = bijectors.ScaleAndShift(
+ shift=mu, scale=sigma, event_ndims=1, validate_args=True)
+ with self.assertRaisesOpError("Singular"):
+ bijector.forward([1., 1.]).eval()
+
+ def testEventNdimsLargerThanOneRaises(self):
+ with self.test_session():
+ mu = [1., -1]
+ sigma = [[1., 1.], [1., 1.]]
+ bijector = bijectors.ScaleAndShift(
+ shift=mu, scale=sigma, event_ndims=2, validate_args=True)
+ with self.assertRaisesOpError("event_ndims"):
+ bijector.forward([1., 1.]).eval()
+
+ def testNonSquareMatrixScaleRaises(self):
+ # event_ndims = 1, so we expected a matrix, but will only feed a vector.
+ with self.test_session():
+ mu = [1., -1]
+ sigma = [[1., 1., 1.], [1., 1., 1.]]
+ bijector = bijectors.ScaleAndShift(
+ shift=mu, scale=sigma, event_ndims=1, validate_args=True)
+ with self.assertRaisesOpError("square"):
+ bijector.forward([1., 1.]).eval()
+
+ def testScaleZeroScalarRaises(self):
+ with self.test_session():
+ mu = -1.
+ sigma = 0. # Scalar, leads to non-invertible bijector
+ bijector = bijectors.ScaleAndShift(
+ shift=mu, scale=sigma, validate_args=True)
+ with self.assertRaisesOpError("Singular"):
+ bijector.forward(1.).eval()
+
def testScalarCongruency(self):
with self.test_session():
bijector = bijectors.ScaleAndShift(shift=3.6, scale=0.42, event_ndims=0)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
+ def testScalarCongruencyWithNegativeScale(self):
+ with self.test_session():
+ bijector = bijectors.ScaleAndShift(shift=3.6, scale=-0.42, event_ndims=0)
+ assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
+
class SoftplusBijectorTest(tf.test.TestCase):
"""Tests the correctness of the Y = g(X) = Log[1 + exp(X)] transformation."""
diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index 2472c12d3f..b32500bf74 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -1048,28 +1048,39 @@ class Exp(Bijector):
class ScaleAndShift(Bijector):
- """Bijector which computes Y = g(X; shift, scale) = scale * X + shift.
+ """Bijector which computes Y = g(X; shift, scale) = matmul(scale, X) + shift.
+
+ `scale` is either a non-zero scalar, or a lower triangular matrix with
+ non-zero diagonal. This means the `Bijector` will be invertible and
+ computation of determinant and inverse will be efficient.
+
+ As a result, the mean and covariance are transformed:
+
+ ```
+ E[Y] = matmul(scale, E[X])
+ Cov[Y] = matmul(scale, matmul(Cov[X], scale, transpose_b=True))
+ ```
Example Use:
```python
- # No batch, scalar.
+ # No batch, scalar
mu = 0 # shape=[]
- sigma = 1 # shape=[]
+ sigma = 1 # shape=[], treated like a 1x1 matrix.
b = ScaleAndShift(shift=mu, scale=sigma)
# b.shaper.batch_ndims == 0
# b.shaper.event_ndims == 0
# One batch, scalar.
mu = ... # shape=[b], b>0
- sigma = ... # shape=[b], b>0
+ sigma = ... # shape=[b], b>0, treated like a batch of 1x1 matrices
b = ScaleAndShift(shift=mu, scale=sigma)
# b.shaper.batch_ndims == 1
# b.shaper.event_ndims == 0
# No batch, multivariate.
mu = ... # shape=[d], d>0
- sigma = ... # shape=[d, d], d>0
+ sigma = ... # shape=[d, d], d>0, treated like a single dxd matrix.
b = ScaleAndShift(shift=mu, scale=sigma, event_ndims=1)
# b.shaper.batch_ndims == 0
# b.shaper.event_ndims == 1
@@ -1097,13 +1108,22 @@ class ScaleAndShift(Bijector):
event_ndims=0,
validate_args=False,
name="scale_and_shift"):
- """Instantiates the `Exp` bijector.
+ """Instantiates the `ScaleAndShift` bijector.
+
+ This `Bijector` is initialized with `scale` and `shift` `Tensors`, giving
+ the forward operation:
+
+ ```Y = g(X) = matmul(scale, X) + shift```
Args:
- shift: `Tensor` used to shift input, i.e., `Y = g(X) = scale * X + shift`.
- scale: `Tensor` used to scale input, i.e., `Y = g(X) = scale * X + shift`.
+ shift: Numeric `Tensor`.
+ scale: Numeric `Tensor` of same `dtype` as `shift`. If `event_ndims = 0`,
+ `scale` is treated like a `1x1` matrix or a batch thereof.
+ Otherwise, the last two dimensions of `scale` define a matrix.
+ `scale` must have non-negative diagonal entries. The upper triangular
+ part of `scale` is ignored, effectively making it lower triangular.
event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
- associated with a particular draw from the distribution.
+ associated with a particular draw from the distribution. Must be 0 or 1
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String` name given to ops managed by this object.
@@ -1111,10 +1131,16 @@ class ScaleAndShift(Bijector):
self._parameters = {}
self._name = name
+ self._validate_args = validate_args
with self._name_scope("init", values=[shift, scale, event_ndims]):
self._shift = ops.convert_to_tensor(shift, name="shift")
self._scale = ops.convert_to_tensor(scale, name="scale")
event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ if validate_args:
+ event_ndims = control_flow_ops.with_dependencies(
+ [check_ops.assert_less(
+ event_ndims, 2, message="event_ndims must be 0 or 1")],
+ event_ndims)
if self.shift.dtype.base_dtype != self.scale.dtype.base_dtype:
raise TypeError("%s.dtype=%s does not match %s.dtype=%s" %
(self.shift.name, self.shift.dtype, self.scale.name,
@@ -1173,6 +1199,23 @@ class ScaleAndShift(Bijector):
array_ops.ones([right], dtype=dtypes.int32)))
scale = array_ops.reshape(scale, pad)
batch_ndims = ndims - 2 + right
+ # For safety, explicitly zero-out the upper triangular part.
+ scale = array_ops.matrix_band_part(scale, -1, 0)
+ if self.validate_args:
+ # matrix_band_part will fail if scale is not at least rank 2.
+ shape = array_ops.shape(scale)
+ assert_square = check_ops.assert_equal(
+ shape[-2], shape[-1],
+ message="Input must be a (batch of) square matrix.")
+ # Assuming lower-triangular means we only need check diag != 0.
+ diag = array_ops.matrix_diag_part(scale)
+ is_non_singular = math_ops.logical_not(
+ math_ops.reduce_any(
+ math_ops.equal(diag, ops.convert_to_tensor(0, dtype=diag.dtype))))
+ assert_non_singular = control_flow_ops.Assert(
+ is_non_singular, ["Singular matrix encountered", diag])
+ scale = control_flow_ops.with_dependencies(
+ [assert_square, assert_non_singular], scale)
return scale, batch_ndims
@property
@@ -1198,9 +1241,8 @@ class ScaleAndShift(Bijector):
return x
def _inverse_log_det_jacobian(self, y): # pylint: disable=unused-argument
- return -math_ops.reduce_sum(
- math_ops.log(array_ops.matrix_diag_part(self.scale)),
- reduction_indices=[-1])
+ abs_diag = math_ops.abs(array_ops.matrix_diag_part(self.scale))
+ return -math_ops.reduce_sum(math_ops.log(abs_diag), reduction_indices=[-1])
def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument
return -self._inverse_log_det_jacobian(x)