aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2016-07-27 07:29:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-27 08:32:47 -0700
commit6339f8348cb19b9393447b2a59352f44a2c1a518 (patch)
treec48b903dd52778a9e471f0893775d05ba7035dc0
parentbcd81a4ceb171de978c8abc4379056a96ebee370 (diff)
OperatorPDVDVTSqrtUpdate, MultivariateNormalDiagPlusVDVT. Define a multivariate normal by the sqrt of covariance: S = M + V D V^T, M is diagonal.
Change: 128587968
-rw-r--r--tensorflow/contrib/distributions/BUILD22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py55
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py39
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py115
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py273
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn.py193
-rw-r--r--tensorflow/contrib/distributions/python/ops/operator_pd_diag.py231
-rw-r--r--tensorflow/contrib/distributions/python/ops/operator_pd_identity.py207
-rw-r--r--tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py475
9 files changed, 1525 insertions, 85 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 4027ae3ef8..868e1947dd 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -54,6 +54,28 @@ cuda_py_tests(
],
)
+cuda_py_tests(
+ name = "operator_pd_identity_test",
+ size = "small",
+ srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
+ name = "operator_pd_vdvt_update_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
py_library(
name = "distributions_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
index d8e75e4be2..a985477242 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
@@ -117,6 +117,61 @@ class MultivariateNormalDiagTest(tf.test.TestCase):
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
+class MultivariateNormalDiagPlusVDVTTest(tf.test.TestCase):
+ """Well tested because this is a simple override of the base class."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testMean(self):
+ mu = [-1.0, 1.0]
+ diag_large = [1.0, 5.0]
+ v = [[2.0], [3.0]]
+ diag_small = [3.0]
+ with self.test_session():
+ dist = distributions.MultivariateNormalDiagPlusVDVT(
+ mu, diag_large, v, diag_small=diag_small)
+ self.assertAllEqual(mu, dist.mean().eval())
+
+ def testNonmatchingMuAndSigmaDimensionFailsStatic(self):
+ mu = self._rng.rand(2)
+ # With this diag_large and v, the covariance is 3 x 3
+ diag_large = self._rng.rand(3)
+ v = self._rng.rand(3, 2) # v works with diag_large.
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, "shape.*should match"):
+ distributions.MultivariateNormalDiagPlusVDVT(
+ mu, diag_large, v)
+
+ def testNonmatchingMuDiagDimensionsFailsDynamic(self):
+ mu = self._rng.rand(2)
+ # With this diag_large and v, the covariance is 3 x 3
+ diag_large = self._rng.rand(3)
+ v = self._rng.rand(3, 2) # v works with diag_large.
+
+ with self.test_session():
+ mu_ph = tf.placeholder(tf.float32, name="mu_ph")
+ v_ph = tf.placeholder(tf.float32, name="v_ph")
+ diag_ph = tf.placeholder(tf.float32, name="diag_ph")
+ dist = distributions.MultivariateNormalDiagPlusVDVT(
+ mu_ph, diag_ph, v_ph)
+ with self.assertRaisesOpError("mu.*cov.*shape"):
+ dist.mean().eval(feed_dict={mu_ph: mu, diag_ph: diag_large, v_ph: v})
+
+ def testSample(self):
+ mu = [-1.0, 1.0]
+ diag_large = [1.0, 0.5]
+ v = [[0.2], [0.3]]
+ with self.test_session():
+ dist = distributions.MultivariateNormalDiagPlusVDVT(mu, diag_large, v)
+
+ samps = dist.sample_n(1000, seed=0).eval()
+ cov_mat = dist.sigma.eval()
+
+ self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
+ self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
+
+
class MultivariateNormalCholeskyTest(tf.test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py
index 3a0f6e1d5a..c11b0357e7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_diag_test.py
@@ -17,14 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import numpy as np
+import six
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_test_util
-class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
+@six.add_metaclass(abc.ABCMeta)
+class OperatorPDDiagBaseTest(object):
def setUp(self):
self._rng = np.random.RandomState(42)
@@ -32,8 +35,14 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
def _random_pd_diag(self, diag_shape):
return self._rng.rand(*diag_shape) + 0.1
+ @abc.abstractmethod
def _diag_to_matrix(self, diag):
- return tf.batch_matrix_diag(diag**2).eval()
+ pass
+
+ @abc.abstractproperty
+ def operator_class(self):
+ # Return the operator class that this tests.
+ pass
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
# Create a diagonal matrix explicitly.
@@ -46,7 +55,7 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
# The diag is the square root.
diag = self._random_pd_diag(diag_shape).astype(dtype)
mat = self._diag_to_matrix(diag).astype(dtype)
- operator = operator_pd_diag.OperatorPDSqrtDiag(diag)
+ operator = self.operator_class(diag)
return operator, mat
@@ -66,5 +75,29 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
operator.to_dense().eval() # Should not raise
+class OperatorPDDiagTest(
+ OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
+ """Most tests done in the base classes."""
+
+ def _diag_to_matrix(self, diag):
+ return tf.batch_matrix_diag(diag).eval()
+
+ @property
+ def operator_class(self):
+ return operator_pd_diag.OperatorPDDiag
+
+
+class OperatorPDSqrtDiagTest(
+ OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
+ """Most tests done in the base classes."""
+
+ def _diag_to_matrix(self, diag):
+ return tf.batch_matrix_diag(diag**2).eval()
+
+ @property
+ def operator_class(self):
+ return operator_pd_diag.OperatorPDSqrtDiag
+
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py
new file mode 100644
index 0000000000..7f411105fb
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py
@@ -0,0 +1,115 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.distributions.python.ops import operator_pd_identity
+from tensorflow.contrib.distributions.python.ops import operator_test_util
+
+distributions = tf.contrib.distributions
+
+
+class OperatorPDIdentityTest(operator_test_util.OperatorPDDerivedClassTest):
+ """Most tests done in the base class."""
+
+ def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
+ # Build an identity matrix with right shape and dtype.
+ # Build an operator that should act the same way.
+ batch_shape = list(batch_shape)
+ diag_shape = batch_shape + [k]
+ matrix_shape = batch_shape + [k, k]
+ diag = tf.ones(diag_shape, dtype=dtype)
+ identity_matrix = tf.batch_matrix_diag(diag)
+ operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype)
+ return operator, identity_matrix.eval()
+
+ def test_bad_dtype_args_raise(self):
+ dtype = np.float32
+ batch_shape = [2, 3]
+ k = 4
+ with self.test_session():
+ operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
+
+ x_good_shape = batch_shape + [k, 5]
+ x_good = self._rng.randn(*x_good_shape).astype(dtype)
+ x_bad = x_good.astype(np.float64)
+
+ operator.matmul(x_good).eval() # Should not raise.
+
+ with self.assertRaisesRegexp(TypeError, 'dtype'):
+ operator.matmul(x_bad)
+
+ with self.assertRaisesRegexp(TypeError, 'dtype'):
+ operator.solve(x_bad)
+
+ with self.assertRaisesRegexp(TypeError, 'dtype'):
+ operator.sqrt_solve(x_bad)
+
+ def test_bad_rank_args_raise(self):
+ # Prepend a singleton dimension, changing the rank of 'x', but not the size.
+ dtype = np.float32
+ batch_shape = [2, 3]
+ k = 4
+ with self.test_session():
+ operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
+
+ x_good_shape = batch_shape + [k, 5]
+ x_good = self._rng.randn(*x_good_shape).astype(dtype)
+ x_bad = x_good.reshape(1, 2, 3, 4, 5)
+
+ operator.matmul(x_good).eval() # Should not raise.
+
+ with self.assertRaisesRegexp(ValueError, 'tensor rank'):
+ operator.matmul(x_bad)
+
+ with self.assertRaisesRegexp(ValueError, 'tensor rank'):
+ operator.solve(x_bad)
+
+ with self.assertRaisesRegexp(ValueError, 'tensor rank'):
+ operator.sqrt_solve(x_bad)
+
+ def test_incompatible_shape_args_raise(self):
+ # Test shapes that are the same rank but incompatible for matrix
+ # multiplication.
+ dtype = np.float32
+ batch_shape = [2, 3]
+ k = 4
+ with self.test_session():
+ operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
+
+ x_good_shape = batch_shape + [k, 5]
+ x_good = self._rng.randn(*x_good_shape).astype(dtype)
+ x_bad_shape = batch_shape + [5, k]
+ x_bad = x_good.reshape(*x_bad_shape)
+
+ operator.matmul(x_good).eval() # Should not raise.
+
+ with self.assertRaisesRegexp(ValueError, 'Incompatible'):
+ operator.matmul(x_bad)
+
+ with self.assertRaisesRegexp(ValueError, 'Incompatible'):
+ operator.solve(x_bad)
+
+ with self.assertRaisesRegexp(ValueError, 'Incompatible'):
+ operator.sqrt_solve(x_bad)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py
new file mode 100644
index 0000000000..66d54561b5
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_vdvt_update_test.py
@@ -0,0 +1,273 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.distributions.python.ops import operator_pd_full
+from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
+from tensorflow.contrib.distributions.python.ops import operator_test_util
+
+distributions = tf.contrib.distributions
+
+
+class OperatorPDSqrtVDVTUpdateTest(
+ operator_test_util.OperatorPDDerivedClassTest):
+ """Most tests done in the base class."""
+ _diag_is_none = False
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def _random_pd_matrix(self, shape):
+ # With probability 1 this is positive definite.
+ sqrt = self._rng.randn(*shape)
+ mat = tf.batch_matmul(sqrt, sqrt, adj_y=True)
+ return mat.eval()
+
+ def _random_v_and_diag(self, mat_shape, v_matrix_rank):
+ # Get the necessary elements to make the sqrt update.
+ mat_shape = list(mat_shape)
+ batch_shape = mat_shape[:-2]
+ diag_shape = mat_shape[:-2] + [v_matrix_rank]
+ k = mat_shape[-1]
+ assert k == mat_shape[-2], 'Must be a square matrix'
+ v_shape = batch_shape + [k, v_matrix_rank]
+ v = self._rng.randn(*v_shape) # anything goes with "v"!
+
+ if self._diag_is_none:
+ diag = None
+ else:
+ diag = self._rng.rand(*diag_shape) + 0.1 # Positive diag!
+ return v, diag
+
+ def _updated_mat(self, mat, v, diag):
+ # Get dense matrix defined by its square root, which is an update of `mat`:
+ # A = (mat + v D v^T) (mat + v D v^T)^T
+ # D is the diagonal matrix with `diag` on the diagonal.
+
+ # If diag is None, then it defaults to the identity matrix, so DV^T = V^T
+ if diag is None:
+ diag_vt = tf.batch_matrix_transpose(v)
+ else:
+ diag_mat = tf.batch_matrix_diag(diag)
+ diag_vt = tf.batch_matmul(diag_mat, v, adj_y=True)
+
+ v_diag_vt = tf.batch_matmul(v, diag_vt)
+ sqrt = mat + v_diag_vt
+ a = tf.batch_matmul(sqrt, sqrt, adj_y=True)
+ return a.eval()
+
+ def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
+ """This method is called by base class, enabling many standard tests."""
+ # Create a matrix then explicitly update it with v and diag.
+ # Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag
+ # The operator should have the same behavior.
+ #
+ # The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which
+ # case it will be 1 as well.
+ if k == 1:
+ v_matrix_rank = k
+ else:
+ v_matrix_rank = k // 2
+ mat_shape = list(batch_shape) + [k, k]
+ mat = self._random_pd_matrix(mat_shape)
+ v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
+
+ # Set dtypes
+ mat = mat.astype(dtype)
+ v = v.astype(dtype)
+ if diag is not None:
+ diag = diag.astype(dtype)
+
+ # The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T
+ # Our final updated operator should behave like this.
+ updated_mat = self._updated_mat(mat, v, diag)
+
+ # Represents the matrix: `mat`, before updating.
+ # This is the Operator that we will update.
+ o_made_with_mat = operator_pd_full.OperatorPDFull(mat)
+
+ # Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T,
+ # achieved by updating the operator "o_made_with_mat".
+ # This is the operator we're testing.
+ operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ o_made_with_mat, v, diag)
+
+ return operator, updated_mat
+
+ def test_to_dense_placeholder(self):
+ # Test simple functionality when the inputs are placeholders.
+ mat_shape = [3, 3]
+ v_matrix_rank = 2
+ with self.test_session():
+ # Make an OperatorPDFull with a matrix placeholder.
+ mat_ph = tf.placeholder(tf.float64, name='mat_ph')
+ mat = self._random_pd_matrix(mat_shape)
+ o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph)
+
+ # Make the placeholders and arrays for the updated operator.
+ v_ph = tf.placeholder(tf.float64, name='v_ph')
+ v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
+ if self._diag_is_none:
+ diag_ph = None
+ feed_dict = {v_ph: v, mat_ph: mat}
+ else:
+ diag_ph = tf.placeholder(tf.float64, name='diag_ph')
+ feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat}
+
+ # Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders.
+ operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ o_made_with_mat, v_ph, diag=diag_ph)
+
+ # Should not fail
+ operator.to_dense().eval(feed_dict=feed_dict)
+ operator.log_det().eval(feed_dict=feed_dict)
+
+ def test_operator_not_subclass_of_operator_pd_raises(self):
+ # We enforce that `operator` is an `OperatorPDBase`.
+ with self.test_session():
+ v, diag = self._random_v_and_diag((3, 3), 2)
+ operator_m = 'I am not a subclass of OperatorPDBase'
+
+ with self.assertRaisesRegexp(TypeError, 'not instance'):
+ operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
+
+ def test_non_pos_def_diag_raises(self):
+ if self._diag_is_none:
+ return
+ # We enforce that the diag is positive definite.
+ with self.test_session():
+ matrix_shape = (3, 3)
+ v_rank = 2
+ v, diag = self._random_v_and_diag(matrix_shape, v_rank)
+ mat = self._random_pd_matrix(matrix_shape)
+ diag[0] = 0.0
+
+ operator_m = operator_pd_full.OperatorPDFull(mat)
+ operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ operator_m, v, diag)
+
+ with self.assertRaisesOpError('positive'):
+ operator.to_dense().eval()
+
+ def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self):
+ # We enforce that the diag is positive definite.
+ if self._diag_is_none:
+ return
+ with self.test_session():
+ matrix_shape = (3, 3)
+ v_rank = 2
+ v, diag = self._random_v_and_diag(matrix_shape, v_rank)
+ mat = self._random_pd_matrix(matrix_shape)
+ diag[0] = 0.0
+
+ operator_m = operator_pd_full.OperatorPDFull(mat)
+ operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ operator_m, v, diag, verify_pd=False)
+
+ operator.to_dense().eval() # Should not raise.
+
+ def test_event_shape_mismatch_v_and_diag_raises_static(self):
+ v = self._rng.rand(4, 3, 2)
+ diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
+ with self.test_session():
+
+ mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
+ operator_m = operator_pd_full.OperatorPDFull(mat)
+ with self.assertRaisesRegexp(ValueError, 'diag.*v.*last dimension'):
+ operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
+
+ def test_batch_shape_mismatch_v_and_diag_raises_static(self):
+ v = self._rng.rand(4, 3, 2)
+ diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
+ with self.test_session():
+
+ mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
+ operator_m = operator_pd_full.OperatorPDFull(mat)
+ with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'):
+ operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
+
+ def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self):
+ v = self._rng.rand(1, 2, 2, 2)
+ diag = self._rng.rand(5, 1) # Should have rank 1 less than v.
+ with self.test_session():
+
+ mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match
+ operator_m = operator_pd_full.OperatorPDFull(mat)
+ with self.assertRaisesRegexp(ValueError, 'diag.*rank'):
+ operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
+
+ def test_event_shape_mismatch_v_and_diag_raises_dynamic(self):
+ with self.test_session():
+
+ v = self._rng.rand(4, 3, 2)
+ diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
+ mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
+
+ v_ph = tf.placeholder(tf.float32, name='v_ph')
+ diag_ph = tf.placeholder(tf.float32, name='diag_ph')
+ mat_ph = tf.placeholder(tf.float32, name='mat_ph')
+
+ operator_m = operator_pd_full.OperatorPDFull(mat_ph)
+ updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ operator_m, v_ph, diag_ph)
+ with self.assertRaisesOpError('x == y'):
+ updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
+
+ def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self):
+ with self.test_session():
+ v = self._rng.rand(4, 3, 2)
+ diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
+ mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
+
+ v_ph = tf.placeholder(tf.float32, name='v_ph')
+ diag_ph = tf.placeholder(tf.float32, name='diag_ph')
+ mat_ph = tf.placeholder(tf.float32, name='mat_ph')
+
+ operator_m = operator_pd_full.OperatorPDFull(mat_ph)
+ updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ operator_m, v_ph, diag_ph)
+ with self.assertRaisesOpError('x == y'):
+ updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
+
+ def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self):
+ with self.test_session():
+
+ v = self._rng.rand(2, 2, 2, 2)
+ diag = self._rng.rand(2, 2) # Should have rank 1 less than v.
+ mat = self._random_pd_matrix((2, 2, 2, 2)) # mat and v match
+
+ v_ph = tf.placeholder(tf.float32, name='v_ph')
+ diag_ph = tf.placeholder(tf.float32, name='diag_ph')
+ mat_ph = tf.placeholder(tf.float32, name='mat_ph')
+
+ operator_m = operator_pd_full.OperatorPDFull(mat_ph)
+ updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ operator_m, v_ph, diag_ph)
+ with self.assertRaisesOpError('rank'):
+ updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
+
+
+class OperatorPDSqrtVDVTUpdateNoneDiagTest(OperatorPDSqrtVDVTUpdateTest):
+ _diag_is_none = True
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py
index 90e26336d7..a3b1baeba5 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_pd_full
+from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@@ -40,6 +41,7 @@ __all__ = [
"MultivariateNormalDiag",
"MultivariateNormalCholesky",
"MultivariateNormalFull",
+ "MultivariateNormalDiagPlusVDVT",
]
@@ -52,14 +54,13 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
#### Mathematical details
- The PDF of this distribution is:
+ With `C` the covariance matrix represented by the operator, the PDF of this
+ distribution is:
```
- f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
+ f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
- where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
-
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@@ -109,10 +110,10 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
- allow_nan_stats: Boolean, default False. If False, raise an exception if
- a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
- If True, batch members with valid parameters leading to undefined
- statistics will return NaN for this statistic.
+ allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@@ -170,12 +171,12 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
@property
def validate_args(self):
- """Boolean describing behavior on invalid input."""
+ """`Boolean` describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
- """Boolean describing behavior when a stat is undefined for batch member."""
+ """`Boolean` describing behavior when stats are undefined."""
return self._allow_nan_stats
@property
@@ -417,7 +418,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
determined by `diag_stdev`: `C_{ii} = diag_stdev[i]**2`.
```
- f(x) = (2*pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 * (x - mu)^T C^{-1} (x - mu))
+ f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
#### Examples
@@ -467,14 +468,14 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
- representing the standard deviations.
+ representing the standard deviations. Must be positive.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
- allow_nan_stats: Boolean, default False. If False, raise an exception if
- a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
- If True, batch members with valid parameters leading to undefined
- statistics will return NaN for this statistic.
+ allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@@ -487,6 +488,125 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
name=name)
+class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
+ """The multivariate normal distribution on `R^k`.
+
+ Every batch member of this distribution is defined by a mean and a lightweight
+ covariance matrix `C`.
+
+ #### Mathematical details
+
+ The PDF of this distribution in terms of the mean `mu` and covariance `C` is:
+
+ ```
+ f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
+ ```
+
+ For every batch member, this distribution represents `k` random variables
+ `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
+ `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
+
+ The user initializes this class by providing the mean `mu`, and a lightweight
+ definition of `C`:
+
+ ```
+ C = SS^T = SS = (M + V D V^T) (M + V D V^T)
+ M is diagonal (k x k)
+ V = is shape (k x r), typically r << k
+ D = is diagonal (r x r), optional (defaults to identity).
+ ```
+
+ This allows for `O(kr + r^3)` pdf evaluation and determinant, and `O(kr)`
+ sampling and storage (per batch member).
+
+ #### Examples
+
+ A single multi-variate Gaussian distribution is defined by a vector of means
+ of length `k`, and square root of the covariance `S = M + V D V^T`. Extra
+ leading dimensions, if provided, allow for batches.
+
+ ```python
+ # Initialize a single 3-variate Gaussian with covariance square root
+ # S = M + V D V^T, where V D V^T is a matrix-rank 2 update.
+ mu = [1, 2, 3.]
+ diag_large = [1.1, 2.2, 3.3]
+ v = ... # shape 3 x 2
+ diag_small = [4., 5.]
+ dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
+ mu, diag_large, v, diag_small=diag_small)
+
+ # Evaluate this on an observation in R^3, returning a scalar.
+ dist.pdf([-1, 0, 1])
+
+ # Initialize a batch of two 3-variate Gaussians. This time, don't provide
+ # diag_small. This means S = M + V V^T.
+ mu = [[1, 2, 3], [11, 22, 33]] # shape 2 x 3
+ diag_large = ... # shape 2 x 3
+ v = ... # shape 2 x 3 x 1, a matrix-rank 1 update.
+ dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
+ mu, diag_large, v)
+
+ # Evaluate this on a two observations, each in R^3, returning a length two
+ # tensor.
+ x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
+ dist.pdf(x)
+ ```
+
+ """
+
+ def __init__(
+ self,
+ mu,
+ diag_large,
+ v,
+ diag_small=None,
+ validate_args=True,
+ allow_nan_stats=False,
+ name="MultivariateNormalDiagPlusVDVT"):
+ """Multivariate Normal distributions on `R^k`.
+
+ For every batch member, this distribution represents `k` random variables
+ `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
+ `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
+
+ The user initializes this class by providing the mean `mu`, and a
+ lightweight definition of `C`:
+
+ ```
+ C = SS^T = SS = (M + V D V^T) (M + V D V^T)
+ M is diagonal (k x k)
+ V = is shape (k x r), typically r << k
+ D = is diagonal (r x r), optional (defaults to identity).
+ ```
+
+ Args:
+ mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`,
+ `n >= 0`. The means.
+ diag_large: Optional rank `n + 1` `float` or `double` tensor, shape
+ `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
+ v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]`
+ `n >= 0`. Defines the matrix `V`.
+ diag_small: Rank `n + 1` `float` or `double` tensor, shape
+ `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
+ is `None`, which means `D` will be the identity matrix.
+ validate_args: Whether to validate input with asserts. If `validate_args`
+ is `False`,
+ and the inputs are invalid, correct behavior is not guaranteed.
+ allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: The name to give Ops created by the initializer.
+ """
+ m = operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args)
+ cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
+ m, v, diag=diag_small, verify_pd=validate_args,
+ verify_shapes=validate_args)
+ super(MultivariateNormalDiagPlusVDVT, self).__init__(
+ mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args,
+ name=name)
+
+
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
"""The multivariate normal distribution on `R^k`.
@@ -496,14 +616,14 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
#### Mathematical details
- The PDF of this distribution is:
+ The Cholesky factor `chol` defines the covariance matrix: `C = chol chol^T`.
+
+ The PDF of this distribution is then:
```
- f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
+ f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
- where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
-
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@@ -546,20 +666,21 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu` and `chol` which holds the (batch) Cholesky
- factors `S`, such that the covariance of each batch member is `S S^*`.
+ factors, such that the covariance of each batch member is `chol chol^T`.
Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
- `[N1,...,Nb, k, k]`.
+ `[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
+ though it is zero), and the diagonal must be positive.
validate_args: Whether to validate input with asserts. If `validate_args`
- is `False`,
- and the inputs are invalid, correct behavior is not guaranteed.
- allow_nan_stats: Boolean, default False. If False, raise an exception if
- a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
- If True, batch members with valid parameters leading to undefined
- statistics will return NaN for this statistic.
+ is `False`, and the inputs are invalid, correct behavior is not
+ guaranteed.
+ allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@@ -582,14 +703,12 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
#### Mathematical details
- The PDF of this distribution is:
+ With `C = sigma`, the PDF of this distribution is:
```
- f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
+ f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
- where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
-
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@@ -633,14 +752,14 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
- `[N1,...,Nb, k, k]`.
+ `[N1,...,Nb, k, k]`. Each batch member must be positive definite.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
- allow_nan_stats: Boolean, default False. If False, raise an exception if
- a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
- If True, batch members with valid parameters leading to undefined
- statistics will return NaN for this statistic.
+ allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py b/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py
index ea5aa3c386..5e019355f7 100644
--- a/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/operator_pd_diag.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
+import six
+
from tensorflow.contrib.distributions.python.ops import operator_pd
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -26,11 +29,190 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
+@six.add_metaclass(abc.ABCMeta)
+class OperatorPDDiagBase(operator_pd.OperatorPDBase):
+ """Base class for diagonal operators."""
+
+ def __init__(self, diag, verify_pd=True, name='OperatorPDDiagBase'):
+ self._verify_pd = verify_pd
+ self._name = name
+ with ops.name_scope(name):
+ with ops.op_scope([diag], 'init'):
+ self._diag = self._check_diag(diag)
+
+ def _check_diag(self, diag):
+ """Verify that `diag` is positive."""
+ diag = ops.convert_to_tensor(diag, name='diag')
+ if not self.verify_pd:
+ return diag
+ deps = [check_ops.assert_positive(diag)]
+ return control_flow_ops.with_dependencies(deps, diag)
+
+ @property
+ def name(self):
+ """String name identifying this `Operator`."""
+ return self._name
+
+ @property
+ def verify_pd(self):
+ """Whether to verify that this `Operator` is positive definite."""
+ return self._verify_pd
+
+ @property
+ def dtype(self):
+ """Data type of matrix elements of `A`."""
+ return self._diag.dtype
+
+ @property
+ def inputs(self):
+ """Initialization arguments."""
+ return [self._diag]
+
+ def get_shape(self):
+ """`TensorShape` giving static shape."""
+ # If d_shape = [5, 3], we return [5, 3, 3].
+ d_shape = self._diag.get_shape()
+ return d_shape.concatenate(d_shape[-1:])
+
+ def _shape(self):
+ d_shape = array_ops.shape(self._diag)
+ k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
+ return array_ops.concat(0, (d_shape, [k]))
+
+ @abc.abstractmethod
+ def _batch_log_det(self):
+ pass
+
+ @abc.abstractmethod
+ def _inv_quadratic_form_on_vectors(self, x):
+ pass
+
+ @abc.abstractmethod
+ def _batch_matmul(self, x, transpose_x=False):
+ pass
+
+ @abc.abstractmethod
+ def _batch_sqrt_matmul(self, x, transpose_x=False):
+ pass
+
+ @abc.abstractmethod
+ def _batch_solve(self, rhs):
+ pass
+
+ @abc.abstractmethod
+ def _batch_sqrt_solve(self, rhs):
+ pass
+
+ @abc.abstractmethod
+ def _to_dense(self):
+ pass
+
+ @abc.abstractmethod
+ def _sqrt_to_dense(self):
+ pass
+
+ @abc.abstractmethod
+ def _add_to_tensor(self, mat):
+ pass
+
+
+class OperatorPDDiag(OperatorPDDiagBase):
+ """Class representing a (batch) of positive definite matrices `A`.
+
+ This class provides access to functions of a batch of symmetric positive
+ definite (PD) matrices `A` in `R^{k x k}`.
+
+ In this case, `A` is diagonal and is defined by a provided tensor `diag`,
+ `A_{ii} = diag[i]`.
+
+ Determinants, solves, and storage are `O(k)`.
+
+ In practice, this operator represents a (batch) matrix `A` with shape
+ `[N1,...,Nn, k, k]` for some `n >= 0`. The first `n` indices designate a
+ batch member. For every batch member `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ a `k x k` matrix.
+
+ For example,
+
+ ```python
+ distributions = tf.contrib.distributions
+ diag = [1.0, 2.0]
+ operator = OperatorPDDiag(diag)
+ operator.det() # ==> (1 * 2)
+
+ # Compute the quadratic form x^T A^{-1} x for vector x.
+ x = [1.0, 2.0]
+ operator.inv_quadratic_form_on_vectors(x)
+
+ # Matrix multiplication by the square root, S w, with A = S S^T.
+ # Recall A is diagonal, and so then is S, with S_{ij} = sqrt(A_{ij}).
+ # If w is iid normal, S w has covariance A.
+ w = [[1.0],
+ [2.0]]
+ operator.sqrt_matmul(w)
+ ```
+
+ The above three methods, `log_det`, `inv_quadratic_form_on_vectors`, and
+ `sqrt_matmul` provide "all" that is necessary to use a covariance matrix
+ in a multi-variate normal distribution. See the class
+ `MultivariateNormalDiag`.
+ """
+
+ def __init__(self, diag, verify_pd=True, name='OperatorPDDiag'):
+ """Initialize an OperatorPDDiag.
+
+ Args:
+ diag: Shape `[N1,...,Nn, k]` positive tensor with `n >= 0`, `k >= 1`.
+ verify_pd: Whether to check `diag` is positive.
+ name: A name to prepend to all ops created by this class.
+ """
+ super(OperatorPDDiag, self).__init__(
+ diag, verify_pd=verify_pd, name=name)
+
+ def _batch_log_det(self):
+ return math_ops.reduce_sum(
+ math_ops.log(self._diag), reduction_indices=[-1])
+
+ def _inv_quadratic_form_on_vectors(self, x):
+ return self._iqfov_via_solve(x)
+
+ def _batch_matmul(self, x, transpose_x=False):
+ if transpose_x:
+ x = array_ops.batch_matrix_transpose(x)
+ diag_mat = array_ops.expand_dims(self._diag, -1)
+ return diag_mat * x
+
+ def _batch_sqrt_matmul(self, x, transpose_x=False):
+ if transpose_x:
+ x = array_ops.batch_matrix_transpose(x)
+ diag_mat = array_ops.expand_dims(self._diag, -1)
+ return math_ops.sqrt(diag_mat) * x
+
+ def _batch_solve(self, rhs):
+ diag_mat = array_ops.expand_dims(self._diag, -1)
+ return rhs / diag_mat
+
+ def _batch_sqrt_solve(self, rhs):
+ diag_mat = array_ops.expand_dims(self._diag, -1)
+ return rhs / math_ops.sqrt(diag_mat)
+
+ def _to_dense(self):
+ return array_ops.batch_matrix_diag(self._diag)
+
+ def _sqrt_to_dense(self):
+ return array_ops.batch_matrix_diag(math_ops.sqrt(self._diag))
+
+ def _add_to_tensor(self, mat):
+ mat_diag = array_ops.batch_matrix_diag_part(mat)
+ new_diag = self._diag + mat_diag
+ return array_ops.batch_matrix_set_diag(mat, new_diag)
+
+
+class OperatorPDSqrtDiag(OperatorPDDiagBase):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a batch of symmetric positive
- definite (PD) matrices `A` in `R^{k x k}` defined by their their square root,
+ definite (PD) matrices `A` in `R^{k x k}` defined by their square root,
`S`, such that `A = SS^T`.
In this case, `S` is diagonal and is defined by a provided tensor `diag`,
@@ -75,58 +257,17 @@ class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
verify_pd: Whether to check `diag` is positive.
name: A name to prepend to all ops created by this class.
"""
- self._verify_pd = verify_pd
- self._name = name
- with ops.name_scope(name):
- with ops.op_scope([diag], 'init'):
- self._diag = self._check_diag(diag)
-
- def _check_diag(self, diag):
- """Verify that `diag` is positive."""
- diag = ops.convert_to_tensor(diag, name='diag')
- if not self.verify_pd:
- return diag
- deps = [check_ops.assert_positive(diag)]
- return control_flow_ops.with_dependencies(deps, diag)
-
- @property
- def name(self):
- """String name identifying this `Operator`."""
- return self._name
-
- @property
- def verify_pd(self):
- """Whether to verify that this `Operator` is positive definite."""
- return self._verify_pd
-
- @property
- def dtype(self):
- """Data type of matrix elements of `A`."""
- return self._diag.dtype
+ super(OperatorPDSqrtDiag, self).__init__(
+ diag, verify_pd=verify_pd, name=name)
def _batch_log_det(self):
return 2 * math_ops.reduce_sum(
math_ops.log(self._diag), reduction_indices=[-1])
- @property
- def inputs(self):
- """List of tensors that were provided as initialization inputs."""
- return [self._diag]
-
def _inv_quadratic_form_on_vectors(self, x):
# This Operator is defined in terms of diagonal entries of the sqrt.
return self._iqfov_via_sqrt_solve(x)
- def get_shape(self):
- """`TensorShape` giving static shape."""
- d_shape = self._diag.get_shape()
- return d_shape.concatenate(d_shape[-1:])
-
- def _shape(self):
- d_shape = array_ops.shape(self._diag)
- k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
- return array_ops.concat(0, (d_shape, [k]))
-
def _batch_matmul(self, x, transpose_x=False):
if transpose_x:
x = array_ops.batch_matrix_transpose(x)
diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py b/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py
new file mode 100644
index 0000000000..f1b750351c
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/operator_pd_identity.py
@@ -0,0 +1,207 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Identity operator in `R^k`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.contrib.distributions.python.ops import operator_pd
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+class OperatorPDIdentity(operator_pd.OperatorPDBase):
+ """Identity operator in `R^k`: `Ax = x`.
+
+ This provides an efficient implementation of the identity as an `OperatorPD`.
+ Storage, solves, and matmul are all `O(1)`, independent of batch size.
+
+ In order to be a drop-in replacement for other operators, shape and dtype
+ of arguments (e.g. to `matmul`) are checked statically as though this operator
+ was an instantiated matrix.
+
+ Dynamic shape checks of arguments are not done since that could impede
+ performance.
+ """
+
+ def __init__(self, shape, dtype, verify_pd=True, name='OperatorPDIdentity'):
+ """Initialize an `OperatorPDIdentity`.
+
+ Args:
+ shape: `int32` rank 1 `Tensor` of length at least 2, and with the last
+ two entries equal (since this is a square matrix).
+ dtype: Data type of the matrix that this operator represents.
+ verify_pd: `Boolean`, if `True`, asserts are added to the initialization
+ args to ensure they define this operator as a square (batch) matrix.
+ name: Name to prepend to `Ops`.
+ """
+
+ # Grab static shape if available now.
+ with ops.name_scope(name):
+ with ops.op_scope([shape], 'init'):
+ self._dtype = dtypes.as_dtype(dtype)
+ self._verify_pd = verify_pd
+ self._name = name
+
+ # Store the static shape (if possible) right now before adding the
+ # asserts, since the asserts prevent .constant_value from working.
+ shape = ops.convert_to_tensor(shape, name='shape')
+ self._get_shape = tensor_shape.TensorShape(
+ tensor_util.constant_value(shape))
+ self._shape_arg = self._check_shape(shape)
+
+ def _check_shape(self, shape):
+ """Check that the init arg `shape` defines a valid operator."""
+ shape = ops.convert_to_tensor(shape, name='shape')
+ if not self._verify_pd:
+ return shape
+
+ # Further checks are equivalent to verification that this is positive
+ # definite. Why? Because the further checks simply check that this is a
+ # square matrix, and combining the fact that this is square (and thus maps
+ # a vector space R^k onto itself), with the behavior of .matmul(), this must
+ # be the identity operator.
+ rank = array_ops.size(shape)
+ assert_matrix = check_ops.assert_less_equal(2, rank)
+ with ops.control_dependencies([assert_matrix]):
+ last_dim = array_ops.gather(shape, rank - 1)
+ second_to_last_dim = array_ops.gather(shape, rank - 2)
+ assert_square = check_ops.assert_equal(last_dim, second_to_last_dim)
+ return control_flow_ops.with_dependencies([assert_matrix, assert_square],
+ shape)
+
+ def _check_x(self, x):
+ """Static check that the argument `x` is proper `shape`, `dtype`."""
+ # x is a typical argument e.g. to matmul or solve. In both cases, x should
+ # have the same type/shape since this is a square matrix. These checks are
+ # ususally not needed since we ususally have some tensor backing this
+ # distribution, and the calls to tf.matmul do a shape/type check.
+ #
+ # Static checks only for efficiency, the identity should be fast.
+ #
+ # Why check at all? Because we want this operator to be swappable for a
+ # real Operator.
+ if self.dtype != x.dtype:
+ raise TypeError(
+ 'Expected argument "x" to have same dtype as this operator (%s). '
+ 'Found: %s' % (self.dtype, x.dtype))
+
+ x_shape = x.get_shape()
+ self_shape = self.get_shape()
+ found_msg = (
+ 'Found: operator.shape = %s, x.shape = %s' % (self_shape, x_shape))
+ if x_shape.ndims is not None and self_shape.ndims is not None:
+ if x_shape.ndims != self_shape.ndims:
+ raise ValueError(
+ 'Expected argument "x" to have same tensor rank as this operator. '
+ + found_msg)
+ if x_shape.is_fully_defined() and self_shape.is_fully_defined():
+ if x_shape[-2] != self_shape[-1]:
+ raise ValueError(
+ 'Incompatible shapes for matrix-matrix operation. ' + found_msg)
+
+ @property
+ def name(self):
+ """String name identifying this `Operator`."""
+ return self._name
+
+ @property
+ def verify_pd(self):
+ """Whether to verify that this `Operator` is positive definite."""
+ return self._verify_pd
+
+ @property
+ def dtype(self):
+ """Data type of matrix elements of `A`."""
+ return self._dtype
+
+ def _add_to_tensor(self, mat):
+ # Add to a tensor in O(k) time!
+ mat_diag = array_ops.batch_matrix_diag_part(mat)
+ new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
+ return array_ops.batch_matrix_set_diag(mat, new_diag)
+
+ def _inv_quadratic_form_on_vectors(self, x):
+ self._check_x(x)
+ return self._iqfov_via_sqrt_solve(x)
+
+ @property
+ def inputs(self):
+ """List of tensors that were provided as initialization inputs."""
+ return [self._shape]
+
+ def get_shape(self):
+ """Static `TensorShape` of entire operator.
+
+ If this operator represents the batch matrix `A` with
+ `A.shape = [N1,...,Nn, k, k]`, then this returns
+ `TensorShape([N1,...,Nn, k, k])`
+
+ Returns:
+ `TensorShape`, statically determined, may be undefined.
+ """
+ return self._get_shape
+
+ def _shape(self):
+ return self._shape_arg
+
+ def _det(self):
+ det = array_ops.ones(self.batch_shape(), dtype=self.dtype)
+ det.set_shape(self.get_batch_shape())
+ return det
+
+ def _batch_log_det(self):
+ log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
+ log_det.set_shape(self.get_batch_shape())
+ return log_det
+
+ def _batch_sqrt_log_det(self):
+ s_log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
+ s_log_det.set_shape(self.get_batch_shape())
+ return s_log_det
+
+ def _batch_matmul(self, x, transpose_x=False):
+ if transpose_x:
+ x = array_ops.batch_matrix_transpose(x)
+ self._check_x(x)
+ return x
+
+ def _batch_sqrt_matmul(self, x, transpose_x=False):
+ return self._batch_matmul(x, transpose_x=transpose_x)
+
+ def _batch_solve(self, rhs):
+ self._check_x(rhs)
+ return rhs
+
+ def _batch_sqrt_solve(self, rhs):
+ self._check_x(rhs)
+ return rhs
+
+ def _to_dense(self):
+ diag = array_ops.ones(self.vector_shape(), dtype=self.dtype)
+ dense = array_ops.batch_matrix_diag(diag)
+ dense.set_shape(self.get_shape())
+ return dense
+
+ def _sqrt_to_dense(self):
+ return self.to_dense()
diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py b/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py
new file mode 100644
index 0000000000..3c934e721c
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/operator_pd_vdvt_update.py
@@ -0,0 +1,475 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operator defined: `A = SS^T` where `S = M + VDV^T`, for `OperatorPD` `M`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import operator_pd
+from tensorflow.contrib.distributions.python.ops import operator_pd_diag
+from tensorflow.contrib.distributions.python.ops import operator_pd_identity
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
+ r"""Operator defined by `A=SS^T`, where `S = M + VDV^T` for `OperatorPD` `M`.
+
+ This provides efficient low-rank updates of arbitrary `OperatorPD`.
+
+ Some math:
+
+ Given positive definite operator representing positive definite (batch) matrix
+ `M` in `R^{k x k}`, diagonal matrix `D` in `R^{r x r}`, and low rank `V` in
+ `R^{k x r}` this class represents the batch matrix `A`, defined by its square
+ root `S` as follows:
+
+ ```
+ A = SS^T, where
+ S := M + VDV^T
+ ```
+
+ Defining an operator in terms of its square root means that
+ `A_{ij} = S_i S_j^T`, where `S_i` is the ith row of `S`. The update
+ `VDV^T` has `ij` coordinate equal to `sum_k V_{ik} D_{kk} V_{jk}`.
+
+ Computational efficiency:
+
+ Defining `A` via its square root eliminates the need to compute the square
+ root.
+
+ Performance depends on the operator representing `M`, the batch size `B`, and
+ the width of the matrix being multiplied, or systems being solved `L`.
+
+ Since `V` is rank `r`, the update adds
+
+ * `O(B L k r)` to matmul, which requires a call to `M.matmul`.
+ * `O(B L r^3)` to solves, which require a call to `M.solve` as well as the
+ solution to a batch of rank `r` systems.
+ * `O(B r^3)` to determinants, which require a call to `M.solve` as well as the
+ solution to a batch of rank `r` systems.
+
+ The rank `r` solve and determinant are both done through a Cholesky
+ factorization, thus some computation is shared.
+
+ See
+ https://en.wikipedia.org/wiki/Woodbury_matrix_identity
+ https://en.wikipedia.org/wiki/Matrix_determinant_lemma
+ """
+
+ # Note that diag must be nonsingular to use Woodbury lemma, and must be
+ # positive def to use a Cholesky factorization, so we enforce that here.
+ def __init__(self,
+ operator,
+ v,
+ diag=None,
+ verify_pd=True,
+ verify_shapes=True,
+ name='OperatorPDSqrtVDVTUpdate'):
+ """Initialize an `OperatorPDSqrtVDVTUpdate`.
+
+ Args:
+ operator: Subclass of `OperatorPDBase`. Represents the (batch) positive
+ definite matrix `M` in `R^{k x k}`.
+ v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as
+ `operator`, and last two dimensions of shape `(k, r)`.
+ diag: Optional `Tensor` defining batch vector of same `dtype` and
+ `batch_shape` as `operator`, and last dimension of size `r`. If `None`,
+ the update becomes `VV^T` rather than `VDV^T`.
+ verify_pd: `Boolean`. If `True`, add asserts that `diag > 0`, which,
+ along with the positive definiteness of `operator`, is sufficient to
+ make the resulting operator positive definite.
+ verify_shapes: `Boolean`. If `True`, check that `operator`, `v`, and
+ `diag` have compatible shapes.
+ name: A name to prepend to `Op` names.
+ """
+
+ if not isinstance(operator, operator_pd.OperatorPDBase):
+ raise TypeError('operator was not instance of OperatorPDBase.')
+
+ with ops.name_scope(name):
+ with ops.op_scope(operator.inputs + [v, diag], 'init'):
+ self._operator = operator
+ self._v = ops.convert_to_tensor(v, name='v')
+ self._verify_pd = verify_pd
+ self._verify_shapes = verify_shapes
+ self._name = name
+
+ # This operator will be PD so long as the diag is PSD, but Woodbury
+ # and determinant lemmas require diag to be PD. So require diag PD
+ # whenever we ask to "verify_pd".
+ if diag is not None:
+ self._diag = ops.convert_to_tensor(diag, name='diag')
+ self._diag_operator = operator_pd_diag.OperatorPDDiag(
+ diag, verify_pd=self.verify_pd)
+ # No need to verify that the inverse of a PD is PD.
+ self._diag_inv_operator = operator_pd_diag.OperatorPDDiag(
+ 1 / self._diag, verify_pd=False)
+ else:
+ self._diag = None
+ self._diag_operator = self._get_identity_operator(self._v)
+ self._diag_inv_operator = self._diag_operator
+
+ self._check_types(operator, self._v, self._diag)
+ # Always check static.
+ checked = self._check_shapes_static(operator, self._v, self._diag)
+ if not checked and self._verify_shapes:
+ self._v, self._diag = self._check_shapes_dynamic(
+ operator, self._v, self._diag)
+
+ def _get_identity_operator(self, v):
+ """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`."""
+ with ops.op_scope([v], 'get_identity_operator'):
+ if v.get_shape().is_fully_defined():
+ v_shape = v.get_shape().as_list()
+ v_batch_shape = v_shape[:-2]
+ r = v_shape[-1]
+ id_shape = v_batch_shape + [r, r]
+ else:
+ v_shape = array_ops.shape(v)
+ v_rank = array_ops.rank(v)
+ v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2])
+ r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v
+ id_shape = array_ops.concat(0, (v_batch_shape, [r, r]))
+ return operator_pd_identity.OperatorPDIdentity(
+ id_shape, v.dtype, verify_pd=self._verify_pd)
+
+ def _check_types(self, operator, v, diag):
+ def msg():
+ string = (
+ 'dtypes must match: Found operator.dtype = %s, v.dtype = %s'
+ % (operator.dtype, v.dtype))
+ return string
+
+ if operator.dtype != v.dtype:
+ raise TypeError(msg())
+ if diag is not None:
+ if diag.dtype != v.dtype:
+ raise TypeError('%s, diag.dtype = %s' % (msg(), diag.dtype))
+
+ def _check_shapes_static(self, operator, v, diag):
+ """True if they are compatible. Raise if not. False if could not check."""
+ def msg():
+ # Error message when shapes don't match.
+ string = ' Found: operator.shape = %s, v.shape = %s' % (s_op, s_v)
+ if diag is not None:
+ string += ', diag.shape = ' % s_d
+ return string
+
+ s_op = operator.get_shape()
+ s_v = v.get_shape()
+
+ # If everything is not fully defined, return False because we couldn't check
+ if not (s_op.is_fully_defined() and s_v.is_fully_defined()):
+ return False
+ if diag is not None:
+ s_d = diag.get_shape()
+ if not s_d.is_fully_defined():
+ return False
+
+ # Now perform the checks, raising ValueError if they fail.
+
+ # Check tensor rank.
+ if s_v.ndims != s_op.ndims:
+ raise ValueError('v should have same rank as operator' + msg())
+ if diag is not None:
+ if s_d.ndims != s_op.ndims - 1:
+ raise ValueError('diag should have rank 1 less than operator' + msg())
+
+ # Check batch shape
+ if s_v[:-2] != s_op[:-2]:
+ raise ValueError('v and operator should have same batch shape' + msg())
+ if diag is not None:
+ if s_d[:-1] != s_op[:-2]:
+ raise ValueError(
+ 'diag and operator should have same batch shape' + msg())
+
+ # Check event shape
+ if s_v[-2] != s_op[-1]:
+ raise ValueError(
+ 'v and operator should be compatible for matmul' + msg())
+ if diag is not None:
+ if s_d[-1] != s_v[-1]:
+ raise ValueError('diag and v should have same last dimension' + msg())
+
+ return True
+
+ def _check_shapes_dynamic(self, operator, v, diag):
+ """Return (v, diag) with Assert dependencies, which check shape."""
+ checks = []
+ with ops.op_scope([operator, v, diag], 'check_shapes'):
+ s_v = array_ops.shape(v)
+ r_op = operator.rank()
+ r_v = array_ops.rank(v)
+ if diag is not None:
+ s_d = array_ops.shape(diag)
+ r_d = array_ops.rank(diag)
+
+ # Check tensor rank.
+ checks.append(check_ops.assert_rank(v, r_op))
+ if diag is not None:
+ checks.append(check_ops.assert_rank(diag, r_op - 1))
+
+ # Check batch shape
+ checks.append(check_ops.assert_equal(
+ operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
+ if diag is not None:
+ checks.append(check_ops.assert_equal(
+ operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))
+
+ # Check event shape
+ checks.append(check_ops.assert_equal(
+ operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
+ if diag is not None:
+ checks.append(check_ops.assert_equal(
+ array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))
+
+ v = control_flow_ops.with_dependencies(checks, v)
+ if diag is not None:
+ diag = control_flow_ops.with_dependencies(checks, diag)
+ return v, diag
+
+ @property
+ def name(self):
+ """String name identifying this `Operator`."""
+ return self._name
+
+ @property
+ def verify_pd(self):
+ """Whether to verify that this `Operator` is positive definite."""
+ return self._verify_pd
+
+ @property
+ def dtype(self):
+ """Data type of matrix elements of `A`."""
+ return self._v.dtype
+
+ def _inv_quadratic_form_on_vectors(self, x):
+ return self._iqfov_via_sqrt_solve(x)
+
+ @property
+ def inputs(self):
+ """List of tensors that were provided as initialization inputs."""
+ return self._operator.inputs + self._diag_operator.inputs + [self._v]
+
+ def get_shape(self):
+ """Static `TensorShape` of entire operator.
+
+ If this operator represents the batch matrix `A` with
+ `A.shape = [N1,...,Nn, k, k]`, then this returns
+ `TensorShape([N1,...,Nn, k, k])`
+
+ Returns:
+ `TensorShape`, statically determined, may be undefined.
+ """
+ return self._operator.get_shape()
+
+ def _shape(self):
+ return self._operator.shape()
+
+ def _det(self):
+ return math_ops.exp(self.log_det())
+
+ def _batch_log_det(self):
+ return 2 * self._batch_sqrt_log_det()
+
+ def _log_det(self):
+ return 2 * self._sqrt_log_det()
+
+ def _sqrt_log_det(self):
+ # The matrix determinant lemma states:
+ # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
+ # = det(C) * det(D) * det(M)
+ #
+ # Here we compute the Cholesky factor of "C", then pass the result on.
+ diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
+ batch_mode=False))
+ return self._sqrt_log_det_core(diag_chol_c)
+
+ def _batch_sqrt_log_det(self):
+ # Here we compute the Cholesky factor of "C", then pass the result on.
+ diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
+ batch_mode=True))
+ return self._sqrt_log_det_core(diag_chol_c)
+
+ def _chol_capacitance(self, batch_mode):
+ """Cholesky factorization of the capacitance term."""
+ # Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes
+ # known as the "capacitance" matrix.
+
+ # self._operator will use batch if need be. Automatically. We cannot force
+ # that here.
+ # M^{-1} V
+ minv_v = self._operator.solve(self._v)
+ # V^T M^{-1} V
+ if batch_mode:
+ vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True)
+ else:
+ vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True)
+
+ # D^{-1} + V^T M^{-1} V
+ capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v)
+ # Cholesky[D^{-1} + V^T M^{-1} V]
+ if batch_mode:
+ return linalg_ops.batch_cholesky(capacitance)
+ else:
+ return linalg_ops.cholesky(capacitance)
+
+ def _sqrt_log_det_core(self, diag_chol_c):
+ """Finish computation of Sqrt[Log[Det]]."""
+ # Complete computation of ._log_det and ._batch_log_det, after the initial
+ # Cholesky factor has been taken with the appropriate batch/non-batch method
+
+ # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
+ # = det(C) * det(D) * det(M)
+ # Multiply by 2 here because this is the log-det of the Cholesky factor of C
+ log_det_c = 2 * math_ops.reduce_sum(
+ math_ops.log(diag_chol_c),
+ reduction_indices=[-1])
+ # Add together to get Log[det(M + VDV^T)], the Log-det of the updated square
+ # root.
+ log_det_updated_sqrt = (
+ log_det_c + self._diag_operator.log_det() + self._operator.log_det())
+ return log_det_updated_sqrt
+
+ def _batch_matmul(self, x, transpose_x=False):
+ # Since the square root is PD, it is symmetric, and so A = SS^T = SS.
+ s_x = self._batch_sqrt_matmul(x, transpose_x=transpose_x)
+ return self._batch_sqrt_matmul(s_x)
+
+ def _matmul(self, x, transpose_x=False):
+ # Since the square root is PD, it is symmetric, and so A = SS^T = SS.
+ s_x = self._sqrt_matmul(x, transpose_x=transpose_x)
+ return self._sqrt_matmul(s_x)
+
+ def _batch_sqrt_matmul(self, x, transpose_x=False):
+ v = self._v
+ m = self._operator
+ d = self._diag_operator
+ # The operators call the appropriate matmul/batch_matmul automatically. We
+ # cannot override.
+ # batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to
+ # transpose the left and right.
+ mx = m.matmul(x, transpose_x=transpose_x)
+ vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x)
+ d_vt_x = d.matmul(vt_x)
+ v_d_vt_x = math_ops.batch_matmul(v, d_vt_x)
+
+ return mx + v_d_vt_x
+
+ def _sqrt_matmul(self, x, transpose_x=False):
+ v = self._v
+ m = self._operator
+ d = self._diag_operator
+ # The operators call the appropriate matmul/batch_matmul automatically. We
+ # cannot override.
+ # matmul is defined as: a * b, so transpose_a, transpose_b are used.
+ # transpose the left and right.
+ mx = m.matmul(x, transpose_x=transpose_x)
+ vt_x = math_ops.matmul(v, x, transpose_a=True, transpose_b=transpose_x)
+ d_vt_x = d.matmul(vt_x)
+ v_d_vt_x = math_ops.matmul(v, d_vt_x)
+
+ return mx + v_d_vt_x
+
+ def _solve(self, rhs):
+ # This operator represents A = SS^T, but S is symmetric, so A = SS,
+ # which means A^{-1} = S^{-1}S^{-2}
+ # S^{-1} rhs
+ sqrtinv_rhs = self._sqrt_solve(rhs)
+ return self._sqrt_solve(sqrtinv_rhs)
+
+ def _batch_solve(self, rhs):
+ sqrtinv_rhs = self._batch_sqrt_solve(rhs)
+ return self._batch_sqrt_solve(sqrtinv_rhs)
+
+ def _sqrt_solve(self, rhs):
+ # Recall the square root of this operator is M + VDV^T.
+ # The Woodbury formula gives:
+ # (M + VDV^T)^{-1}
+ # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
+ # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
+ # where C is the capacitance matrix.
+ # TODO(jvdillon) Determine if recursively applying rank-1 updates is more
+ # efficient. May not be possible because a general n x n matrix can be
+ # represeneted as n rank-1 updates, and solving with this matrix is always
+ # done in O(n^3) time.
+ m = self._operator
+ v = self._v
+ cchol = self._chol_capacitance(batch_mode=False)
+
+ # The operators will use batch/singleton mode automatically. We don't
+ # override.
+ # M^{-1} rhs
+ minv_rhs = m.solve(rhs)
+ # V^T M^{-1} rhs
+ vt_minv_rhs = math_ops.matmul(v, minv_rhs, transpose_a=True)
+ # C^{-1} V^T M^{-1} rhs
+ cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs)
+ # V C^{-1} V^T M^{-1} rhs
+ v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs)
+ # M^{-1} V C^{-1} V^T M^{-1} rhs
+ minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
+
+ # M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
+ return minv_rhs - minv_v_cinv_vt_minv_rhs
+
+ def _batch_sqrt_solve(self, rhs):
+ # Recall the square root of this operator is M + VDV^T.
+ # The Woodbury formula gives:
+ # (M + VDV^T)^{-1}
+ # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
+ # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
+ # where C is the capacitance matrix.
+ m = self._operator
+ v = self._v
+ cchol = self._chol_capacitance(batch_mode=True)
+
+ # The operators will use batch/singleton mode automatically. We don't
+ # override.
+ # M^{-1} rhs
+ minv_rhs = m.solve(rhs)
+ # V^T M^{-1} rhs
+ vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True)
+ # C^{-1} V^T M^{-1} rhs
+ cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs)
+ # V C^{-1} V^T M^{-1} rhs
+ v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs)
+ # M^{-1} V C^{-1} V^T M^{-1} rhs
+ minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
+
+ # M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
+ return minv_rhs - minv_v_cinv_vt_minv_rhs
+
+ def _to_dense(self):
+ sqrt = self.sqrt_to_dense()
+ return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
+
+ def _sqrt_to_dense(self):
+ v = self._v
+ d = self._diag_operator
+ m = self._operator
+
+ d_vt = d.matmul(v, transpose_x=True)
+ # Batch op won't be efficient for singletons. Currently we don't break
+ # to_dense into batch/singleton methods.
+ v_d_vt = math_ops.batch_matmul(v, d_vt)
+ m_plus_v_d_vt = m.to_dense() + v_d_vt
+ return m_plus_v_d_vt