diff options
author | Ian Langmore <langmore@google.com> | 2016-07-27 07:29:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-27 08:32:47 -0700 |
commit | 6339f8348cb19b9393447b2a59352f44a2c1a518 (patch) | |
tree | c48b903dd52778a9e471f0893775d05ba7035dc0 | |
parent | bcd81a4ceb171de978c8abc4379056a96ebee370 (diff) |
OperatorPDVDVTSqrtUpdate, MultivariateNormalDiagPlusVDVT. Define a multivariate normal by the sqrt of covariance: S = M + V D V^T, M is diagonal.
Change: 128587968
9 files changed, 1525 insertions, 85 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 4027ae3ef8..868e1947dd 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -54,6 +54,28 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "operator_pd_identity_test", + size = "small", + srcs = ["python/kernel_tests/operator_pd_identity_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "operator_pd_vdvt_update_test", + size = "medium", + srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py index d8e75e4be2..a985477242 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py @@ -117,6 +117,61 @@ class MultivariateNormalDiagTest(tf.test.TestCase): self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) +class MultivariateNormalDiagPlusVDVTTest(tf.test.TestCase): + """Well tested because this is a simple override of the base class.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testMean(self): + mu = [-1.0, 1.0] + diag_large = [1.0, 5.0] + v = [[2.0], [3.0]] + diag_small = [3.0] + with self.test_session(): + dist = distributions.MultivariateNormalDiagPlusVDVT( + mu, diag_large, v, diag_small=diag_small) + self.assertAllEqual(mu, dist.mean().eval()) + + def testNonmatchingMuAndSigmaDimensionFailsStatic(self): + mu = self._rng.rand(2) + # With this diag_large and v, the covariance is 3 x 3 + diag_large = self._rng.rand(3) + v = self._rng.rand(3, 2) # v works with diag_large. + with self.test_session(): + with self.assertRaisesRegexp(ValueError, "shape.*should match"): + distributions.MultivariateNormalDiagPlusVDVT( + mu, diag_large, v) + + def testNonmatchingMuDiagDimensionsFailsDynamic(self): + mu = self._rng.rand(2) + # With this diag_large and v, the covariance is 3 x 3 + diag_large = self._rng.rand(3) + v = self._rng.rand(3, 2) # v works with diag_large. + + with self.test_session(): + mu_ph = tf.placeholder(tf.float32, name="mu_ph") + v_ph = tf.placeholder(tf.float32, name="v_ph") + diag_ph = tf.placeholder(tf.float32, name="diag_ph") + dist = distributions.MultivariateNormalDiagPlusVDVT( + mu_ph, diag_ph, v_ph) + with self.assertRaisesOpError("mu.*cov.*shape"): + dist.mean().eval(feed_dict={mu_ph: mu, diag_ph: diag_large, v_ph: v}) + + def testSample(self): + mu = [-1.0, 1.0] + diag_large = [1.0, 0.5] + v = [[0.2], [0.3]] + with self.test_session(): + dist = distributions.MultivariateNormalDiagPlusVDVT(mu, diag_large, v) + + samps = dist.sample_n(1000, seed=0).eval() + cov_mat = dist.sigma.eval() + + self.assertAllClose(mu, samps.mean(axis=0), atol=0.1) + self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) + + class MultivariateNormalCholeskyTest(tf.test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py index 3a0f6e1d5a..c11b0357e7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py @@ -17,14 +17,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import numpy as np +import six import tensorflow as tf from tensorflow.contrib.distributions.python.ops import operator_pd_diag from tensorflow.contrib.distributions.python.ops import operator_test_util -class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest): +@six.add_metaclass(abc.ABCMeta) +class OperatorPDDiagBaseTest(object): def setUp(self): self._rng = np.random.RandomState(42) @@ -32,8 +35,14 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest): def _random_pd_diag(self, diag_shape): return self._rng.rand(*diag_shape) + 0.1 + @abc.abstractmethod def _diag_to_matrix(self, diag): - return tf.batch_matrix_diag(diag**2).eval() + pass + + @abc.abstractproperty + def operator_class(self): + # Return the operator class that this tests. + pass def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): # Create a diagonal matrix explicitly. @@ -46,7 +55,7 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest): # The diag is the square root. diag = self._random_pd_diag(diag_shape).astype(dtype) mat = self._diag_to_matrix(diag).astype(dtype) - operator = operator_pd_diag.OperatorPDSqrtDiag(diag) + operator = self.operator_class(diag) return operator, mat @@ -66,5 +75,29 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest): operator.to_dense().eval() # Should not raise +class OperatorPDDiagTest( + OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest): + """Most tests done in the base classes.""" + + def _diag_to_matrix(self, diag): + return tf.batch_matrix_diag(diag).eval() + + @property + def operator_class(self): + return operator_pd_diag.OperatorPDDiag + + +class OperatorPDSqrtDiagTest( + OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest): + """Most tests done in the base classes.""" + + def _diag_to_matrix(self, diag): + return tf.batch_matrix_diag(diag**2).eval() + + @property + def operator_class(self): + return operator_pd_diag.OperatorPDSqrtDiag + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py new file mode 100644 index 0000000000..7f411105fb --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py @@ -0,0 +1,115 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.distributions.python.ops import operator_pd_identity +from tensorflow.contrib.distributions.python.ops import operator_test_util + +distributions = tf.contrib.distributions + + +class OperatorPDIdentityTest(operator_test_util.OperatorPDDerivedClassTest): + """Most tests done in the base class.""" + + def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): + # Build an identity matrix with right shape and dtype. + # Build an operator that should act the same way. + batch_shape = list(batch_shape) + diag_shape = batch_shape + [k] + matrix_shape = batch_shape + [k, k] + diag = tf.ones(diag_shape, dtype=dtype) + identity_matrix = tf.batch_matrix_diag(diag) + operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype) + return operator, identity_matrix.eval() + + def test_bad_dtype_args_raise(self): + dtype = np.float32 + batch_shape = [2, 3] + k = 4 + with self.test_session(): + operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype) + + x_good_shape = batch_shape + [k, 5] + x_good = self._rng.randn(*x_good_shape).astype(dtype) + x_bad = x_good.astype(np.float64) + + operator.matmul(x_good).eval() # Should not raise. + + with self.assertRaisesRegexp(TypeError, 'dtype'): + operator.matmul(x_bad) + + with self.assertRaisesRegexp(TypeError, 'dtype'): + operator.solve(x_bad) + + with self.assertRaisesRegexp(TypeError, 'dtype'): + operator.sqrt_solve(x_bad) + + def test_bad_rank_args_raise(self): + # Prepend a singleton dimension, changing the rank of 'x', but not the size. + dtype = np.float32 + batch_shape = [2, 3] + k = 4 + with self.test_session(): + operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype) + + x_good_shape = batch_shape + [k, 5] + x_good = self._rng.randn(*x_good_shape).astype(dtype) + x_bad = x_good.reshape(1, 2, 3, 4, 5) + + operator.matmul(x_good).eval() # Should not raise. + + with self.assertRaisesRegexp(ValueError, 'tensor rank'): + operator.matmul(x_bad) + + with self.assertRaisesRegexp(ValueError, 'tensor rank'): + operator.solve(x_bad) + + with self.assertRaisesRegexp(ValueError, 'tensor rank'): + operator.sqrt_solve(x_bad) + + def test_incompatible_shape_args_raise(self): + # Test shapes that are the same rank but incompatible for matrix + # multiplication. + dtype = np.float32 + batch_shape = [2, 3] + k = 4 + with self.test_session(): + operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype) + + x_good_shape = batch_shape + [k, 5] + x_good = self._rng.randn(*x_good_shape).astype(dtype) + x_bad_shape = batch_shape + [5, k] + x_bad = x_good.reshape(*x_bad_shape) + + operator.matmul(x_good).eval() # Should not raise. + + with self.assertRaisesRegexp(ValueError, 'Incompatible'): + operator.matmul(x_bad) + + with self.assertRaisesRegexp(ValueError, 'Incompatible'): + operator.solve(x_bad) + + with self.assertRaisesRegexp(ValueError, 'Incompatible'): + operator.sqrt_solve(x_bad) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py new file mode 100644 index 0000000000..66d54561b5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py @@ -0,0 +1,273 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.distributions.python.ops import operator_pd_full +from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update +from tensorflow.contrib.distributions.python.ops import operator_test_util + +distributions = tf.contrib.distributions + + +class OperatorPDSqrtVDVTUpdateTest( + operator_test_util.OperatorPDDerivedClassTest): + """Most tests done in the base class.""" + _diag_is_none = False + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_pd_matrix(self, shape): + # With probability 1 this is positive definite. + sqrt = self._rng.randn(*shape) + mat = tf.batch_matmul(sqrt, sqrt, adj_y=True) + return mat.eval() + + def _random_v_and_diag(self, mat_shape, v_matrix_rank): + # Get the necessary elements to make the sqrt update. + mat_shape = list(mat_shape) + batch_shape = mat_shape[:-2] + diag_shape = mat_shape[:-2] + [v_matrix_rank] + k = mat_shape[-1] + assert k == mat_shape[-2], 'Must be a square matrix' + v_shape = batch_shape + [k, v_matrix_rank] + v = self._rng.randn(*v_shape) # anything goes with "v"! + + if self._diag_is_none: + diag = None + else: + diag = self._rng.rand(*diag_shape) + 0.1 # Positive diag! + return v, diag + + def _updated_mat(self, mat, v, diag): + # Get dense matrix defined by its square root, which is an update of `mat`: + # A = (mat + v D v^T) (mat + v D v^T)^T + # D is the diagonal matrix with `diag` on the diagonal. + + # If diag is None, then it defaults to the identity matrix, so DV^T = V^T + if diag is None: + diag_vt = tf.batch_matrix_transpose(v) + else: + diag_mat = tf.batch_matrix_diag(diag) + diag_vt = tf.batch_matmul(diag_mat, v, adj_y=True) + + v_diag_vt = tf.batch_matmul(v, diag_vt) + sqrt = mat + v_diag_vt + a = tf.batch_matmul(sqrt, sqrt, adj_y=True) + return a.eval() + + def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): + """This method is called by base class, enabling many standard tests.""" + # Create a matrix then explicitly update it with v and diag. + # Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag + # The operator should have the same behavior. + # + # The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which + # case it will be 1 as well. + if k == 1: + v_matrix_rank = k + else: + v_matrix_rank = k // 2 + mat_shape = list(batch_shape) + [k, k] + mat = self._random_pd_matrix(mat_shape) + v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank) + + # Set dtypes + mat = mat.astype(dtype) + v = v.astype(dtype) + if diag is not None: + diag = diag.astype(dtype) + + # The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T + # Our final updated operator should behave like this. + updated_mat = self._updated_mat(mat, v, diag) + + # Represents the matrix: `mat`, before updating. + # This is the Operator that we will update. + o_made_with_mat = operator_pd_full.OperatorPDFull(mat) + + # Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T, + # achieved by updating the operator "o_made_with_mat". + # This is the operator we're testing. + operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + o_made_with_mat, v, diag) + + return operator, updated_mat + + def test_to_dense_placeholder(self): + # Test simple functionality when the inputs are placeholders. + mat_shape = [3, 3] + v_matrix_rank = 2 + with self.test_session(): + # Make an OperatorPDFull with a matrix placeholder. + mat_ph = tf.placeholder(tf.float64, name='mat_ph') + mat = self._random_pd_matrix(mat_shape) + o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph) + + # Make the placeholders and arrays for the updated operator. + v_ph = tf.placeholder(tf.float64, name='v_ph') + v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank) + if self._diag_is_none: + diag_ph = None + feed_dict = {v_ph: v, mat_ph: mat} + else: + diag_ph = tf.placeholder(tf.float64, name='diag_ph') + feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat} + + # Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders. + operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + o_made_with_mat, v_ph, diag=diag_ph) + + # Should not fail + operator.to_dense().eval(feed_dict=feed_dict) + operator.log_det().eval(feed_dict=feed_dict) + + def test_operator_not_subclass_of_operator_pd_raises(self): + # We enforce that `operator` is an `OperatorPDBase`. + with self.test_session(): + v, diag = self._random_v_and_diag((3, 3), 2) + operator_m = 'I am not a subclass of OperatorPDBase' + + with self.assertRaisesRegexp(TypeError, 'not instance'): + operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag) + + def test_non_pos_def_diag_raises(self): + if self._diag_is_none: + return + # We enforce that the diag is positive definite. + with self.test_session(): + matrix_shape = (3, 3) + v_rank = 2 + v, diag = self._random_v_and_diag(matrix_shape, v_rank) + mat = self._random_pd_matrix(matrix_shape) + diag[0] = 0.0 + + operator_m = operator_pd_full.OperatorPDFull(mat) + operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + operator_m, v, diag) + + with self.assertRaisesOpError('positive'): + operator.to_dense().eval() + + def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self): + # We enforce that the diag is positive definite. + if self._diag_is_none: + return + with self.test_session(): + matrix_shape = (3, 3) + v_rank = 2 + v, diag = self._random_v_and_diag(matrix_shape, v_rank) + mat = self._random_pd_matrix(matrix_shape) + diag[0] = 0.0 + + operator_m = operator_pd_full.OperatorPDFull(mat) + operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + operator_m, v, diag, verify_pd=False) + + operator.to_dense().eval() # Should not raise. + + def test_event_shape_mismatch_v_and_diag_raises_static(self): + v = self._rng.rand(4, 3, 2) + diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v. + with self.test_session(): + + mat = self._random_pd_matrix((4, 3, 3)) # mat and v match + operator_m = operator_pd_full.OperatorPDFull(mat) + with self.assertRaisesRegexp(ValueError, 'diag.*v.*last dimension'): + operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag) + + def test_batch_shape_mismatch_v_and_diag_raises_static(self): + v = self._rng.rand(4, 3, 2) + diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v. + with self.test_session(): + + mat = self._random_pd_matrix((4, 3, 3)) # mat and v match + operator_m = operator_pd_full.OperatorPDFull(mat) + with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'): + operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag) + + def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self): + v = self._rng.rand(1, 2, 2, 2) + diag = self._rng.rand(5, 1) # Should have rank 1 less than v. + with self.test_session(): + + mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match + operator_m = operator_pd_full.OperatorPDFull(mat) + with self.assertRaisesRegexp(ValueError, 'diag.*rank'): + operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag) + + def test_event_shape_mismatch_v_and_diag_raises_dynamic(self): + with self.test_session(): + + v = self._rng.rand(4, 3, 2) + diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v. + mat = self._random_pd_matrix((4, 3, 3)) # mat and v match + + v_ph = tf.placeholder(tf.float32, name='v_ph') + diag_ph = tf.placeholder(tf.float32, name='diag_ph') + mat_ph = tf.placeholder(tf.float32, name='mat_ph') + + operator_m = operator_pd_full.OperatorPDFull(mat_ph) + updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + operator_m, v_ph, diag_ph) + with self.assertRaisesOpError('x == y'): + updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat}) + + def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self): + with self.test_session(): + v = self._rng.rand(4, 3, 2) + diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v. + mat = self._random_pd_matrix((4, 3, 3)) # mat and v match + + v_ph = tf.placeholder(tf.float32, name='v_ph') + diag_ph = tf.placeholder(tf.float32, name='diag_ph') + mat_ph = tf.placeholder(tf.float32, name='mat_ph') + + operator_m = operator_pd_full.OperatorPDFull(mat_ph) + updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + operator_m, v_ph, diag_ph) + with self.assertRaisesOpError('x == y'): + updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat}) + + def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self): + with self.test_session(): + + v = self._rng.rand(2, 2, 2, 2) + diag = self._rng.rand(2, 2) # Should have rank 1 less than v. + mat = self._random_pd_matrix((2, 2, 2, 2)) # mat and v match + + v_ph = tf.placeholder(tf.float32, name='v_ph') + diag_ph = tf.placeholder(tf.float32, name='diag_ph') + mat_ph = tf.placeholder(tf.float32, name='mat_ph') + + operator_m = operator_pd_full.OperatorPDFull(mat_ph) + updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + operator_m, v_ph, diag_ph) + with self.assertRaisesOpError('rank'): + updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat}) + + +class OperatorPDSqrtVDVTUpdateNoneDiagTest(OperatorPDSqrtVDVTUpdateTest): + _diag_is_none = True + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index 90e26336d7..a3b1baeba5 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -24,6 +24,7 @@ from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky from tensorflow.contrib.distributions.python.ops import operator_pd_diag from tensorflow.contrib.distributions.python.ops import operator_pd_full +from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -40,6 +41,7 @@ __all__ = [ "MultivariateNormalDiag", "MultivariateNormalCholesky", "MultivariateNormalFull", + "MultivariateNormalDiagPlusVDVT", ] @@ -52,14 +54,13 @@ class MultivariateNormalOperatorPD(distribution.Distribution): #### Mathematical details - The PDF of this distribution is: + With `C` the covariance matrix represented by the operator, the PDF of this + distribution is: ``` - f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu)) ``` - where `.` denotes the inner product on `R^k` and `^*` denotes transpose. - #### Examples A single multi-variate Gaussian distribution is defined by a vector of means @@ -109,10 +110,10 @@ class MultivariateNormalOperatorPD(distribution.Distribution): validate_args: Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: `Boolean`, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -170,12 +171,12 @@ class MultivariateNormalOperatorPD(distribution.Distribution): @property def validate_args(self): - """Boolean describing behavior on invalid input.""" + """`Boolean` describing behavior on invalid input.""" return self._validate_args @property def allow_nan_stats(self): - """Boolean describing behavior when a stat is undefined for batch member.""" + """`Boolean` describing behavior when stats are undefined.""" return self._allow_nan_stats @property @@ -417,7 +418,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD): determined by `diag_stdev`: `C_{ii} = diag_stdev[i]**2`. ``` - f(x) = (2*pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 * (x - mu)^T C^{-1} (x - mu)) + f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu)) ``` #### Examples @@ -467,14 +468,14 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD): mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`, - representing the standard deviations. + representing the standard deviations. Must be positive. validate_args: Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: `Boolean`, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -487,6 +488,125 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD): name=name) +class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD): + """The multivariate normal distribution on `R^k`. + + Every batch member of this distribution is defined by a mean and a lightweight + covariance matrix `C`. + + #### Mathematical details + + The PDF of this distribution in terms of the mean `mu` and covariance `C` is: + + ``` + f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu)) + ``` + + For every batch member, this distribution represents `k` random variables + `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix + `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]` + + The user initializes this class by providing the mean `mu`, and a lightweight + definition of `C`: + + ``` + C = SS^T = SS = (M + V D V^T) (M + V D V^T) + M is diagonal (k x k) + V = is shape (k x r), typically r << k + D = is diagonal (r x r), optional (defaults to identity). + ``` + + This allows for `O(kr + r^3)` pdf evaluation and determinant, and `O(kr)` + sampling and storage (per batch member). + + #### Examples + + A single multi-variate Gaussian distribution is defined by a vector of means + of length `k`, and square root of the covariance `S = M + V D V^T`. Extra + leading dimensions, if provided, allow for batches. + + ```python + # Initialize a single 3-variate Gaussian with covariance square root + # S = M + V D V^T, where V D V^T is a matrix-rank 2 update. + mu = [1, 2, 3.] + diag_large = [1.1, 2.2, 3.3] + v = ... # shape 3 x 2 + diag_small = [4., 5.] + dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT( + mu, diag_large, v, diag_small=diag_small) + + # Evaluate this on an observation in R^3, returning a scalar. + dist.pdf([-1, 0, 1]) + + # Initialize a batch of two 3-variate Gaussians. This time, don't provide + # diag_small. This means S = M + V V^T. + mu = [[1, 2, 3], [11, 22, 33]] # shape 2 x 3 + diag_large = ... # shape 2 x 3 + v = ... # shape 2 x 3 x 1, a matrix-rank 1 update. + dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT( + mu, diag_large, v) + + # Evaluate this on a two observations, each in R^3, returning a length two + # tensor. + x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3. + dist.pdf(x) + ``` + + """ + + def __init__( + self, + mu, + diag_large, + v, + diag_small=None, + validate_args=True, + allow_nan_stats=False, + name="MultivariateNormalDiagPlusVDVT"): + """Multivariate Normal distributions on `R^k`. + + For every batch member, this distribution represents `k` random variables + `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix + `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]` + + The user initializes this class by providing the mean `mu`, and a + lightweight definition of `C`: + + ``` + C = SS^T = SS = (M + V D V^T) (M + V D V^T) + M is diagonal (k x k) + V = is shape (k x r), typically r << k + D = is diagonal (r x r), optional (defaults to identity). + ``` + + Args: + mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`, + `n >= 0`. The means. + diag_large: Optional rank `n + 1` `float` or `double` tensor, shape + `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`. + v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]` + `n >= 0`. Defines the matrix `V`. + diag_small: Rank `n + 1` `float` or `double` tensor, shape + `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default + is `None`, which means `D` will be the identity matrix. + validate_args: Whether to validate input with asserts. If `validate_args` + is `False`, + and the inputs are invalid, correct behavior is not guaranteed. + allow_nan_stats: `Boolean`, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + """ + m = operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args) + cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( + m, v, diag=diag_small, verify_pd=validate_args, + verify_shapes=validate_args) + super(MultivariateNormalDiagPlusVDVT, self).__init__( + mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args, + name=name) + + class MultivariateNormalCholesky(MultivariateNormalOperatorPD): """The multivariate normal distribution on `R^k`. @@ -496,14 +616,14 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD): #### Mathematical details - The PDF of this distribution is: + The Cholesky factor `chol` defines the covariance matrix: `C = chol chol^T`. + + The PDF of this distribution is then: ``` - f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu)) ``` - where `.` denotes the inner product on `R^k` and `^*` denotes transpose. - #### Examples A single multi-variate Gaussian distribution is defined by a vector of means @@ -546,20 +666,21 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD): """Multivariate Normal distributions on `R^k`. User must provide means `mu` and `chol` which holds the (batch) Cholesky - factors `S`, such that the covariance of each batch member is `S S^*`. + factors, such that the covariance of each batch member is `chol chol^T`. Args: mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape - `[N1,...,Nb, k, k]`. + `[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as + though it is zero), and the diagonal must be positive. validate_args: Whether to validate input with asserts. If `validate_args` - is `False`, - and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + is `False`, and the inputs are invalid, correct behavior is not + guaranteed. + allow_nan_stats: `Boolean`, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -582,14 +703,12 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD): #### Mathematical details - The PDF of this distribution is: + With `C = sigma`, the PDF of this distribution is: ``` - f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu)) ``` - where `.` denotes the inner product on `R^k` and `^*` denotes transpose. - #### Examples A single multi-variate Gaussian distribution is defined by a vector of means @@ -633,14 +752,14 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD): mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape - `[N1,...,Nb, k, k]`. + `[N1,...,Nb, k, k]`. Each batch member must be positive definite. validate_args: Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: `Boolean`, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py b/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py index ea5aa3c386..5e019355f7 100644 --- a/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py +++ b/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc +import six + from tensorflow.contrib.distributions.python.ops import operator_pd from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -26,11 +29,190 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -class OperatorPDSqrtDiag(operator_pd.OperatorPDBase): +@six.add_metaclass(abc.ABCMeta) +class OperatorPDDiagBase(operator_pd.OperatorPDBase): + """Base class for diagonal operators.""" + + def __init__(self, diag, verify_pd=True, name='OperatorPDDiagBase'): + self._verify_pd = verify_pd + self._name = name + with ops.name_scope(name): + with ops.op_scope([diag], 'init'): + self._diag = self._check_diag(diag) + + def _check_diag(self, diag): + """Verify that `diag` is positive.""" + diag = ops.convert_to_tensor(diag, name='diag') + if not self.verify_pd: + return diag + deps = [check_ops.assert_positive(diag)] + return control_flow_ops.with_dependencies(deps, diag) + + @property + def name(self): + """String name identifying this `Operator`.""" + return self._name + + @property + def verify_pd(self): + """Whether to verify that this `Operator` is positive definite.""" + return self._verify_pd + + @property + def dtype(self): + """Data type of matrix elements of `A`.""" + return self._diag.dtype + + @property + def inputs(self): + """Initialization arguments.""" + return [self._diag] + + def get_shape(self): + """`TensorShape` giving static shape.""" + # If d_shape = [5, 3], we return [5, 3, 3]. + d_shape = self._diag.get_shape() + return d_shape.concatenate(d_shape[-1:]) + + def _shape(self): + d_shape = array_ops.shape(self._diag) + k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1) + return array_ops.concat(0, (d_shape, [k])) + + @abc.abstractmethod + def _batch_log_det(self): + pass + + @abc.abstractmethod + def _inv_quadratic_form_on_vectors(self, x): + pass + + @abc.abstractmethod + def _batch_matmul(self, x, transpose_x=False): + pass + + @abc.abstractmethod + def _batch_sqrt_matmul(self, x, transpose_x=False): + pass + + @abc.abstractmethod + def _batch_solve(self, rhs): + pass + + @abc.abstractmethod + def _batch_sqrt_solve(self, rhs): + pass + + @abc.abstractmethod + def _to_dense(self): + pass + + @abc.abstractmethod + def _sqrt_to_dense(self): + pass + + @abc.abstractmethod + def _add_to_tensor(self, mat): + pass + + +class OperatorPDDiag(OperatorPDDiagBase): + """Class representing a (batch) of positive definite matrices `A`. + + This class provides access to functions of a batch of symmetric positive + definite (PD) matrices `A` in `R^{k x k}`. + + In this case, `A` is diagonal and is defined by a provided tensor `diag`, + `A_{ii} = diag[i]`. + + Determinants, solves, and storage are `O(k)`. + + In practice, this operator represents a (batch) matrix `A` with shape + `[N1,...,Nn, k, k]` for some `n >= 0`. The first `n` indices designate a + batch member. For every batch member `(i1,...,ib)`, `A[i1,...,ib, : :]` is + a `k x k` matrix. + + For example, + + ```python + distributions = tf.contrib.distributions + diag = [1.0, 2.0] + operator = OperatorPDDiag(diag) + operator.det() # ==> (1 * 2) + + # Compute the quadratic form x^T A^{-1} x for vector x. + x = [1.0, 2.0] + operator.inv_quadratic_form_on_vectors(x) + + # Matrix multiplication by the square root, S w, with A = S S^T. + # Recall A is diagonal, and so then is S, with S_{ij} = sqrt(A_{ij}). + # If w is iid normal, S w has covariance A. + w = [[1.0], + [2.0]] + operator.sqrt_matmul(w) + ``` + + The above three methods, `log_det`, `inv_quadratic_form_on_vectors`, and + `sqrt_matmul` provide "all" that is necessary to use a covariance matrix + in a multi-variate normal distribution. See the class + `MultivariateNormalDiag`. + """ + + def __init__(self, diag, verify_pd=True, name='OperatorPDDiag'): + """Initialize an OperatorPDDiag. + + Args: + diag: Shape `[N1,...,Nn, k]` positive tensor with `n >= 0`, `k >= 1`. + verify_pd: Whether to check `diag` is positive. + name: A name to prepend to all ops created by this class. + """ + super(OperatorPDDiag, self).__init__( + diag, verify_pd=verify_pd, name=name) + + def _batch_log_det(self): + return math_ops.reduce_sum( + math_ops.log(self._diag), reduction_indices=[-1]) + + def _inv_quadratic_form_on_vectors(self, x): + return self._iqfov_via_solve(x) + + def _batch_matmul(self, x, transpose_x=False): + if transpose_x: + x = array_ops.batch_matrix_transpose(x) + diag_mat = array_ops.expand_dims(self._diag, -1) + return diag_mat * x + + def _batch_sqrt_matmul(self, x, transpose_x=False): + if transpose_x: + x = array_ops.batch_matrix_transpose(x) + diag_mat = array_ops.expand_dims(self._diag, -1) + return math_ops.sqrt(diag_mat) * x + + def _batch_solve(self, rhs): + diag_mat = array_ops.expand_dims(self._diag, -1) + return rhs / diag_mat + + def _batch_sqrt_solve(self, rhs): + diag_mat = array_ops.expand_dims(self._diag, -1) + return rhs / math_ops.sqrt(diag_mat) + + def _to_dense(self): + return array_ops.batch_matrix_diag(self._diag) + + def _sqrt_to_dense(self): + return array_ops.batch_matrix_diag(math_ops.sqrt(self._diag)) + + def _add_to_tensor(self, mat): + mat_diag = array_ops.batch_matrix_diag_part(mat) + new_diag = self._diag + mat_diag + return array_ops.batch_matrix_set_diag(mat, new_diag) + + +class OperatorPDSqrtDiag(OperatorPDDiagBase): """Class representing a (batch) of positive definite matrices `A`. This class provides access to functions of a batch of symmetric positive - definite (PD) matrices `A` in `R^{k x k}` defined by their their square root, + definite (PD) matrices `A` in `R^{k x k}` defined by their square root, `S`, such that `A = SS^T`. In this case, `S` is diagonal and is defined by a provided tensor `diag`, @@ -75,58 +257,17 @@ class OperatorPDSqrtDiag(operator_pd.OperatorPDBase): verify_pd: Whether to check `diag` is positive. name: A name to prepend to all ops created by this class. """ - self._verify_pd = verify_pd - self._name = name - with ops.name_scope(name): - with ops.op_scope([diag], 'init'): - self._diag = self._check_diag(diag) - - def _check_diag(self, diag): - """Verify that `diag` is positive.""" - diag = ops.convert_to_tensor(diag, name='diag') - if not self.verify_pd: - return diag - deps = [check_ops.assert_positive(diag)] - return control_flow_ops.with_dependencies(deps, diag) - - @property - def name(self): - """String name identifying this `Operator`.""" - return self._name - - @property - def verify_pd(self): - """Whether to verify that this `Operator` is positive definite.""" - return self._verify_pd - - @property - def dtype(self): - """Data type of matrix elements of `A`.""" - return self._diag.dtype + super(OperatorPDSqrtDiag, self).__init__( + diag, verify_pd=verify_pd, name=name) def _batch_log_det(self): return 2 * math_ops.reduce_sum( math_ops.log(self._diag), reduction_indices=[-1]) - @property - def inputs(self): - """List of tensors that were provided as initialization inputs.""" - return [self._diag] - def _inv_quadratic_form_on_vectors(self, x): # This Operator is defined in terms of diagonal entries of the sqrt. return self._iqfov_via_sqrt_solve(x) - def get_shape(self): - """`TensorShape` giving static shape.""" - d_shape = self._diag.get_shape() - return d_shape.concatenate(d_shape[-1:]) - - def _shape(self): - d_shape = array_ops.shape(self._diag) - k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1) - return array_ops.concat(0, (d_shape, [k])) - def _batch_matmul(self, x, transpose_x=False): if transpose_x: x = array_ops.batch_matrix_transpose(x) diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py b/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py new file mode 100644 index 0000000000..f1b750351c --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py @@ -0,0 +1,207 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Identity operator in `R^k`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.contrib.distributions.python.ops import operator_pd +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops + + +class OperatorPDIdentity(operator_pd.OperatorPDBase): + """Identity operator in `R^k`: `Ax = x`. + + This provides an efficient implementation of the identity as an `OperatorPD`. + Storage, solves, and matmul are all `O(1)`, independent of batch size. + + In order to be a drop-in replacement for other operators, shape and dtype + of arguments (e.g. to `matmul`) are checked statically as though this operator + was an instantiated matrix. + + Dynamic shape checks of arguments are not done since that could impede + performance. + """ + + def __init__(self, shape, dtype, verify_pd=True, name='OperatorPDIdentity'): + """Initialize an `OperatorPDIdentity`. + + Args: + shape: `int32` rank 1 `Tensor` of length at least 2, and with the last + two entries equal (since this is a square matrix). + dtype: Data type of the matrix that this operator represents. + verify_pd: `Boolean`, if `True`, asserts are added to the initialization + args to ensure they define this operator as a square (batch) matrix. + name: Name to prepend to `Ops`. + """ + + # Grab static shape if available now. + with ops.name_scope(name): + with ops.op_scope([shape], 'init'): + self._dtype = dtypes.as_dtype(dtype) + self._verify_pd = verify_pd + self._name = name + + # Store the static shape (if possible) right now before adding the + # asserts, since the asserts prevent .constant_value from working. + shape = ops.convert_to_tensor(shape, name='shape') + self._get_shape = tensor_shape.TensorShape( + tensor_util.constant_value(shape)) + self._shape_arg = self._check_shape(shape) + + def _check_shape(self, shape): + """Check that the init arg `shape` defines a valid operator.""" + shape = ops.convert_to_tensor(shape, name='shape') + if not self._verify_pd: + return shape + + # Further checks are equivalent to verification that this is positive + # definite. Why? Because the further checks simply check that this is a + # square matrix, and combining the fact that this is square (and thus maps + # a vector space R^k onto itself), with the behavior of .matmul(), this must + # be the identity operator. + rank = array_ops.size(shape) + assert_matrix = check_ops.assert_less_equal(2, rank) + with ops.control_dependencies([assert_matrix]): + last_dim = array_ops.gather(shape, rank - 1) + second_to_last_dim = array_ops.gather(shape, rank - 2) + assert_square = check_ops.assert_equal(last_dim, second_to_last_dim) + return control_flow_ops.with_dependencies([assert_matrix, assert_square], + shape) + + def _check_x(self, x): + """Static check that the argument `x` is proper `shape`, `dtype`.""" + # x is a typical argument e.g. to matmul or solve. In both cases, x should + # have the same type/shape since this is a square matrix. These checks are + # ususally not needed since we ususally have some tensor backing this + # distribution, and the calls to tf.matmul do a shape/type check. + # + # Static checks only for efficiency, the identity should be fast. + # + # Why check at all? Because we want this operator to be swappable for a + # real Operator. + if self.dtype != x.dtype: + raise TypeError( + 'Expected argument "x" to have same dtype as this operator (%s). ' + 'Found: %s' % (self.dtype, x.dtype)) + + x_shape = x.get_shape() + self_shape = self.get_shape() + found_msg = ( + 'Found: operator.shape = %s, x.shape = %s' % (self_shape, x_shape)) + if x_shape.ndims is not None and self_shape.ndims is not None: + if x_shape.ndims != self_shape.ndims: + raise ValueError( + 'Expected argument "x" to have same tensor rank as this operator. ' + + found_msg) + if x_shape.is_fully_defined() and self_shape.is_fully_defined(): + if x_shape[-2] != self_shape[-1]: + raise ValueError( + 'Incompatible shapes for matrix-matrix operation. ' + found_msg) + + @property + def name(self): + """String name identifying this `Operator`.""" + return self._name + + @property + def verify_pd(self): + """Whether to verify that this `Operator` is positive definite.""" + return self._verify_pd + + @property + def dtype(self): + """Data type of matrix elements of `A`.""" + return self._dtype + + def _add_to_tensor(self, mat): + # Add to a tensor in O(k) time! + mat_diag = array_ops.batch_matrix_diag_part(mat) + new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag + return array_ops.batch_matrix_set_diag(mat, new_diag) + + def _inv_quadratic_form_on_vectors(self, x): + self._check_x(x) + return self._iqfov_via_sqrt_solve(x) + + @property + def inputs(self): + """List of tensors that were provided as initialization inputs.""" + return [self._shape] + + def get_shape(self): + """Static `TensorShape` of entire operator. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nn, k, k]`, then this returns + `TensorShape([N1,...,Nn, k, k])` + + Returns: + `TensorShape`, statically determined, may be undefined. + """ + return self._get_shape + + def _shape(self): + return self._shape_arg + + def _det(self): + det = array_ops.ones(self.batch_shape(), dtype=self.dtype) + det.set_shape(self.get_batch_shape()) + return det + + def _batch_log_det(self): + log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype) + log_det.set_shape(self.get_batch_shape()) + return log_det + + def _batch_sqrt_log_det(self): + s_log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype) + s_log_det.set_shape(self.get_batch_shape()) + return s_log_det + + def _batch_matmul(self, x, transpose_x=False): + if transpose_x: + x = array_ops.batch_matrix_transpose(x) + self._check_x(x) + return x + + def _batch_sqrt_matmul(self, x, transpose_x=False): + return self._batch_matmul(x, transpose_x=transpose_x) + + def _batch_solve(self, rhs): + self._check_x(rhs) + return rhs + + def _batch_sqrt_solve(self, rhs): + self._check_x(rhs) + return rhs + + def _to_dense(self): + diag = array_ops.ones(self.vector_shape(), dtype=self.dtype) + dense = array_ops.batch_matrix_diag(diag) + dense.set_shape(self.get_shape()) + return dense + + def _sqrt_to_dense(self): + return self.to_dense() diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py b/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py new file mode 100644 index 0000000000..3c934e721c --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py @@ -0,0 +1,475 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operator defined: `A = SS^T` where `S = M + VDV^T`, for `OperatorPD` `M`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import operator_pd +from tensorflow.contrib.distributions.python.ops import operator_pd_diag +from tensorflow.contrib.distributions.python.ops import operator_pd_identity +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + + +class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase): + r"""Operator defined by `A=SS^T`, where `S = M + VDV^T` for `OperatorPD` `M`. + + This provides efficient low-rank updates of arbitrary `OperatorPD`. + + Some math: + + Given positive definite operator representing positive definite (batch) matrix + `M` in `R^{k x k}`, diagonal matrix `D` in `R^{r x r}`, and low rank `V` in + `R^{k x r}` this class represents the batch matrix `A`, defined by its square + root `S` as follows: + + ``` + A = SS^T, where + S := M + VDV^T + ``` + + Defining an operator in terms of its square root means that + `A_{ij} = S_i S_j^T`, where `S_i` is the ith row of `S`. The update + `VDV^T` has `ij` coordinate equal to `sum_k V_{ik} D_{kk} V_{jk}`. + + Computational efficiency: + + Defining `A` via its square root eliminates the need to compute the square + root. + + Performance depends on the operator representing `M`, the batch size `B`, and + the width of the matrix being multiplied, or systems being solved `L`. + + Since `V` is rank `r`, the update adds + + * `O(B L k r)` to matmul, which requires a call to `M.matmul`. + * `O(B L r^3)` to solves, which require a call to `M.solve` as well as the + solution to a batch of rank `r` systems. + * `O(B r^3)` to determinants, which require a call to `M.solve` as well as the + solution to a batch of rank `r` systems. + + The rank `r` solve and determinant are both done through a Cholesky + factorization, thus some computation is shared. + + See + https://en.wikipedia.org/wiki/Woodbury_matrix_identity + https://en.wikipedia.org/wiki/Matrix_determinant_lemma + """ + + # Note that diag must be nonsingular to use Woodbury lemma, and must be + # positive def to use a Cholesky factorization, so we enforce that here. + def __init__(self, + operator, + v, + diag=None, + verify_pd=True, + verify_shapes=True, + name='OperatorPDSqrtVDVTUpdate'): + """Initialize an `OperatorPDSqrtVDVTUpdate`. + + Args: + operator: Subclass of `OperatorPDBase`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as + `operator`, and last two dimensions of shape `(k, r)`. + diag: Optional `Tensor` defining batch vector of same `dtype` and + `batch_shape` as `operator`, and last dimension of size `r`. If `None`, + the update becomes `VV^T` rather than `VDV^T`. + verify_pd: `Boolean`. If `True`, add asserts that `diag > 0`, which, + along with the positive definiteness of `operator`, is sufficient to + make the resulting operator positive definite. + verify_shapes: `Boolean`. If `True`, check that `operator`, `v`, and + `diag` have compatible shapes. + name: A name to prepend to `Op` names. + """ + + if not isinstance(operator, operator_pd.OperatorPDBase): + raise TypeError('operator was not instance of OperatorPDBase.') + + with ops.name_scope(name): + with ops.op_scope(operator.inputs + [v, diag], 'init'): + self._operator = operator + self._v = ops.convert_to_tensor(v, name='v') + self._verify_pd = verify_pd + self._verify_shapes = verify_shapes + self._name = name + + # This operator will be PD so long as the diag is PSD, but Woodbury + # and determinant lemmas require diag to be PD. So require diag PD + # whenever we ask to "verify_pd". + if diag is not None: + self._diag = ops.convert_to_tensor(diag, name='diag') + self._diag_operator = operator_pd_diag.OperatorPDDiag( + diag, verify_pd=self.verify_pd) + # No need to verify that the inverse of a PD is PD. + self._diag_inv_operator = operator_pd_diag.OperatorPDDiag( + 1 / self._diag, verify_pd=False) + else: + self._diag = None + self._diag_operator = self._get_identity_operator(self._v) + self._diag_inv_operator = self._diag_operator + + self._check_types(operator, self._v, self._diag) + # Always check static. + checked = self._check_shapes_static(operator, self._v, self._diag) + if not checked and self._verify_shapes: + self._v, self._diag = self._check_shapes_dynamic( + operator, self._v, self._diag) + + def _get_identity_operator(self, v): + """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`.""" + with ops.op_scope([v], 'get_identity_operator'): + if v.get_shape().is_fully_defined(): + v_shape = v.get_shape().as_list() + v_batch_shape = v_shape[:-2] + r = v_shape[-1] + id_shape = v_batch_shape + [r, r] + else: + v_shape = array_ops.shape(v) + v_rank = array_ops.rank(v) + v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2]) + r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v + id_shape = array_ops.concat(0, (v_batch_shape, [r, r])) + return operator_pd_identity.OperatorPDIdentity( + id_shape, v.dtype, verify_pd=self._verify_pd) + + def _check_types(self, operator, v, diag): + def msg(): + string = ( + 'dtypes must match: Found operator.dtype = %s, v.dtype = %s' + % (operator.dtype, v.dtype)) + return string + + if operator.dtype != v.dtype: + raise TypeError(msg()) + if diag is not None: + if diag.dtype != v.dtype: + raise TypeError('%s, diag.dtype = %s' % (msg(), diag.dtype)) + + def _check_shapes_static(self, operator, v, diag): + """True if they are compatible. Raise if not. False if could not check.""" + def msg(): + # Error message when shapes don't match. + string = ' Found: operator.shape = %s, v.shape = %s' % (s_op, s_v) + if diag is not None: + string += ', diag.shape = ' % s_d + return string + + s_op = operator.get_shape() + s_v = v.get_shape() + + # If everything is not fully defined, return False because we couldn't check + if not (s_op.is_fully_defined() and s_v.is_fully_defined()): + return False + if diag is not None: + s_d = diag.get_shape() + if not s_d.is_fully_defined(): + return False + + # Now perform the checks, raising ValueError if they fail. + + # Check tensor rank. + if s_v.ndims != s_op.ndims: + raise ValueError('v should have same rank as operator' + msg()) + if diag is not None: + if s_d.ndims != s_op.ndims - 1: + raise ValueError('diag should have rank 1 less than operator' + msg()) + + # Check batch shape + if s_v[:-2] != s_op[:-2]: + raise ValueError('v and operator should have same batch shape' + msg()) + if diag is not None: + if s_d[:-1] != s_op[:-2]: + raise ValueError( + 'diag and operator should have same batch shape' + msg()) + + # Check event shape + if s_v[-2] != s_op[-1]: + raise ValueError( + 'v and operator should be compatible for matmul' + msg()) + if diag is not None: + if s_d[-1] != s_v[-1]: + raise ValueError('diag and v should have same last dimension' + msg()) + + return True + + def _check_shapes_dynamic(self, operator, v, diag): + """Return (v, diag) with Assert dependencies, which check shape.""" + checks = [] + with ops.op_scope([operator, v, diag], 'check_shapes'): + s_v = array_ops.shape(v) + r_op = operator.rank() + r_v = array_ops.rank(v) + if diag is not None: + s_d = array_ops.shape(diag) + r_d = array_ops.rank(diag) + + # Check tensor rank. + checks.append(check_ops.assert_rank(v, r_op)) + if diag is not None: + checks.append(check_ops.assert_rank(diag, r_op - 1)) + + # Check batch shape + checks.append(check_ops.assert_equal( + operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2]))) + if diag is not None: + checks.append(check_ops.assert_equal( + operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1]))) + + # Check event shape + checks.append(check_ops.assert_equal( + operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2))) + if diag is not None: + checks.append(check_ops.assert_equal( + array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1))) + + v = control_flow_ops.with_dependencies(checks, v) + if diag is not None: + diag = control_flow_ops.with_dependencies(checks, diag) + return v, diag + + @property + def name(self): + """String name identifying this `Operator`.""" + return self._name + + @property + def verify_pd(self): + """Whether to verify that this `Operator` is positive definite.""" + return self._verify_pd + + @property + def dtype(self): + """Data type of matrix elements of `A`.""" + return self._v.dtype + + def _inv_quadratic_form_on_vectors(self, x): + return self._iqfov_via_sqrt_solve(x) + + @property + def inputs(self): + """List of tensors that were provided as initialization inputs.""" + return self._operator.inputs + self._diag_operator.inputs + [self._v] + + def get_shape(self): + """Static `TensorShape` of entire operator. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nn, k, k]`, then this returns + `TensorShape([N1,...,Nn, k, k])` + + Returns: + `TensorShape`, statically determined, may be undefined. + """ + return self._operator.get_shape() + + def _shape(self): + return self._operator.shape() + + def _det(self): + return math_ops.exp(self.log_det()) + + def _batch_log_det(self): + return 2 * self._batch_sqrt_log_det() + + def _log_det(self): + return 2 * self._sqrt_log_det() + + def _sqrt_log_det(self): + # The matrix determinant lemma states: + # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M) + # = det(C) * det(D) * det(M) + # + # Here we compute the Cholesky factor of "C", then pass the result on. + diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance( + batch_mode=False)) + return self._sqrt_log_det_core(diag_chol_c) + + def _batch_sqrt_log_det(self): + # Here we compute the Cholesky factor of "C", then pass the result on. + diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance( + batch_mode=True)) + return self._sqrt_log_det_core(diag_chol_c) + + def _chol_capacitance(self, batch_mode): + """Cholesky factorization of the capacitance term.""" + # Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes + # known as the "capacitance" matrix. + + # self._operator will use batch if need be. Automatically. We cannot force + # that here. + # M^{-1} V + minv_v = self._operator.solve(self._v) + # V^T M^{-1} V + if batch_mode: + vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True) + else: + vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True) + + # D^{-1} + V^T M^{-1} V + capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v) + # Cholesky[D^{-1} + V^T M^{-1} V] + if batch_mode: + return linalg_ops.batch_cholesky(capacitance) + else: + return linalg_ops.cholesky(capacitance) + + def _sqrt_log_det_core(self, diag_chol_c): + """Finish computation of Sqrt[Log[Det]].""" + # Complete computation of ._log_det and ._batch_log_det, after the initial + # Cholesky factor has been taken with the appropriate batch/non-batch method + + # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M) + # = det(C) * det(D) * det(M) + # Multiply by 2 here because this is the log-det of the Cholesky factor of C + log_det_c = 2 * math_ops.reduce_sum( + math_ops.log(diag_chol_c), + reduction_indices=[-1]) + # Add together to get Log[det(M + VDV^T)], the Log-det of the updated square + # root. + log_det_updated_sqrt = ( + log_det_c + self._diag_operator.log_det() + self._operator.log_det()) + return log_det_updated_sqrt + + def _batch_matmul(self, x, transpose_x=False): + # Since the square root is PD, it is symmetric, and so A = SS^T = SS. + s_x = self._batch_sqrt_matmul(x, transpose_x=transpose_x) + return self._batch_sqrt_matmul(s_x) + + def _matmul(self, x, transpose_x=False): + # Since the square root is PD, it is symmetric, and so A = SS^T = SS. + s_x = self._sqrt_matmul(x, transpose_x=transpose_x) + return self._sqrt_matmul(s_x) + + def _batch_sqrt_matmul(self, x, transpose_x=False): + v = self._v + m = self._operator + d = self._diag_operator + # The operators call the appropriate matmul/batch_matmul automatically. We + # cannot override. + # batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to + # transpose the left and right. + mx = m.matmul(x, transpose_x=transpose_x) + vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x) + d_vt_x = d.matmul(vt_x) + v_d_vt_x = math_ops.batch_matmul(v, d_vt_x) + + return mx + v_d_vt_x + + def _sqrt_matmul(self, x, transpose_x=False): + v = self._v + m = self._operator + d = self._diag_operator + # The operators call the appropriate matmul/batch_matmul automatically. We + # cannot override. + # matmul is defined as: a * b, so transpose_a, transpose_b are used. + # transpose the left and right. + mx = m.matmul(x, transpose_x=transpose_x) + vt_x = math_ops.matmul(v, x, transpose_a=True, transpose_b=transpose_x) + d_vt_x = d.matmul(vt_x) + v_d_vt_x = math_ops.matmul(v, d_vt_x) + + return mx + v_d_vt_x + + def _solve(self, rhs): + # This operator represents A = SS^T, but S is symmetric, so A = SS, + # which means A^{-1} = S^{-1}S^{-2} + # S^{-1} rhs + sqrtinv_rhs = self._sqrt_solve(rhs) + return self._sqrt_solve(sqrtinv_rhs) + + def _batch_solve(self, rhs): + sqrtinv_rhs = self._batch_sqrt_solve(rhs) + return self._batch_sqrt_solve(sqrtinv_rhs) + + def _sqrt_solve(self, rhs): + # Recall the square root of this operator is M + VDV^T. + # The Woodbury formula gives: + # (M + VDV^T)^{-1} + # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1} + # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1} + # where C is the capacitance matrix. + # TODO(jvdillon) Determine if recursively applying rank-1 updates is more + # efficient. May not be possible because a general n x n matrix can be + # represeneted as n rank-1 updates, and solving with this matrix is always + # done in O(n^3) time. + m = self._operator + v = self._v + cchol = self._chol_capacitance(batch_mode=False) + + # The operators will use batch/singleton mode automatically. We don't + # override. + # M^{-1} rhs + minv_rhs = m.solve(rhs) + # V^T M^{-1} rhs + vt_minv_rhs = math_ops.matmul(v, minv_rhs, transpose_a=True) + # C^{-1} V^T M^{-1} rhs + cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs) + # V C^{-1} V^T M^{-1} rhs + v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs) + # M^{-1} V C^{-1} V^T M^{-1} rhs + minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs) + + # M^{-1} - M^{-1} V C^{-1} V^T M^{-1} + return minv_rhs - minv_v_cinv_vt_minv_rhs + + def _batch_sqrt_solve(self, rhs): + # Recall the square root of this operator is M + VDV^T. + # The Woodbury formula gives: + # (M + VDV^T)^{-1} + # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1} + # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1} + # where C is the capacitance matrix. + m = self._operator + v = self._v + cchol = self._chol_capacitance(batch_mode=True) + + # The operators will use batch/singleton mode automatically. We don't + # override. + # M^{-1} rhs + minv_rhs = m.solve(rhs) + # V^T M^{-1} rhs + vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True) + # C^{-1} V^T M^{-1} rhs + cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs) + # V C^{-1} V^T M^{-1} rhs + v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs) + # M^{-1} V C^{-1} V^T M^{-1} rhs + minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs) + + # M^{-1} - M^{-1} V C^{-1} V^T M^{-1} + return minv_rhs - minv_v_cinv_vt_minv_rhs + + def _to_dense(self): + sqrt = self.sqrt_to_dense() + return math_ops.batch_matmul(sqrt, sqrt, adj_y=True) + + def _sqrt_to_dense(self): + v = self._v + d = self._diag_operator + m = self._operator + + d_vt = d.matmul(v, transpose_x=True) + # Batch op won't be efficient for singletons. Currently we don't break + # to_dense into batch/singleton methods. + v_d_vt = math_ops.batch_matmul(v, d_vt) + m_plus_v_d_vt = m.to_dense() + v_d_vt + return m_plus_v_d_vt |