aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-30 14:44:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 14:47:19 -0700
commit5810723cc8f25fcf651be56c5b0271f70011fc2d (patch)
treebf7183e53d9ba004602ff6d09de85885e3aee150 /tensorflow/contrib/distributions
parent144c2b4a5fadb6cfed371dc9d72119826dbaf418 (diff)
Add `tf.contrib.distributions.bijectors.MatrixInverseTriL`: Bijector that inverts a lower-triangular matrix.
PiperOrigin-RevId: 198622553
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py190
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py145
4 files changed, 356 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 6192f04c8b..23d9dbcd91 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -1033,6 +1033,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "matrix_inverse_tril_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "real_nvp_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/real_nvp_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
new file mode 100644
index 0000000000..1839703557
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
@@ -0,0 +1,190 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MatrixInverseTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class MatrixInverseTriLBijectorTest(test.TestCase):
+ """Tests the correctness of the Y = inv(tril) transformation."""
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testComputesCorrectValues(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ self.assertEqual("matrix_inverse_tril", inv.name)
+ x_ = np.array([[0.7, 0., 0.],
+ [0.1, -1., 0.],
+ [0.3, 0.25, 0.5]], dtype=np.float32)
+ x_inv_ = np.linalg.inv(x_)
+ expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_))))
+
+ y = inv.forward(x_)
+ x_back = inv.inverse(x_inv_)
+ fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
+ ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)
+
+ y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])
+
+ self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
+ self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
+ self.assertNear(expected_fldj_, fldj_, err=1e-3)
+ self.assertNear(-expected_fldj_, ildj_, err=1e-3)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testOneByOneMatrix(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.array([[5.]], dtype=np.float32)
+ x_inv_ = np.array([[0.2]], dtype=np.float32)
+ expected_fldj_ = np.log(0.04)
+
+ y = inv.forward(x_)
+ x_back = inv.inverse(x_inv_)
+ fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
+ ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)
+
+ y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])
+
+ self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
+ self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
+ self.assertNear(expected_fldj_, fldj_, err=1e-3)
+ self.assertNear(-expected_fldj_, ildj_, err=1e-3)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testZeroByZeroMatrix(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.eye(0, dtype=np.float32)
+ x_inv_ = np.eye(0, dtype=np.float32)
+ expected_fldj_ = 0.
+
+ y = inv.forward(x_)
+ x_back = inv.inverse(x_inv_)
+ fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
+ ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)
+
+ y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])
+
+ self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
+ self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
+ self.assertNear(expected_fldj_, fldj_, err=1e-3)
+ self.assertNear(-expected_fldj_, ildj_, err=1e-3)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBatch(self):
+ # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
+ # (2, 1).
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.array([[[[1., 0.],
+ [2., 3.]]],
+ [[[4., 0.],
+ [5., -6.]]]], dtype=np.float32)
+ x_inv_ = np.linalg.inv(x_)
+ expected_fldj_ = -4. * np.sum(
+ np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1)
+
+ y = inv.forward(x_)
+ x_back = inv.inverse(x_inv_)
+ fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
+ ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)
+
+ y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])
+
+ self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
+ self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
+ self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
+ self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testErrorOnInputRankTooLow(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.array([0.1], dtype=np.float32)
+ rank_error_msg = "must have rank at least 2"
+ with self.test_session():
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ inv.forward(x_).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ inv.inverse(x_).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+
+ # TODO(b/80481923): Figure out why these assertions fail, and fix them.
+ ## def testErrorOnInputNonSquare(self):
+ ## inv = bijectors.MatrixInverseTriL(validate_args=True)
+ ## x_ = np.array([[1., 2., 3.],
+ ## [4., 5., 6.]], dtype=np.float32)
+ ## square_error_msg = "must be a square matrix"
+ ## with self.test_session():
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## inv.forward(x_).eval()
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## inv.inverse(x_).eval()
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testErrorOnInputNotLowerTriangular(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.array([[1., 2.],
+ [3., 4.]], dtype=np.float32)
+ triangular_error_msg = "must be lower triangular"
+ with self.test_session():
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ inv.forward(x_).eval()
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ inv.inverse(x_).eval()
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testErrorOnInputSingular(self):
+ inv = bijectors.MatrixInverseTriL(validate_args=True)
+ x_ = np.array([[1., 0.],
+ [0., 0.]], dtype=np.float32)
+ nonsingular_error_msg = "must have all diagonal entries nonzero"
+ with self.test_session():
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ inv.forward(x_).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ inv.inverse(x_).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 51478dbeff..4965381ef3 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -30,6 +30,7 @@
@@Invert
@@Kumaraswamy
@@MaskedAutoregressiveFlow
+@@MatrixInverseTriL
@@Ordered
@@Permute
@@PowerTransform
@@ -68,6 +69,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
+from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import *
from tensorflow.contrib.distributions.python.ops.bijectors.ordered import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py
new file mode 100644
index 0000000000..71903f7052
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py
@@ -0,0 +1,145 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MatrixInverseTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+ "MatrixInverseTriL",
+]
+
+
+class MatrixInverseTriL(bijector.Bijector):
+ """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix.
+
+ `L` must be nonsingular; equivalently, all diagonal entries of `L` must be
+ nonzero.
+
+ The input must have `rank >= 2`. The input is treated as a batch of matrices
+ with batch shape `input.shape[:-2]`, where each matrix has dimensions
+ `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal
+ `input.shape[-1]`).
+
+ #### Examples
+
+ ```python
+ tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]])
+ # Result: [[1., 0], [-2, 1]], i.e., inv(x)
+
+ tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]])
+ # Result: [[1., 0], [2, 1]], i.e., inv(y).
+ ```
+
+ """
+
+ def __init__(self, validate_args=False, name="matrix_inverse_tril"):
+ """Instantiates the `MatrixInverseTriL` bijector.
+
+ Args:
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._graph_parents = []
+ self._name = name
+ super(MatrixInverseTriL, self).__init__(
+ forward_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ with ops.control_dependencies(self._assertions(x)):
+ shape = array_ops.shape(x)
+ return linalg_ops.matrix_triangular_solve(
+ x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True)
+
+ def _inverse(self, y):
+ return self._forward(y)
+
+ def _forward_log_det_jacobian(self, x):
+ # Calculation of the Jacobian:
+ #
+ # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z =
+ # X^{-1} where Z = (z_{ij}). Then
+ #
+ # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1},
+ #
+ # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j)
+ # entry and zeros elsewhere. By the product rule,
+ #
+ # 0 = d/dt [Identity matrix]
+ # = d/dt [Y Y^{-1}]
+ # = Y d/dt[Y^{-1}] + dY/dt Y^{-1}
+ #
+ # so
+ #
+ # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1}
+ # = -Y^{-1} E_{ij} Y^{-1}.
+ #
+ # Evaluating at t=0,
+ #
+ # dZ/dx_{ij} = -Z E_{ij} Z.
+ #
+ # Taking the (r,s) entry of each side,
+ #
+ # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}.
+ #
+ # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose
+ # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n
+ # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij}
+ # shows that the block at position (r,i) is -z_{ri}Z. Hence
+ #
+ # J = -KroneckerProduct(Z, Z),
+ # det(J) = (-1)^(n^2) (det Z)^(2n)
+ # = (-1)^n (det X)^(-2n).
+ with ops.control_dependencies(self._assertions(x)):
+ return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) *
+ math_ops.reduce_sum(
+ math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))),
+ axis=-1))
+
+ def _assertions(self, x):
+ if not self.validate_args:
+ return []
+ shape = array_ops.shape(x)
+ is_matrix = check_ops.assert_rank_at_least(
+ x, 2, message="Input must have rank at least 2.")
+ is_square = check_ops.assert_equal(
+ shape[-2], shape[-1], message="Input must be a square matrix.")
+ above_diagonal = array_ops.matrix_band_part(
+ array_ops.matrix_set_diag(
+ x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)),
+ 0, -1)
+ is_lower_triangular = check_ops.assert_equal(
+ above_diagonal, array_ops.zeros_like(above_diagonal),
+ message="Input must be lower triangular.")
+ # A lower triangular matrix is nonsingular iff all its diagonal entries are
+ # nonzero.
+ diag_part = array_ops.matrix_diag_part(x)
+ is_nonsingular = check_ops.assert_none_equal(
+ diag_part, array_ops.zeros_like(diag_part),
+ message="Input must have all diagonal entries nonzero.")
+ return [is_matrix, is_square, is_lower_triangular, is_nonsingular]