diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-30 14:44:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-30 14:47:19 -0700 |
commit | 5810723cc8f25fcf651be56c5b0271f70011fc2d (patch) | |
tree | bf7183e53d9ba004602ff6d09de85885e3aee150 /tensorflow/contrib/distributions | |
parent | 144c2b4a5fadb6cfed371dc9d72119826dbaf418 (diff) |
Add `tf.contrib.distributions.bijectors.MatrixInverseTriL`: Bijector that inverts a lower-triangular matrix.
PiperOrigin-RevId: 198622553
Diffstat (limited to 'tensorflow/contrib/distributions')
4 files changed, 356 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 6192f04c8b..23d9dbcd91 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -1033,6 +1033,25 @@ cuda_py_test( ) cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "real_nvp_test", size = "small", srcs = ["python/kernel_tests/bijectors/real_nvp_test.py"], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000..1839703557 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeff..4965381ef3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -30,6 +30,7 @@ @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@ -68,6 +69,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000..71903f7052 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,145 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] |