diff options
12 files changed, 2243 insertions, 2048 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 714e48a35c..7ac72f07e9 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -349,9 +349,43 @@ cuda_py_test( ) cuda_py_test( - name = "mvn_test", + name = "mvn_diag_test", + size = "small", + srcs = ["python/kernel_tests/mvn_diag_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "mvn_diag_plus_low_rank_test", size = "medium", - srcs = ["python/kernel_tests/mvn_test.py"], + srcs = ["python/kernel_tests/mvn_diag_plus_low_rank_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "mvn_tril_test", + size = "small", + srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", "//third_party/py/numpy", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index db02ea60f8..fa92d3c156 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -134,7 +134,9 @@ from tensorflow.contrib.distributions.python.ops.laplace import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * from tensorflow.contrib.distributions.python.ops.multinomial import * -from tensorflow.contrib.distributions.python.ops.mvn import * +from tensorflow.contrib.distributions.python.ops.mvn_diag import * +from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * +from tensorflow.contrib.distributions.python.ops.mvn_tril import * from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.onehot_categorical import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py new file mode 100644 index 0000000000..834877769e --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py @@ -0,0 +1,420 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultivariateNormal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib import distributions +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +ds = distributions + + +class MultivariateNormalDiagPlusLowRankTest(test.TestCase): + """Well tested because this is a simple override of the base class.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testDiagBroadcastBothBatchAndEvent(self): + # batch_shape: [3], event_shape: [2] + diag = np.array([[1., 2], [3, 4], [5, 6]]) + # batch_shape: [1], event_shape: [] + identity_multiplier = np.array([5.]) + with self.test_session(): + dist = ds.MultivariateNormalDiagPlusLowRank( + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=True) + self.assertAllClose( + np.array([[[1. + 5, 0], + [0, 2 + 5]], + [[3 + 5, 0], + [0, 4 + 5]], + [[5 + 5, 0], + [0, 6 + 5]]]), + dist.scale.to_dense().eval()) + + def testDiagBroadcastBothBatchAndEvent2(self): + # This test differs from `testDiagBroadcastBothBatchAndEvent` in that it + # broadcasts batch_shape's from both the `scale_diag` and + # `scale_identity_multiplier` args. + # batch_shape: [3], event_shape: [2] + diag = np.array([[1., 2], [3, 4], [5, 6]]) + # batch_shape: [3, 1], event_shape: [] + identity_multiplier = np.array([[5.], [4], [3]]) + with self.test_session(): + dist = ds.MultivariateNormalDiagPlusLowRank( + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=True) + self.assertAllEqual( + [3, 3, 2, 2], + dist.scale.to_dense().get_shape()) + + def testDiagBroadcastOnlyEvent(self): + # batch_shape: [3], event_shape: [2] + diag = np.array([[1., 2], [3, 4], [5, 6]]) + # batch_shape: [3], event_shape: [] + identity_multiplier = np.array([5., 4, 3]) + with self.test_session(): + dist = ds.MultivariateNormalDiagPlusLowRank( + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=True) + self.assertAllClose( + np.array([[[1. + 5, 0], + [0, 2 + 5]], + [[3 + 4, 0], + [0, 4 + 4]], + [[5 + 3, 0], + [0, 6 + 3]]]), # shape: [3, 2, 2] + dist.scale.to_dense().eval()) + + def testDiagBroadcastMultiplierAndLoc(self): + # batch_shape: [], event_shape: [3] + loc = np.array([1., 0, -1]) + # batch_shape: [3], event_shape: [] + identity_multiplier = np.array([5., 4, 3]) + with self.test_session(): + dist = ds.MultivariateNormalDiagPlusLowRank( + loc=loc, + scale_identity_multiplier=identity_multiplier, + validate_args=True) + self.assertAllClose( + np.array([[[5, 0, 0], + [0, 5, 0], + [0, 0, 5]], + [[4, 0, 0], + [0, 4, 0], + [0, 0, 4]], + [[3, 0, 0], + [0, 3, 0], + [0, 0, 3]]]), + dist.scale.to_dense().eval()) + + def testMean(self): + mu = [-1.0, 1.0] + diag_large = [1.0, 5.0] + v = [[2.0], [3.0]] + diag_small = [3.0] + with self.test_session(): + dist = ds.MultivariateNormalDiagPlusLowRank( + loc=mu, + scale_diag=diag_large, + scale_perturb_factor=v, + scale_perturb_diag=diag_small, + validate_args=True) + self.assertAllEqual(mu, dist.mean().eval()) + + def testSample(self): + # TODO(jvdillon): This test should be the basis of a new test fixture which + # is applied to every distribution. When we make this fixture, we'll also + # separate the analytical- and sample-based tests as well as for each + # function tested. For now, we group things so we can recycle one batch of + # samples (thus saving resources). + + mu = np.array([-1., 1, 0.5], dtype=np.float32) + diag_large = np.array([1., 0.5, 0.75], dtype=np.float32) + diag_small = np.array([-1.1, 1.2], dtype=np.float32) + v = np.array([[0.7, 0.8], + [0.9, 1], + [0.5, 0.6]], dtype=np.float32) # shape: [k, r] = [3, 2] + + true_mean = mu + true_scale = np.diag(diag_large) + np.matmul(np.matmul( + v, np.diag(diag_small)), v.T) + true_covariance = np.matmul(true_scale, true_scale.T) + true_variance = np.diag(true_covariance) + true_stddev = np.sqrt(true_variance) + true_det_covariance = np.linalg.det(true_covariance) + true_log_det_covariance = np.log(true_det_covariance) + + with self.test_session() as sess: + dist = ds.MultivariateNormalDiagPlusLowRank( + loc=mu, + scale_diag=diag_large, + scale_perturb_factor=v, + scale_perturb_diag=diag_small, + validate_args=True) + + # The following distributions will test the KL divergence calculation. + mvn_identity = ds.MultivariateNormalDiag( + loc=np.array([1., 2, 0.25], dtype=np.float32), + validate_args=True) + mvn_scaled = ds.MultivariateNormalDiag( + loc=mvn_identity.loc, + scale_identity_multiplier=2.2, + validate_args=True) + mvn_diag = ds.MultivariateNormalDiag( + loc=mvn_identity.loc, + scale_diag=np.array([0.5, 1.5, 1.], dtype=np.float32), + validate_args=True) + mvn_chol = ds.MultivariateNormalTriL( + loc=np.array([1., 2, -1], dtype=np.float32), + scale_tril=np.array([[6., 0, 0], + [2, 5, 0], + [1, 3, 4]], dtype=np.float32) / 10., + validate_args=True) + + scale = dist.scale.to_dense() + + n = int(30e3) + samps = dist.sample(n, seed=0) + sample_mean = math_ops.reduce_mean(samps, 0) + x = samps - sample_mean + sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n + + sample_kl_identity = math_ops.reduce_mean( + dist.log_prob(samps) - mvn_identity.log_prob(samps), 0) + analytical_kl_identity = ds.kl(dist, mvn_identity) + + sample_kl_scaled = math_ops.reduce_mean( + dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0) + analytical_kl_scaled = ds.kl(dist, mvn_scaled) + + sample_kl_diag = math_ops.reduce_mean( + dist.log_prob(samps) - mvn_diag.log_prob(samps), 0) + analytical_kl_diag = ds.kl(dist, mvn_diag) + + sample_kl_chol = math_ops.reduce_mean( + dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) + analytical_kl_chol = ds.kl(dist, mvn_chol) + + n = int(10e3) + baseline = ds.MultivariateNormalDiag( + loc=np.array([-1., 0.25, 1.25], dtype=np.float32), + scale_diag=np.array([1.5, 0.5, 1.], dtype=np.float32), + validate_args=True) + samps = baseline.sample(n, seed=0) + + sample_kl_identity_diag_baseline = math_ops.reduce_mean( + baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0) + analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity) + + sample_kl_scaled_diag_baseline = math_ops.reduce_mean( + baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0) + analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled) + + sample_kl_diag_diag_baseline = math_ops.reduce_mean( + baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0) + analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag) + + sample_kl_chol_diag_baseline = math_ops.reduce_mean( + baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0) + analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol) + + [ + sample_mean_, + analytical_mean_, + sample_covariance_, + analytical_covariance_, + analytical_variance_, + analytical_stddev_, + analytical_log_det_covariance_, + analytical_det_covariance_, + scale_, + sample_kl_identity_, analytical_kl_identity_, + sample_kl_scaled_, analytical_kl_scaled_, + sample_kl_diag_, analytical_kl_diag_, + sample_kl_chol_, analytical_kl_chol_, + sample_kl_identity_diag_baseline_, + analytical_kl_identity_diag_baseline_, + sample_kl_scaled_diag_baseline_, analytical_kl_scaled_diag_baseline_, + sample_kl_diag_diag_baseline_, analytical_kl_diag_diag_baseline_, + sample_kl_chol_diag_baseline_, analytical_kl_chol_diag_baseline_, + ] = sess.run([ + sample_mean, + dist.mean(), + sample_covariance, + dist.covariance(), + dist.variance(), + dist.stddev(), + dist.log_det_covariance(), + dist.det_covariance(), + scale, + sample_kl_identity, analytical_kl_identity, + sample_kl_scaled, analytical_kl_scaled, + sample_kl_diag, analytical_kl_diag, + sample_kl_chol, analytical_kl_chol, + sample_kl_identity_diag_baseline, + analytical_kl_identity_diag_baseline, + sample_kl_scaled_diag_baseline, analytical_kl_scaled_diag_baseline, + sample_kl_diag_diag_baseline, analytical_kl_diag_diag_baseline, + sample_kl_chol_diag_baseline, analytical_kl_chol_diag_baseline, + ]) + + sample_variance_ = np.diag(sample_covariance_) + sample_stddev_ = np.sqrt(sample_variance_) + sample_det_covariance_ = np.linalg.det(sample_covariance_) + sample_log_det_covariance_ = np.log(sample_det_covariance_) + + print("true_mean:\n{} ".format(true_mean)) + print("sample_mean:\n{}".format(sample_mean_)) + print("analytical_mean:\n{}".format(analytical_mean_)) + + print("true_covariance:\n{}".format(true_covariance)) + print("sample_covariance:\n{}".format(sample_covariance_)) + print("analytical_covariance:\n{}".format( + analytical_covariance_)) + + print("true_variance:\n{}".format(true_variance)) + print("sample_variance:\n{}".format(sample_variance_)) + print("analytical_variance:\n{}".format(analytical_variance_)) + + print("true_stddev:\n{}".format(true_stddev)) + print("sample_stddev:\n{}".format(sample_stddev_)) + print("analytical_stddev:\n{}".format(analytical_stddev_)) + + print("true_log_det_covariance:\n{}".format( + true_log_det_covariance)) + print("sample_log_det_covariance:\n{}".format( + sample_log_det_covariance_)) + print("analytical_log_det_covariance:\n{}".format( + analytical_log_det_covariance_)) + + print("true_det_covariance:\n{}".format( + true_det_covariance)) + print("sample_det_covariance:\n{}".format( + sample_det_covariance_)) + print("analytical_det_covariance:\n{}".format( + analytical_det_covariance_)) + + print("true_scale:\n{}".format(true_scale)) + print("scale:\n{}".format(scale_)) + + print("kl_identity: analytical:{} sample:{}".format( + analytical_kl_identity_, sample_kl_identity_)) + + print("kl_scaled: analytical:{} sample:{}".format( + analytical_kl_scaled_, sample_kl_scaled_)) + + print("kl_diag: analytical:{} sample:{}".format( + analytical_kl_diag_, sample_kl_diag_)) + + print("kl_chol: analytical:{} sample:{}".format( + analytical_kl_chol_, sample_kl_chol_)) + + print("kl_identity_diag_baseline: analytical:{} sample:{}".format( + analytical_kl_identity_diag_baseline_, + sample_kl_identity_diag_baseline_)) + + print("kl_scaled_diag_baseline: analytical:{} sample:{}".format( + analytical_kl_scaled_diag_baseline_, + sample_kl_scaled_diag_baseline_)) + + print("kl_diag_diag_baseline: analytical:{} sample:{}".format( + analytical_kl_diag_diag_baseline_, + sample_kl_diag_diag_baseline_)) + + print("kl_chol_diag_baseline: analytical:{} sample:{}".format( + analytical_kl_chol_diag_baseline_, + sample_kl_chol_diag_baseline_)) + + self.assertAllClose(true_mean, sample_mean_, + atol=0., rtol=0.02) + self.assertAllClose(true_mean, analytical_mean_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_covariance, sample_covariance_, + atol=0., rtol=0.02) + self.assertAllClose(true_covariance, analytical_covariance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_variance, sample_variance_, + atol=0., rtol=0.02) + self.assertAllClose(true_variance, analytical_variance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_stddev, sample_stddev_, + atol=0., rtol=0.02) + self.assertAllClose(true_stddev, analytical_stddev_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_, + atol=0., rtol=0.02) + self.assertAllClose(true_log_det_covariance, + analytical_log_det_covariance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_det_covariance, sample_det_covariance_, + atol=0., rtol=0.02) + self.assertAllClose(true_det_covariance, analytical_det_covariance_, + atol=0., rtol=1e-5) + + self.assertAllClose(true_scale, scale_, + atol=0., rtol=1e-6) + + self.assertAllClose(sample_kl_identity_, analytical_kl_identity_, + atol=0., rtol=0.02) + self.assertAllClose(sample_kl_scaled_, analytical_kl_scaled_, + atol=0., rtol=0.02) + self.assertAllClose(sample_kl_diag_, analytical_kl_diag_, + atol=0., rtol=0.02) + self.assertAllClose(sample_kl_chol_, analytical_kl_chol_, + atol=0., rtol=0.02) + + self.assertAllClose( + sample_kl_identity_diag_baseline_, + analytical_kl_identity_diag_baseline_, + atol=0., rtol=0.02) + self.assertAllClose( + sample_kl_scaled_diag_baseline_, + analytical_kl_scaled_diag_baseline_, + atol=0., rtol=0.02) + self.assertAllClose( + sample_kl_diag_diag_baseline_, + analytical_kl_diag_diag_baseline_, + atol=0., rtol=0.04) + self.assertAllClose( + sample_kl_chol_diag_baseline_, + analytical_kl_chol_diag_baseline_, + atol=0., rtol=0.02) + + def testImplicitLargeDiag(self): + mu = np.array([[1., 2, 3], + [11, 22, 33]]) # shape: [b, k] = [2, 3] + u = np.array([[[1., 2], + [3, 4], + [5, 6]], + [[0.5, 0.75], + [1, 0.25], + [1.5, 1.25]]]) # shape: [b, k, r] = [2, 3, 2] + m = np.array([[0.1, 0.2], + [0.4, 0.5]]) # shape: [b, r] = [2, 2] + scale = np.stack([ + np.eye(3) + np.matmul(np.matmul(u[0], np.diag(m[0])), + np.transpose(u[0])), + np.eye(3) + np.matmul(np.matmul(u[1], np.diag(m[1])), + np.transpose(u[1])), + ]) + cov = np.stack([np.matmul(scale[0], scale[0].T), + np.matmul(scale[1], scale[1].T)]) + print("expected_cov:\n{}".format(cov)) + with self.test_session(): + mvn = ds.MultivariateNormalDiagPlusLowRank( + loc=mu, + scale_perturb_factor=u, + scale_perturb_diag=m) + self.assertAllClose(cov, mvn.covariance().eval(), atol=0., rtol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py new file mode 100644 index 0000000000..3838cccb22 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -0,0 +1,181 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultivariateNormal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +from tensorflow.contrib import distributions +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +ds = distributions + + +class MultivariateNormalDiagTest(test.TestCase): + """Well tested because this is a simple override of the base class.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testScalarParams(self): + mu = -1. + diag = -5. + with self.test_session(): + # TODO(b/35244539): Choose better exception, once LinOp throws it. + with self.assertRaises(IndexError): + ds.MultivariateNormalDiag(mu, diag) + + def testVectorParams(self): + mu = [-1.] + diag = [-5.] + with self.test_session(): + dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + self.assertAllEqual([3, 1], dist.sample(3).get_shape()) + + def testMean(self): + mu = [-1., 1] + diag = [1., -5] + with self.test_session(): + dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + self.assertAllEqual(mu, dist.mean().eval()) + + def testEntropy(self): + mu = [-1., 1] + diag = [-1., 5] + diag_mat = np.diag(diag) + scipy_mvn = stats.multivariate_normal(mean=mu, cov=diag_mat**2) + with self.test_session(): + dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + self.assertAllClose(scipy_mvn.entropy(), dist.entropy().eval(), atol=1e-4) + + def testSample(self): + mu = [-1., 1] + diag = [1., -2] + with self.test_session(): + dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + samps = dist.sample(int(1e3), seed=0).eval() + cov_mat = array_ops.matrix_diag(diag).eval()**2 + + self.assertAllClose(mu, samps.mean(axis=0), + atol=0., rtol=0.05) + self.assertAllClose(cov_mat, np.cov(samps.T), + atol=0.05, rtol=0.05) + + def testCovariance(self): + with self.test_session(): + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) + self.assertAllClose( + np.diag(np.ones([3], dtype=np.float32)), + mvn.covariance().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_identity_multiplier=[3., 2.]) + self.assertAllEqual([2], mvn.batch_shape) + self.assertAllEqual([3], mvn.event_shape) + self.assertAllClose( + np.array([[[3., 0, 0], + [0, 3, 0], + [0, 0, 3]], + [[2, 0, 0], + [0, 2, 0], + [0, 0, 2]]])**2., + mvn.covariance().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_diag=[[3., 2, 1], [4, 5, 6]]) + self.assertAllEqual([2], mvn.batch_shape) + self.assertAllEqual([3], mvn.event_shape) + self.assertAllClose( + np.array([[[3., 0, 0], + [0, 2, 0], + [0, 0, 1]], + [[4, 0, 0], + [0, 5, 0], + [0, 0, 6]]])**2., + mvn.covariance().eval()) + + def testVariance(self): + with self.test_session(): + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) + self.assertAllClose( + np.ones([3], dtype=np.float32), + mvn.variance().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_identity_multiplier=[3., 2.]) + self.assertAllClose( + np.array([[3., 3, 3], + [2, 2, 2]])**2., + mvn.variance().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_diag=[[3., 2, 1], + [4, 5, 6]]) + self.assertAllClose( + np.array([[3., 2, 1], + [4, 5, 6]])**2., + mvn.variance().eval()) + + def testStddev(self): + with self.test_session(): + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) + self.assertAllClose( + np.ones([3], dtype=np.float32), + mvn.stddev().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_identity_multiplier=[3., 2.]) + self.assertAllClose( + np.array([[3., 3, 3], + [2, 2, 2]]), + mvn.stddev().eval()) + + mvn = ds.MultivariateNormalDiag( + loc=array_ops.zeros([3], dtype=dtypes.float32), + scale_diag=[[3., 2, 1], [4, 5, 6]]) + self.assertAllClose( + np.array([[3., 2, 1], + [4, 5, 6]]), + mvn.stddev().eval()) + + def testMultivariateNormalDiagWithSoftplusScale(self): + mu = [-1.0, 1.0] + diag = [-1.0, -2.0] + with self.test_session(): + dist = ds.MultivariateNormalDiagWithSoftplusScale( + mu, diag, validate_args=True) + samps = dist.sample(1000, seed=0).eval() + cov_mat = array_ops.matrix_diag(nn_ops.softplus(diag)).eval()**2 + + self.assertAllClose(mu, samps.mean(axis=0), atol=0.1) + self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py deleted file mode 100644 index 79fdb149b4..0000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py +++ /dev/null @@ -1,981 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MultivariateNormal.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from scipy import stats -from tensorflow.contrib import distributions -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.platform import test - - -ds = distributions - - -class MultivariateNormalDiagTest(test.TestCase): - """Well tested because this is a simple override of the base class.""" - - def setUp(self): - self._rng = np.random.RandomState(42) - - def testScalarParams(self): - mu = -1. - diag = -5. - with self.test_session(): - # TODO(b/35244539): Choose better exception, once LinOp throws it. - with self.assertRaises(IndexError): - ds.MultivariateNormalDiag(mu, diag) - - def testVectorParams(self): - mu = [-1.] - diag = [-5.] - with self.test_session(): - dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) - self.assertAllEqual([3, 1], dist.sample(3).get_shape()) - - def testMean(self): - mu = [-1., 1] - diag = [1., -5] - with self.test_session(): - dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) - self.assertAllEqual(mu, dist.mean().eval()) - - def testEntropy(self): - mu = [-1., 1] - diag = [-1., 5] - diag_mat = np.diag(diag) - scipy_mvn = stats.multivariate_normal(mean=mu, cov=diag_mat**2) - with self.test_session(): - dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) - self.assertAllClose(scipy_mvn.entropy(), dist.entropy().eval(), atol=1e-4) - - def testSample(self): - mu = [-1., 1] - diag = [1., -2] - with self.test_session(): - dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) - samps = dist.sample(int(1e3), seed=0).eval() - cov_mat = array_ops.matrix_diag(diag).eval()**2 - - self.assertAllClose(mu, samps.mean(axis=0), - atol=0., rtol=0.05) - self.assertAllClose(cov_mat, np.cov(samps.T), - atol=0.05, rtol=0.05) - - def testCovariance(self): - with self.test_session(): - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) - self.assertAllClose( - np.diag(np.ones([3], dtype=np.float32)), - mvn.covariance().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_identity_multiplier=[3., 2.]) - self.assertAllEqual([2], mvn.batch_shape) - self.assertAllEqual([3], mvn.event_shape) - self.assertAllClose( - np.array([[[3., 0, 0], - [0, 3, 0], - [0, 0, 3]], - [[2, 0, 0], - [0, 2, 0], - [0, 0, 2]]])**2., - mvn.covariance().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_diag=[[3., 2, 1], [4, 5, 6]]) - self.assertAllEqual([2], mvn.batch_shape) - self.assertAllEqual([3], mvn.event_shape) - self.assertAllClose( - np.array([[[3., 0, 0], - [0, 2, 0], - [0, 0, 1]], - [[4, 0, 0], - [0, 5, 0], - [0, 0, 6]]])**2., - mvn.covariance().eval()) - - def testVariance(self): - with self.test_session(): - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) - self.assertAllClose( - np.ones([3], dtype=np.float32), - mvn.variance().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_identity_multiplier=[3., 2.]) - self.assertAllClose( - np.array([[3., 3, 3], - [2, 2, 2]])**2., - mvn.variance().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_diag=[[3., 2, 1], - [4, 5, 6]]) - self.assertAllClose( - np.array([[3., 2, 1], - [4, 5, 6]])**2., - mvn.variance().eval()) - - def testStddev(self): - with self.test_session(): - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) - self.assertAllClose( - np.ones([3], dtype=np.float32), - mvn.stddev().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_identity_multiplier=[3., 2.]) - self.assertAllClose( - np.array([[3., 3, 3], - [2, 2, 2]]), - mvn.stddev().eval()) - - mvn = ds.MultivariateNormalDiag( - loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_diag=[[3., 2, 1], [4, 5, 6]]) - self.assertAllClose( - np.array([[3., 2, 1], - [4, 5, 6]]), - mvn.stddev().eval()) - - def testMultivariateNormalDiagWithSoftplusScale(self): - mu = [-1.0, 1.0] - diag = [-1.0, -2.0] - with self.test_session(): - dist = ds.MultivariateNormalDiagWithSoftplusScale( - mu, diag, validate_args=True) - samps = dist.sample(1000, seed=0).eval() - cov_mat = array_ops.matrix_diag(nn_ops.softplus(diag)).eval()**2 - - self.assertAllClose(mu, samps.mean(axis=0), atol=0.1) - self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) - - -class MultivariateNormalDiagPlusLowRankTest(test.TestCase): - """Well tested because this is a simple override of the base class.""" - - def setUp(self): - self._rng = np.random.RandomState(42) - - def testDiagBroadcastBothBatchAndEvent(self): - # batch_shape: [3], event_shape: [2] - diag = np.array([[1., 2], [3, 4], [5, 6]]) - # batch_shape: [1], event_shape: [] - identity_multiplier = np.array([5.]) - with self.test_session(): - dist = ds.MultivariateNormalDiagPlusLowRank( - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=True) - self.assertAllClose( - np.array([[[1. + 5, 0], - [0, 2 + 5]], - [[3 + 5, 0], - [0, 4 + 5]], - [[5 + 5, 0], - [0, 6 + 5]]]), - dist.scale.to_dense().eval()) - - def testDiagBroadcastBothBatchAndEvent2(self): - # This test differs from `testDiagBroadcastBothBatchAndEvent` in that it - # broadcasts batch_shape's from both the `scale_diag` and - # `scale_identity_multiplier` args. - # batch_shape: [3], event_shape: [2] - diag = np.array([[1., 2], [3, 4], [5, 6]]) - # batch_shape: [3, 1], event_shape: [] - identity_multiplier = np.array([[5.], [4], [3]]) - with self.test_session(): - dist = ds.MultivariateNormalDiagPlusLowRank( - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=True) - self.assertAllEqual( - [3, 3, 2, 2], - dist.scale.to_dense().get_shape()) - - def testDiagBroadcastOnlyEvent(self): - # batch_shape: [3], event_shape: [2] - diag = np.array([[1., 2], [3, 4], [5, 6]]) - # batch_shape: [3], event_shape: [] - identity_multiplier = np.array([5., 4, 3]) - with self.test_session(): - dist = ds.MultivariateNormalDiagPlusLowRank( - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=True) - self.assertAllClose( - np.array([[[1. + 5, 0], - [0, 2 + 5]], - [[3 + 4, 0], - [0, 4 + 4]], - [[5 + 3, 0], - [0, 6 + 3]]]), # shape: [3, 2, 2] - dist.scale.to_dense().eval()) - - def testDiagBroadcastMultiplierAndLoc(self): - # batch_shape: [], event_shape: [3] - loc = np.array([1., 0, -1]) - # batch_shape: [3], event_shape: [] - identity_multiplier = np.array([5., 4, 3]) - with self.test_session(): - dist = ds.MultivariateNormalDiagPlusLowRank( - loc=loc, - scale_identity_multiplier=identity_multiplier, - validate_args=True) - self.assertAllClose( - np.array([[[5, 0, 0], - [0, 5, 0], - [0, 0, 5]], - [[4, 0, 0], - [0, 4, 0], - [0, 0, 4]], - [[3, 0, 0], - [0, 3, 0], - [0, 0, 3]]]), - dist.scale.to_dense().eval()) - - def testMean(self): - mu = [-1.0, 1.0] - diag_large = [1.0, 5.0] - v = [[2.0], [3.0]] - diag_small = [3.0] - with self.test_session(): - dist = ds.MultivariateNormalDiagPlusLowRank( - loc=mu, - scale_diag=diag_large, - scale_perturb_factor=v, - scale_perturb_diag=diag_small, - validate_args=True) - self.assertAllEqual(mu, dist.mean().eval()) - - def testSample(self): - # TODO(jvdillon): This test should be the basis of a new test fixture which - # is applied to every distribution. When we make this fixture, we'll also - # separate the analytical- and sample-based tests as well as for each - # function tested. For now, we group things so we can recycle one batch of - # samples (thus saving resources). - - mu = np.array([-1., 1, 0.5], dtype=np.float32) - diag_large = np.array([1., 0.5, 0.75], dtype=np.float32) - diag_small = np.array([-1.1, 1.2], dtype=np.float32) - v = np.array([[0.7, 0.8], - [0.9, 1], - [0.5, 0.6]], dtype=np.float32) # shape: [k, r] = [3, 2] - - true_mean = mu - true_scale = np.diag(diag_large) + np.matmul(np.matmul( - v, np.diag(diag_small)), v.T) - true_covariance = np.matmul(true_scale, true_scale.T) - true_variance = np.diag(true_covariance) - true_stddev = np.sqrt(true_variance) - true_det_covariance = np.linalg.det(true_covariance) - true_log_det_covariance = np.log(true_det_covariance) - - with self.test_session() as sess: - dist = ds.MultivariateNormalDiagPlusLowRank( - loc=mu, - scale_diag=diag_large, - scale_perturb_factor=v, - scale_perturb_diag=diag_small, - validate_args=True) - - # The following distributions will test the KL divergence calculation. - mvn_identity = ds.MultivariateNormalDiag( - loc=np.array([1., 2, 0.25], dtype=np.float32), - validate_args=True) - mvn_scaled = ds.MultivariateNormalDiag( - loc=mvn_identity.loc, - scale_identity_multiplier=2.2, - validate_args=True) - mvn_diag = ds.MultivariateNormalDiag( - loc=mvn_identity.loc, - scale_diag=np.array([0.5, 1.5, 1.], dtype=np.float32), - validate_args=True) - mvn_chol = ds.MultivariateNormalTriL( - loc=np.array([1., 2, -1], dtype=np.float32), - scale_tril=np.array([[6., 0, 0], - [2, 5, 0], - [1, 3, 4]], dtype=np.float32) / 10., - validate_args=True) - - scale = dist.scale.to_dense() - - n = int(30e3) - samps = dist.sample(n, seed=0) - sample_mean = math_ops.reduce_mean(samps, 0) - x = samps - sample_mean - sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n - - sample_kl_identity = math_ops.reduce_mean( - dist.log_prob(samps) - mvn_identity.log_prob(samps), 0) - analytical_kl_identity = ds.kl(dist, mvn_identity) - - sample_kl_scaled = math_ops.reduce_mean( - dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0) - analytical_kl_scaled = ds.kl(dist, mvn_scaled) - - sample_kl_diag = math_ops.reduce_mean( - dist.log_prob(samps) - mvn_diag.log_prob(samps), 0) - analytical_kl_diag = ds.kl(dist, mvn_diag) - - sample_kl_chol = math_ops.reduce_mean( - dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol = ds.kl(dist, mvn_chol) - - n = int(10e3) - baseline = ds.MultivariateNormalDiag( - loc=np.array([-1., 0.25, 1.25], dtype=np.float32), - scale_diag=np.array([1.5, 0.5, 1.], dtype=np.float32), - validate_args=True) - samps = baseline.sample(n, seed=0) - - sample_kl_identity_diag_baseline = math_ops.reduce_mean( - baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0) - analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity) - - sample_kl_scaled_diag_baseline = math_ops.reduce_mean( - baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0) - analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled) - - sample_kl_diag_diag_baseline = math_ops.reduce_mean( - baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0) - analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag) - - sample_kl_chol_diag_baseline = math_ops.reduce_mean( - baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol) - - [ - sample_mean_, - analytical_mean_, - sample_covariance_, - analytical_covariance_, - analytical_variance_, - analytical_stddev_, - analytical_log_det_covariance_, - analytical_det_covariance_, - scale_, - sample_kl_identity_, analytical_kl_identity_, - sample_kl_scaled_, analytical_kl_scaled_, - sample_kl_diag_, analytical_kl_diag_, - sample_kl_chol_, analytical_kl_chol_, - sample_kl_identity_diag_baseline_, - analytical_kl_identity_diag_baseline_, - sample_kl_scaled_diag_baseline_, analytical_kl_scaled_diag_baseline_, - sample_kl_diag_diag_baseline_, analytical_kl_diag_diag_baseline_, - sample_kl_chol_diag_baseline_, analytical_kl_chol_diag_baseline_, - ] = sess.run([ - sample_mean, - dist.mean(), - sample_covariance, - dist.covariance(), - dist.variance(), - dist.stddev(), - dist.log_det_covariance(), - dist.det_covariance(), - scale, - sample_kl_identity, analytical_kl_identity, - sample_kl_scaled, analytical_kl_scaled, - sample_kl_diag, analytical_kl_diag, - sample_kl_chol, analytical_kl_chol, - sample_kl_identity_diag_baseline, - analytical_kl_identity_diag_baseline, - sample_kl_scaled_diag_baseline, analytical_kl_scaled_diag_baseline, - sample_kl_diag_diag_baseline, analytical_kl_diag_diag_baseline, - sample_kl_chol_diag_baseline, analytical_kl_chol_diag_baseline, - ]) - - sample_variance_ = np.diag(sample_covariance_) - sample_stddev_ = np.sqrt(sample_variance_) - sample_det_covariance_ = np.linalg.det(sample_covariance_) - sample_log_det_covariance_ = np.log(sample_det_covariance_) - - print("true_mean:\n{} ".format(true_mean)) - print("sample_mean:\n{}".format(sample_mean_)) - print("analytical_mean:\n{}".format(analytical_mean_)) - - print("true_covariance:\n{}".format(true_covariance)) - print("sample_covariance:\n{}".format(sample_covariance_)) - print("analytical_covariance:\n{}".format( - analytical_covariance_)) - - print("true_variance:\n{}".format(true_variance)) - print("sample_variance:\n{}".format(sample_variance_)) - print("analytical_variance:\n{}".format(analytical_variance_)) - - print("true_stddev:\n{}".format(true_stddev)) - print("sample_stddev:\n{}".format(sample_stddev_)) - print("analytical_stddev:\n{}".format(analytical_stddev_)) - - print("true_log_det_covariance:\n{}".format( - true_log_det_covariance)) - print("sample_log_det_covariance:\n{}".format( - sample_log_det_covariance_)) - print("analytical_log_det_covariance:\n{}".format( - analytical_log_det_covariance_)) - - print("true_det_covariance:\n{}".format( - true_det_covariance)) - print("sample_det_covariance:\n{}".format( - sample_det_covariance_)) - print("analytical_det_covariance:\n{}".format( - analytical_det_covariance_)) - - print("true_scale:\n{}".format(true_scale)) - print("scale:\n{}".format(scale_)) - - print("kl_identity: analytical:{} sample:{}".format( - analytical_kl_identity_, sample_kl_identity_)) - - print("kl_scaled: analytical:{} sample:{}".format( - analytical_kl_scaled_, sample_kl_scaled_)) - - print("kl_diag: analytical:{} sample:{}".format( - analytical_kl_diag_, sample_kl_diag_)) - - print("kl_chol: analytical:{} sample:{}".format( - analytical_kl_chol_, sample_kl_chol_)) - - print("kl_identity_diag_baseline: analytical:{} sample:{}".format( - analytical_kl_identity_diag_baseline_, - sample_kl_identity_diag_baseline_)) - - print("kl_scaled_diag_baseline: analytical:{} sample:{}".format( - analytical_kl_scaled_diag_baseline_, - sample_kl_scaled_diag_baseline_)) - - print("kl_diag_diag_baseline: analytical:{} sample:{}".format( - analytical_kl_diag_diag_baseline_, - sample_kl_diag_diag_baseline_)) - - print("kl_chol_diag_baseline: analytical:{} sample:{}".format( - analytical_kl_chol_diag_baseline_, - sample_kl_chol_diag_baseline_)) - - self.assertAllClose(true_mean, sample_mean_, - atol=0., rtol=0.02) - self.assertAllClose(true_mean, analytical_mean_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_covariance, sample_covariance_, - atol=0., rtol=0.02) - self.assertAllClose(true_covariance, analytical_covariance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_variance, sample_variance_, - atol=0., rtol=0.02) - self.assertAllClose(true_variance, analytical_variance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_stddev, sample_stddev_, - atol=0., rtol=0.02) - self.assertAllClose(true_stddev, analytical_stddev_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_, - atol=0., rtol=0.02) - self.assertAllClose(true_log_det_covariance, - analytical_log_det_covariance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_det_covariance, sample_det_covariance_, - atol=0., rtol=0.02) - self.assertAllClose(true_det_covariance, analytical_det_covariance_, - atol=0., rtol=1e-5) - - self.assertAllClose(true_scale, scale_, - atol=0., rtol=1e-6) - - self.assertAllClose(sample_kl_identity_, analytical_kl_identity_, - atol=0., rtol=0.02) - self.assertAllClose(sample_kl_scaled_, analytical_kl_scaled_, - atol=0., rtol=0.02) - self.assertAllClose(sample_kl_diag_, analytical_kl_diag_, - atol=0., rtol=0.02) - self.assertAllClose(sample_kl_chol_, analytical_kl_chol_, - atol=0., rtol=0.02) - - self.assertAllClose( - sample_kl_identity_diag_baseline_, - analytical_kl_identity_diag_baseline_, - atol=0., rtol=0.02) - self.assertAllClose( - sample_kl_scaled_diag_baseline_, - analytical_kl_scaled_diag_baseline_, - atol=0., rtol=0.02) - self.assertAllClose( - sample_kl_diag_diag_baseline_, - analytical_kl_diag_diag_baseline_, - atol=0., rtol=0.04) - self.assertAllClose( - sample_kl_chol_diag_baseline_, - analytical_kl_chol_diag_baseline_, - atol=0., rtol=0.02) - - def testImplicitLargeDiag(self): - mu = np.array([[1., 2, 3], - [11, 22, 33]]) # shape: [b, k] = [2, 3] - u = np.array([[[1., 2], - [3, 4], - [5, 6]], - [[0.5, 0.75], - [1, 0.25], - [1.5, 1.25]]]) # shape: [b, k, r] = [2, 3, 2] - m = np.array([[0.1, 0.2], - [0.4, 0.5]]) # shape: [b, r] = [2, 2] - scale = np.stack([ - np.eye(3) + np.matmul(np.matmul(u[0], np.diag(m[0])), - np.transpose(u[0])), - np.eye(3) + np.matmul(np.matmul(u[1], np.diag(m[1])), - np.transpose(u[1])), - ]) - cov = np.stack([np.matmul(scale[0], scale[0].T), - np.matmul(scale[1], scale[1].T)]) - print("expected_cov:\n{}".format(cov)) - with self.test_session(): - mvn = ds.MultivariateNormalDiagPlusLowRank( - loc=mu, - scale_perturb_factor=u, - scale_perturb_diag=m) - self.assertAllClose(cov, mvn.covariance().eval(), atol=0., rtol=1e-6) - - -class MultivariateNormalTriLTest(test.TestCase): - - def setUp(self): - self._rng = np.random.RandomState(42) - - def _random_chol(self, *shape): - mat = self._rng.rand(*shape) - chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus) - chol = array_ops.matrix_band_part(chol, -1, 0) - sigma = math_ops.matmul(chol, chol, adjoint_b=True) - return chol.eval(), sigma.eval() - - def testLogPDFScalarBatch(self): - with self.test_session(): - mu = self._rng.rand(2) - chol, sigma = self._random_chol(2, 2) - chol[1, 1] = -chol[1, 1] - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - x = self._rng.rand(2) - - log_pdf = mvn.log_prob(x) - pdf = mvn.prob(x) - - scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) - - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertEqual((), log_pdf.get_shape()) - self.assertEqual((), pdf.get_shape()) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - - def testLogPDFXIsHigherRank(self): - with self.test_session(): - mu = self._rng.rand(2) - chol, sigma = self._random_chol(2, 2) - chol[0, 0] = -chol[0, 0] - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - x = self._rng.rand(3, 2) - - log_pdf = mvn.log_prob(x) - pdf = mvn.prob(x) - - scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) - - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertEqual((3,), log_pdf.get_shape()) - self.assertEqual((3,), pdf.get_shape()) - self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02) - self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03) - - def testLogPDFXLowerDimension(self): - with self.test_session(): - mu = self._rng.rand(3, 2) - chol, sigma = self._random_chol(3, 2, 2) - chol[0, 0, 0] = -chol[0, 0, 0] - chol[2, 1, 1] = -chol[2, 1, 1] - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - x = self._rng.rand(2) - - log_pdf = mvn.log_prob(x) - pdf = mvn.prob(x) - - self.assertEqual((3,), log_pdf.get_shape()) - self.assertEqual((3,), pdf.get_shape()) - - # scipy can't do batches, so just test one of them. - scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :]) - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - - self.assertAllClose(expected_log_pdf, log_pdf.eval()[1]) - self.assertAllClose(expected_pdf, pdf.eval()[1]) - - def testEntropy(self): - with self.test_session(): - mu = self._rng.rand(2) - chol, sigma = self._random_chol(2, 2) - chol[0, 0] = -chol[0, 0] - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - entropy = mvn.entropy() - - scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) - expected_entropy = scipy_mvn.entropy() - self.assertEqual(entropy.get_shape(), ()) - self.assertAllClose(expected_entropy, entropy.eval()) - - def testEntropyMultidimensional(self): - with self.test_session(): - mu = self._rng.rand(3, 5, 2) - chol, sigma = self._random_chol(3, 5, 2, 2) - chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] - chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - entropy = mvn.entropy() - - # Scipy doesn't do batches, so test one of them. - expected_entropy = stats.multivariate_normal( - mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy() - self.assertEqual(entropy.get_shape(), (3, 5)) - self.assertAllClose(expected_entropy, entropy.eval()[1, 1]) - - def testSample(self): - with self.test_session(): - mu = self._rng.rand(2) - chol, sigma = self._random_chol(2, 2) - chol[0, 0] = -chol[0, 0] - sigma[0, 1] = -sigma[0, 1] - sigma[1, 0] = -sigma[1, 0] - - n = constant_op.constant(100000) - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - samples = mvn.sample(n, seed=137) - sample_values = samples.eval() - self.assertEqual(samples.get_shape(), [int(100e3), 2]) - self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2) - self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06) - - def testSampleWithSampleShape(self): - with self.test_session(): - mu = self._rng.rand(3, 5, 2) - chol, sigma = self._random_chol(3, 5, 2, 2) - chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] - chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] - - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - samples_val = mvn.sample((10, 11, 12), seed=137).eval() - - # Check sample shape - self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape) - - # Check sample means - x = samples_val[:, :, :, 1, 1, :] - self.assertAllClose( - x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05) - - # Check that log_prob(samples) works - log_prob_val = mvn.log_prob(samples_val).eval() - x_log_pdf = log_prob_val[:, :, :, 1, 1] - expected_log_pdf = stats.multivariate_normal( - mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x) - self.assertAllClose(expected_log_pdf, x_log_pdf) - - def testSampleMultiDimensional(self): - with self.test_session(): - mu = self._rng.rand(3, 5, 2) - chol, sigma = self._random_chol(3, 5, 2, 2) - chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] - chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] - - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - n = constant_op.constant(100000) - samples = mvn.sample(n, seed=137) - sample_values = samples.eval() - - self.assertEqual(samples.get_shape(), (100000, 3, 5, 2)) - self.assertAllClose( - sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05) - self.assertAllClose( - np.cov(sample_values[:, 1, 1, :], rowvar=0), - sigma[1, 1, :, :], - atol=1e-1) - - def testShapes(self): - with self.test_session(): - mu = self._rng.rand(3, 5, 2) - chol, _ = self._random_chol(3, 5, 2, 2) - chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] - chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] - - mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) - - # Shapes known at graph construction time. - self.assertEqual((2,), tuple(mvn.event_shape.as_list())) - self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list())) - - # Shapes known at runtime. - self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval())) - self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval())) - - def _random_mu_and_sigma(self, batch_shape, event_shape): - # This ensures sigma is positive def. - mat_shape = batch_shape + event_shape + event_shape - mat = self._rng.randn(*mat_shape) - perm = np.arange(mat.ndim) - perm[-2:] = [perm[-1], perm[-2]] - sigma = np.matmul(mat, np.transpose(mat, perm)) - - mu_shape = batch_shape + event_shape - mu = self._rng.randn(*mu_shape) - - return mu, sigma - - def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) - with self.test_session(): - mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) - mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) - mvn_a = ds.MultivariateNormalTriL( - loc=mu_a, - scale_tril=np.linalg.cholesky(sigma_a), - validate_args=True) - mvn_b = ds.MultivariateNormalTriL( - loc=mu_b, - scale_tril=np.linalg.cholesky(sigma_b), - validate_args=True) - - kl = ds.kl(mvn_a, mvn_b) - self.assertEqual(batch_shape, kl.get_shape()) - - kl_v = kl.eval() - expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b) - self.assertAllClose(expected_kl, kl_v) - - def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) - with self.test_session(): - mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) - mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) - mvn_a = ds.MultivariateNormalTriL( - loc=mu_a, - scale_tril=np.linalg.cholesky(sigma_a), - validate_args=True) - mvn_b = ds.MultivariateNormalTriL( - loc=mu_b, - scale_tril=np.linalg.cholesky(sigma_b), - validate_args=True) - - kl = ds.kl(mvn_a, mvn_b) - self.assertEqual(batch_shape, kl.get_shape()) - - kl_v = kl.eval() - expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], - mu_b[0, :], sigma_b[0, :]) - expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], - mu_b[1, :], sigma_b[1, :]) - self.assertAllClose(expected_kl_0, kl_v[0]) - self.assertAllClose(expected_kl_1, kl_v[1]) - - def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) - with self.test_session(): - mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) - mvn_a = ds.MultivariateNormalTriL( - loc=mu_a, - scale_tril=np.linalg.cholesky(sigma_a), - validate_args=True) - - # Should be zero since KL(p || p) = =. - kl = ds.kl(mvn_a, mvn_a) - self.assertEqual(batch_shape, kl.get_shape()) - - kl_v = kl.eval() - self.assertAllClose(np.zeros(*batch_shape), kl_v) - - def testSampleLarge(self): - mu = np.array([-1., 1], dtype=np.float32) - scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3. - - true_mean = mu - true_scale = scale_tril - true_covariance = np.matmul(true_scale, true_scale.T) - true_variance = np.diag(true_covariance) - true_stddev = np.sqrt(true_variance) - true_det_covariance = np.linalg.det(true_covariance) - true_log_det_covariance = np.log(true_det_covariance) - - with self.test_session() as sess: - dist = ds.MultivariateNormalTriL( - loc=mu, - scale_tril=scale_tril, - validate_args=True) - - # The following distributions will test the KL divergence calculation. - mvn_chol = ds.MultivariateNormalTriL( - loc=np.array([0.5, 1.2], dtype=np.float32), - scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32), - validate_args=True) - - n = int(10e3) - samps = dist.sample(n, seed=0) - sample_mean = math_ops.reduce_mean(samps, 0) - x = samps - sample_mean - sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n - - sample_kl_chol = math_ops.reduce_mean( - dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol = ds.kl(dist, mvn_chol) - - scale = dist.scale.to_dense() - - [ - sample_mean_, - analytical_mean_, - sample_covariance_, - analytical_covariance_, - analytical_variance_, - analytical_stddev_, - analytical_log_det_covariance_, - analytical_det_covariance_, - sample_kl_chol_, analytical_kl_chol_, - scale_, - ] = sess.run([ - sample_mean, - dist.mean(), - sample_covariance, - dist.covariance(), - dist.variance(), - dist.stddev(), - dist.log_det_covariance(), - dist.det_covariance(), - sample_kl_chol, analytical_kl_chol, - scale, - ]) - - sample_variance_ = np.diag(sample_covariance_) - sample_stddev_ = np.sqrt(sample_variance_) - sample_det_covariance_ = np.linalg.det(sample_covariance_) - sample_log_det_covariance_ = np.log(sample_det_covariance_) - - print("true_mean:\n{} ".format(true_mean)) - print("sample_mean:\n{}".format(sample_mean_)) - print("analytical_mean:\n{}".format(analytical_mean_)) - - print("true_covariance:\n{}".format(true_covariance)) - print("sample_covariance:\n{}".format(sample_covariance_)) - print("analytical_covariance:\n{}".format(analytical_covariance_)) - - print("true_variance:\n{}".format(true_variance)) - print("sample_variance:\n{}".format(sample_variance_)) - print("analytical_variance:\n{}".format(analytical_variance_)) - - print("true_stddev:\n{}".format(true_stddev)) - print("sample_stddev:\n{}".format(sample_stddev_)) - print("analytical_stddev:\n{}".format(analytical_stddev_)) - - print("true_log_det_covariance:\n{}".format(true_log_det_covariance)) - print("sample_log_det_covariance:\n{}".format(sample_log_det_covariance_)) - print("analytical_log_det_covariance:\n{}".format( - analytical_log_det_covariance_)) - - print("true_det_covariance:\n{}".format(true_det_covariance)) - print("sample_det_covariance:\n{}".format(sample_det_covariance_)) - print("analytical_det_covariance:\n{}".format(analytical_det_covariance_)) - - print("true_scale:\n{}".format(true_scale)) - print("scale:\n{}".format(scale_)) - - print("kl_chol: analytical:{} sample:{}".format( - analytical_kl_chol_, sample_kl_chol_)) - - self.assertAllClose(true_mean, sample_mean_, - atol=0., rtol=0.03) - self.assertAllClose(true_mean, analytical_mean_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_covariance, sample_covariance_, - atol=0., rtol=0.03) - self.assertAllClose(true_covariance, analytical_covariance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_variance, sample_variance_, - atol=0., rtol=0.02) - self.assertAllClose(true_variance, analytical_variance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_stddev, sample_stddev_, - atol=0., rtol=0.01) - self.assertAllClose(true_stddev, analytical_stddev_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_, - atol=0., rtol=0.04) - self.assertAllClose(true_log_det_covariance, - analytical_log_det_covariance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_det_covariance, sample_det_covariance_, - atol=0., rtol=0.03) - self.assertAllClose(true_det_covariance, analytical_det_covariance_, - atol=0., rtol=1e-6) - - self.assertAllClose(true_scale, scale_, - atol=0., rtol=1e-6) - - self.assertAllClose(sample_kl_chol_, analytical_kl_chol_, - atol=0., rtol=0.02) - - -def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): - """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" - # Check using numpy operations - # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy. - # So it is important to also check that KL(mvn, mvn) = 0. - sigma_b_inv = np.linalg.inv(sigma_b) - - t = np.trace(sigma_b_inv.dot(sigma_a)) - q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a) - k = mu_a.shape[0] - l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a)) - - return 0.5 * (t + q - k + l) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py new file mode 100644 index 0000000000..994b9877fb --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -0,0 +1,443 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultivariateNormal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +from tensorflow.contrib import distributions +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +ds = distributions + + +class MultivariateNormalTriLTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_chol(self, *shape): + mat = self._rng.rand(*shape) + chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus) + chol = array_ops.matrix_band_part(chol, -1, 0) + sigma = math_ops.matmul(chol, chol, adjoint_b=True) + return chol.eval(), sigma.eval() + + def testLogPDFScalarBatch(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + chol[1, 1] = -chol[1, 1] + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + x = self._rng.rand(2) + + log_pdf = mvn.log_prob(x) + pdf = mvn.prob(x) + + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) + + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual((), log_pdf.get_shape()) + self.assertEqual((), pdf.get_shape()) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDFXIsHigherRank(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + chol[0, 0] = -chol[0, 0] + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + x = self._rng.rand(3, 2) + + log_pdf = mvn.log_prob(x) + pdf = mvn.prob(x) + + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) + + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual((3,), log_pdf.get_shape()) + self.assertEqual((3,), pdf.get_shape()) + self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02) + self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03) + + def testLogPDFXLowerDimension(self): + with self.test_session(): + mu = self._rng.rand(3, 2) + chol, sigma = self._random_chol(3, 2, 2) + chol[0, 0, 0] = -chol[0, 0, 0] + chol[2, 1, 1] = -chol[2, 1, 1] + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + x = self._rng.rand(2) + + log_pdf = mvn.log_prob(x) + pdf = mvn.prob(x) + + self.assertEqual((3,), log_pdf.get_shape()) + self.assertEqual((3,), pdf.get_shape()) + + # scipy can't do batches, so just test one of them. + scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :]) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + + self.assertAllClose(expected_log_pdf, log_pdf.eval()[1]) + self.assertAllClose(expected_pdf, pdf.eval()[1]) + + def testEntropy(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + chol[0, 0] = -chol[0, 0] + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + entropy = mvn.entropy() + + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) + expected_entropy = scipy_mvn.entropy() + self.assertEqual(entropy.get_shape(), ()) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testEntropyMultidimensional(self): + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, sigma = self._random_chol(3, 5, 2, 2) + chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] + chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + entropy = mvn.entropy() + + # Scipy doesn't do batches, so test one of them. + expected_entropy = stats.multivariate_normal( + mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy() + self.assertEqual(entropy.get_shape(), (3, 5)) + self.assertAllClose(expected_entropy, entropy.eval()[1, 1]) + + def testSample(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + chol[0, 0] = -chol[0, 0] + sigma[0, 1] = -sigma[0, 1] + sigma[1, 0] = -sigma[1, 0] + + n = constant_op.constant(100000) + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + samples = mvn.sample(n, seed=137) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), [int(100e3), 2]) + self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2) + self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06) + + def testSampleWithSampleShape(self): + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, sigma = self._random_chol(3, 5, 2, 2) + chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] + chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] + + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + samples_val = mvn.sample((10, 11, 12), seed=137).eval() + + # Check sample shape + self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape) + + # Check sample means + x = samples_val[:, :, :, 1, 1, :] + self.assertAllClose( + x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05) + + # Check that log_prob(samples) works + log_prob_val = mvn.log_prob(samples_val).eval() + x_log_pdf = log_prob_val[:, :, :, 1, 1] + expected_log_pdf = stats.multivariate_normal( + mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x) + self.assertAllClose(expected_log_pdf, x_log_pdf) + + def testSampleMultiDimensional(self): + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, sigma = self._random_chol(3, 5, 2, 2) + chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] + chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] + + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + n = constant_op.constant(100000) + samples = mvn.sample(n, seed=137) + sample_values = samples.eval() + + self.assertEqual(samples.get_shape(), (100000, 3, 5, 2)) + self.assertAllClose( + sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05) + self.assertAllClose( + np.cov(sample_values[:, 1, 1, :], rowvar=0), + sigma[1, 1, :, :], + atol=1e-1) + + def testShapes(self): + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, _ = self._random_chol(3, 5, 2, 2) + chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] + chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] + + mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) + + # Shapes known at graph construction time. + self.assertEqual((2,), tuple(mvn.event_shape.as_list())) + self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list())) + + # Shapes known at runtime. + self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval())) + self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval())) + + def _random_mu_and_sigma(self, batch_shape, event_shape): + # This ensures sigma is positive def. + mat_shape = batch_shape + event_shape + event_shape + mat = self._rng.randn(*mat_shape) + perm = np.arange(mat.ndim) + perm[-2:] = [perm[-1], perm[-2]] + sigma = np.matmul(mat, np.transpose(mat, perm)) + + mu_shape = batch_shape + event_shape + mu = self._rng.randn(*mu_shape) + + return mu, sigma + + def testKLNonBatch(self): + batch_shape = () + event_shape = (2,) + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b) + self.assertAllClose(expected_kl, kl_v) + + def testKLBatch(self): + batch_shape = (2,) + event_shape = (3,) + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b[0, :], sigma_b[0, :]) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b[1, :], sigma_b[1, :]) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + + def testKLTwoIdenticalDistributionsIsZero(self): + batch_shape = (2,) + event_shape = (3,) + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + + # Should be zero since KL(p || p) = =. + kl = ds.kl(mvn_a, mvn_a) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + self.assertAllClose(np.zeros(*batch_shape), kl_v) + + def testSampleLarge(self): + mu = np.array([-1., 1], dtype=np.float32) + scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3. + + true_mean = mu + true_scale = scale_tril + true_covariance = np.matmul(true_scale, true_scale.T) + true_variance = np.diag(true_covariance) + true_stddev = np.sqrt(true_variance) + true_det_covariance = np.linalg.det(true_covariance) + true_log_det_covariance = np.log(true_det_covariance) + + with self.test_session() as sess: + dist = ds.MultivariateNormalTriL( + loc=mu, + scale_tril=scale_tril, + validate_args=True) + + # The following distributions will test the KL divergence calculation. + mvn_chol = ds.MultivariateNormalTriL( + loc=np.array([0.5, 1.2], dtype=np.float32), + scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32), + validate_args=True) + + n = int(10e3) + samps = dist.sample(n, seed=0) + sample_mean = math_ops.reduce_mean(samps, 0) + x = samps - sample_mean + sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n + + sample_kl_chol = math_ops.reduce_mean( + dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) + analytical_kl_chol = ds.kl(dist, mvn_chol) + + scale = dist.scale.to_dense() + + [ + sample_mean_, + analytical_mean_, + sample_covariance_, + analytical_covariance_, + analytical_variance_, + analytical_stddev_, + analytical_log_det_covariance_, + analytical_det_covariance_, + sample_kl_chol_, analytical_kl_chol_, + scale_, + ] = sess.run([ + sample_mean, + dist.mean(), + sample_covariance, + dist.covariance(), + dist.variance(), + dist.stddev(), + dist.log_det_covariance(), + dist.det_covariance(), + sample_kl_chol, analytical_kl_chol, + scale, + ]) + + sample_variance_ = np.diag(sample_covariance_) + sample_stddev_ = np.sqrt(sample_variance_) + sample_det_covariance_ = np.linalg.det(sample_covariance_) + sample_log_det_covariance_ = np.log(sample_det_covariance_) + + print("true_mean:\n{} ".format(true_mean)) + print("sample_mean:\n{}".format(sample_mean_)) + print("analytical_mean:\n{}".format(analytical_mean_)) + + print("true_covariance:\n{}".format(true_covariance)) + print("sample_covariance:\n{}".format(sample_covariance_)) + print("analytical_covariance:\n{}".format(analytical_covariance_)) + + print("true_variance:\n{}".format(true_variance)) + print("sample_variance:\n{}".format(sample_variance_)) + print("analytical_variance:\n{}".format(analytical_variance_)) + + print("true_stddev:\n{}".format(true_stddev)) + print("sample_stddev:\n{}".format(sample_stddev_)) + print("analytical_stddev:\n{}".format(analytical_stddev_)) + + print("true_log_det_covariance:\n{}".format(true_log_det_covariance)) + print("sample_log_det_covariance:\n{}".format(sample_log_det_covariance_)) + print("analytical_log_det_covariance:\n{}".format( + analytical_log_det_covariance_)) + + print("true_det_covariance:\n{}".format(true_det_covariance)) + print("sample_det_covariance:\n{}".format(sample_det_covariance_)) + print("analytical_det_covariance:\n{}".format(analytical_det_covariance_)) + + print("true_scale:\n{}".format(true_scale)) + print("scale:\n{}".format(scale_)) + + print("kl_chol: analytical:{} sample:{}".format( + analytical_kl_chol_, sample_kl_chol_)) + + self.assertAllClose(true_mean, sample_mean_, + atol=0., rtol=0.03) + self.assertAllClose(true_mean, analytical_mean_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_covariance, sample_covariance_, + atol=0., rtol=0.03) + self.assertAllClose(true_covariance, analytical_covariance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_variance, sample_variance_, + atol=0., rtol=0.02) + self.assertAllClose(true_variance, analytical_variance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_stddev, sample_stddev_, + atol=0., rtol=0.01) + self.assertAllClose(true_stddev, analytical_stddev_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_, + atol=0., rtol=0.04) + self.assertAllClose(true_log_det_covariance, + analytical_log_det_covariance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_det_covariance, sample_det_covariance_, + atol=0., rtol=0.03) + self.assertAllClose(true_det_covariance, analytical_det_covariance_, + atol=0., rtol=1e-6) + + self.assertAllClose(true_scale, scale_, + atol=0., rtol=1e-6) + + self.assertAllClose(sample_kl_chol_, analytical_kl_chol_, + atol=0., rtol=0.02) + + +def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): + """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" + # Check using numpy operations + # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy. + # So it is important to also check that KL(mvn, mvn) = 0. + sigma_b_inv = np.linalg.inv(sigma_b) + + t = np.trace(sigma_b_inv.dot(sigma_a)) + q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a) + k = mu_a.shape[0] + l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a)) + + return 0.5 * (t + q - k + l) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 3695ff007a..10b4a6ceab 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -24,6 +24,7 @@ import math import numpy as np from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib import linalg from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +33,6 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -613,6 +613,82 @@ def softplus_inverse(x, name=None): array_ops.where(is_too_large, too_large_value, y)) +# TODO(b/35290280): Add unit-tests. +def dimension_size(x, axis): + """Returns the size of a specific dimension.""" + # Since tf.gather isn't "constant-in, constant-out", we must first check the + # static shape or fallback to dynamic shape. + num_rows = (None if x.get_shape().ndims is None + else x.get_shape()[axis].value) + if num_rows is not None: + return num_rows + return array_ops.shape(x)[axis] + + +# TODO(b/35290280): Add unit-tests. +def make_diag_scale(loc, scale_diag, scale_identity_multiplier, + validate_args, assert_positive, name=None): + """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs.""" + def _convert_to_tensor(x, name): + return None if x is None else ops.convert_to_tensor(x, name=name) + + def _maybe_attach_assertion(x): + if not validate_args: + return x + if assert_positive: + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + x, message="diagonal part must be positive"), + ], x) + # TODO(b/35157376): Use `assert_none_equal` once it exists. + return control_flow_ops.with_dependencies([ + check_ops.assert_greater( + math_ops.abs(x), + array_ops.zeros([], x.dtype), + message="diagonal part must be non-zero"), + ], x) + + with ops.name_scope(name, "make_diag_scale", + values=[loc, scale_diag, scale_identity_multiplier]): + loc = _convert_to_tensor(loc, name="loc") + scale_diag = _convert_to_tensor(scale_diag, name="scale_diag") + scale_identity_multiplier = _convert_to_tensor( + scale_identity_multiplier, + name="scale_identity_multiplier") + + if scale_diag is not None: + if scale_identity_multiplier is not None: + scale_diag += scale_identity_multiplier[..., array_ops.newaxis] + return linalg.LinearOperatorDiag( + diag=_maybe_attach_assertion(scale_diag), + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=assert_positive) + + # TODO(b/35290280): Consider inferring shape from scale_perturb_factor. + if loc is None: + raise ValueError( + "Cannot infer `event_shape` unless `loc` is specified.") + + num_rows = dimension_size(loc, -1) + + if scale_identity_multiplier is None: + return linalg.LinearOperatorIdentity( + num_rows=num_rows, + dtype=loc.dtype.base_dtype, + is_self_adjoint=True, + is_positive_definite=True, + assert_proper_shapes=validate_args) + + return linalg.LinearOperatorScaledIdentity( + num_rows=num_rows, + multiplier=_maybe_attach_assertion(scale_identity_multiplier), + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=assert_positive, + assert_proper_shapes=validate_args) + + class AppendDocstring(object): """Helper class to promote private subclass docstring to public counterpart. diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py deleted file mode 100644 index a451f3a04c..0000000000 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ /dev/null @@ -1,1063 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Multivariate Normal distribution classes.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import bijector as bijectors -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import normal -from tensorflow.contrib.distributions.python.ops import transformed_distribution -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn - - -__all__ = [ - "MultivariateNormalDiag", - "MultivariateNormalDiagWithSoftplusScale", - "MultivariateNormalDiagPlusLowRank", - "MultivariateNormalTriL", -] - -_mvn_sample_note = """ -`value` is a batch vector with compatible shape if `value` is a `Tensor` whose -shape can be broadcast up to either: - -```python -self.batch_shape + self.event_shape -``` - -or - -```python -[M1, ..., Mm] + self.batch_shape + self.event_shape -``` - -""" - - -def _convert_to_tensor(x, name): - """Helper; same as `ops.convert_to_tensor` but passes-through `None`.""" - return x if x is None else ops.convert_to_tensor(x, name=name) - - -def _event_size_from_loc(loc): - """Helper; returns the shape of the last dimension of `loc`.""" - # Since tf.gather isn't "constant-in, constant-out", we must first check the - # static shape or fallback to dynamic shape. - num_rows = loc.get_shape().with_rank_at_least(1)[-1].value - if num_rows is not None: - return num_rows - return array_ops.shape(loc)[-1] - - -def _make_diag_scale(loc, scale_diag, scale_identity_multiplier, - validate_args, assert_positive): - """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs.""" - loc = _convert_to_tensor(loc, name="loc") - scale_diag = _convert_to_tensor(scale_diag, name="scale_diag") - scale_identity_multiplier = _convert_to_tensor( - scale_identity_multiplier, - name="scale_identity_multiplier") - - def _maybe_attach_assertion(x): - if not validate_args: - return x - if assert_positive: - return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, message="diagonal part must be positive"), - ], x) - # TODO(b/35157376): Use `assert_none_equal` once it exists. - return control_flow_ops.with_dependencies([ - check_ops.assert_greater( - math_ops.abs(x), - array_ops.zeros([], x.dtype), - message="diagonal part must be non-zero"), - ], x) - - if scale_diag is not None: - if scale_identity_multiplier is not None: - scale_diag += scale_identity_multiplier[..., array_ops.newaxis] - return linalg.LinearOperatorDiag( - diag=_maybe_attach_assertion(scale_diag), - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=assert_positive) - - # TODO(b/34878297): Consider inferring shape from scale_perturb_factor. - if loc is None: - raise ValueError( - "Cannot infer `event_shape` unless `loc` is specified.") - - num_rows = _event_size_from_loc(loc) - - if scale_identity_multiplier is None: - return linalg.LinearOperatorIdentity( - num_rows=num_rows, - dtype=loc.dtype.base_dtype, - is_self_adjoint=True, - is_positive_definite=True, - assert_proper_shapes=validate_args) - - return linalg.LinearOperatorScaledIdentity( - num_rows=num_rows, - multiplier=_maybe_attach_assertion(scale_identity_multiplier), - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=assert_positive, - assert_proper_shapes=validate_args) - - -class _MultivariateNormalLinearOperator( - transformed_distribution.TransformedDistribution): - """The multivariate normal distribution on `R^k`. - - The Multivariate Normal distribution is defined over `R^k` and parameterized - by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` - `scale` matrix; `covariance = scale @ scale.T` where `@` denotes - matrix-multiplication. - - #### Mathematical Details - - The probability density function (pdf) is, - - ```none - pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, - y = inv(scale) @ (x - loc), - Z = (2 pi)**(0.5 k) |det(scale)|, - ``` - - where: - - * `loc` is a vector in `R^k`, - * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. - - The MultivariateNormal distribution is a member of the [location-scale - family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be - constructed as, - - ```none - X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. - Y = scale @ X + loc - ``` - - #### Examples - - ```python - ds = tf.contrib.distributions - la = tf.contrib.linalg - - # Initialize a single 3-variate Gaussian. - mu = [1., 2, 3] - cov = [[ 0.36, 0.12, 0.06], - [ 0.12, 0.29, -0.13], - [ 0.06, -0.13, 0.26]] - scale = tf.cholesky(cov) - # ==> [[ 0.6, 0. , 0. ], - # [ 0.2, 0.5, 0. ], - # [ 0.1, -0.3, 0.4]]) - - mvn = ds._MultivariateNormalLinearOperator( - loc=mu, - scale=la.LinearOperatorTriL(scale)) - - # Covariance agrees with cholesky(cov) parameterization. - mvn.covariance().eval() - # ==> [[ 0.36, 0.12, 0.06], - # [ 0.12, 0.29, -0.13], - # [ 0.06, -0.13, 0.26]] - - # Compute the pdf of an`R^3` observation; return a scalar. - mvn.prob([-1., 0, 1]).eval() # shape: [] - - # Initialize a 2-batch of 3-variate Gaussians. - mu = [[1., 2, 3], - [11, 22, 33]] # shape: [2, 3] - scale_diag = [[1., 2, 3], - [0.5, 1, 1.5]] # shape: [2, 3] - - mvn = ds._MultivariateNormalLinearOperator( - loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) - - # Compute the pdf of two `R^3` observations; return a length-2 vector. - x = [[-0.9, 0, 0.1], - [-10, 0, 9]] # shape: [2, 3] - mvn.prob(x).eval() # shape: [2] - ``` - - """ - - def __init__(self, - loc=None, - scale=None, - validate_args=False, - allow_nan_stats=True, - name="MultivariateNormalLinearOperator"): - """Construct Multivariate Normal distribution on `R^k`. - - The `batch_shape` is the broadcast shape between `loc` and `scale` - arguments. - - The `event_shape` is given by the last dimension of `loc` or the last - dimension of the matrix implied by `scale`. - - Recall that `covariance = scale @ scale.T`. - - Additional leading dimensions (if any) will index batches. - - Args: - loc: Floating-point `Tensor`. If this is set to `None`, `loc` is - implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where - `b >= 0` and `k` represents the event size. - scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape - `[B1, ..., Bb, k, k]`. - validate_args: `Boolean`, default `False`. Whether to validate input - with asserts. If `validate_args` is `False`, and the inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. - - Raises: - ValueError: if `scale` is unspecified. - TypeError: if not `scale.dtype.is_floating` - """ - parameters = locals() - if scale is None: - raise ValueError("Missing required `scale` parameter.") - if not scale.dtype.is_floating: - raise TypeError("`scale` parameter must have floating-point dtype.") - - # Since expand_dims doesn't preserve constant-ness, we obtain the - # non-dynamic value if possible. - event_shape = scale.domain_dimension_tensor() - if tensor_util.constant_value(event_shape) is not None: - event_shape = tensor_util.constant_value(event_shape) - event_shape = event_shape[array_ops.newaxis] - - super(_MultivariateNormalLinearOperator, self).__init__( - distribution=normal.Normal( - loc=array_ops.zeros([], dtype=scale.dtype), - scale=array_ops.ones([], dtype=scale.dtype)), - bijector=bijectors.AffineLinearOperator( - shift=loc, scale=scale, validate_args=validate_args), - batch_shape=scale.batch_shape_tensor(), - event_shape=event_shape, - validate_args=validate_args, - name=name) - self._parameters = parameters - - @property - def loc(self): - """The `loc` `Tensor` in `Y = scale @ X + loc`.""" - return self.bijector.shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" - return self.bijector.scale - - def log_det_covariance(self, name="log_det_covariance"): - """Log of determinant of covariance matrix.""" - with ops.name_scope(self.name): - with ops.name_scope(name, values=self.scale.graph_parents): - return 2. * self.scale.log_abs_determinant() - - def det_covariance(self, name="det_covariance"): - """Determinant of covariance matrix.""" - with ops.name_scope(self.name): - with ops.name_scope(name, values=self.scale.graph_parents): - return math_ops.exp(2.* self.scale.log_abs_determinant()) - - @distribution_util.AppendDocstring(_mvn_sample_note) - def _log_prob(self, x): - return super(_MultivariateNormalLinearOperator, self)._log_prob(x) - - @distribution_util.AppendDocstring(_mvn_sample_note) - def _prob(self, x): - return super(_MultivariateNormalLinearOperator, self)._prob(x) - - def _mean(self): - if self.loc is None: - shape = array_ops.concat([ - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], 0) - return array_ops.zeros(shape, self.dtype) - return array_ops.identity(self.loc) - - def _covariance(self): - # TODO(b/35041434): Remove special-case logic once LinOp supports - # `diag_part`. - if (isinstance(self.scale, linalg.LinearOperatorIdentity) or - isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or - isinstance(self.scale, linalg.LinearOperatorDiag)): - shape = array_ops.concat([self.batch_shape_tensor(), - self.event_shape_tensor()], 0) - diag_part = array_ops.ones(shape, self.scale.dtype) - if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): - diag_part *= math_ops.square( - self.scale.multiplier[..., array_ops.newaxis]) - elif isinstance(self.scale, linalg.LinearOperatorDiag): - diag_part *= math_ops.square(self.scale.diag) - return array_ops.matrix_diag(diag_part) - else: - # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. - return self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())) - - def _variance(self): - # TODO(b/35041434): Remove special-case logic once LinOp supports - # `diag_part`. - if (isinstance(self.scale, linalg.LinearOperatorIdentity) or - isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or - isinstance(self.scale, linalg.LinearOperatorDiag)): - shape = array_ops.concat([self.batch_shape_tensor(), - self.event_shape_tensor()], 0) - diag_part = array_ops.ones(shape, self.scale.dtype) - if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): - diag_part *= math_ops.square( - self.scale.multiplier[..., array_ops.newaxis]) - elif isinstance(self.scale, linalg.LinearOperatorDiag): - diag_part *= math_ops.square(self.scale.diag) - return diag_part - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): - return array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense())) - else: - # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. - return array_ops.matrix_diag_part( - self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))) - - def _stddev(self): - # TODO(b/35041434): Remove special-case logic once LinOp supports - # `diag_part`. - if (isinstance(self.scale, linalg.LinearOperatorIdentity) or - isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or - isinstance(self.scale, linalg.LinearOperatorDiag)): - shape = array_ops.concat([self.batch_shape_tensor(), - self.event_shape_tensor()], 0) - diag_part = array_ops.ones(shape, self.scale.dtype) - if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): - diag_part *= self.scale.multiplier[..., array_ops.newaxis] - elif isinstance(self.scale, linalg.LinearOperatorDiag): - diag_part *= self.scale.diag - return math_ops.abs(diag_part) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): - return math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense()))) - else: - # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. - return math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())))) - - def _mode(self): - return self._mean() - - -class MultivariateNormalDiagPlusLowRank(_MultivariateNormalLinearOperator): - """The multivariate normal distribution on `R^k`. - - The Multivariate Normal distribution is defined over `R^k` and parameterized - by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` - `scale` matrix; `covariance = scale @ scale.T` where `@` denotes - matrix-multiplication. - - #### Mathematical Details - - The probability density function (pdf) is, - - ```none - pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, - y = inv(scale) @ (x - loc), - Z = (2 pi)**(0.5 k) |det(scale)|, - ``` - - where: - - * `loc` is a vector in `R^k`, - * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. - - A (non-batch) `scale` matrix is: - - ```none - scale = diag(scale_diag + scale_identity_multiplier ones(k)) + - scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T - ``` - - where: - - * `scale_diag.shape = [k]`, - * `scale_identity_multiplier.shape = []`, - * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, - * `scale_perturb_diag.shape = [r]`. - - Additional leading dimensions (if any) will index batches. - - If both `scale_diag` and `scale_identity_multiplier` are `None`, then - `scale` is the Identity matrix. - - The MultivariateNormal distribution is a member of the [location-scale - family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be - constructed as, - - ```none - X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. - Y = scale @ X + loc - ``` - - #### Examples - - ```python - ds = tf.contrib.distributions - - # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, - # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is - # a rank-2 update. - mu = [-0.5., 0, 0.5] # shape: [3] - d = [1.5, 0.5, 2] # shape: [3] - U = [[1., 2], - [-1, 1], - [2, -0.5]] # shape: [3, 2] - m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( - loc=mu - scale_diag=d - scale_perturb_factor=U, - scale_perturb_diag=m) - - # Evaluate this on an observation in `R^3`, returning a scalar. - mvn.prob([-1, 0, 1]).eval() # shape: [] - - # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`. - mu = [[1., 2, 3], - [11, 22, 33]] # shape: [b, k] = [2, 3] - U = [[[1., 2], - [3, 4], - [5, 6]], - [[0.5, 0.75], - [1,0, 0.25], - [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2] - m = [[0.1, 0.2], - [0.4, 0.5]] # shape: [b, r] = [2, 2] - - mvn = ds.MultivariateNormalDiagPlusLowRank( - loc=mu, - scale_perturb_factor=U, - scale_perturb_diag=m) - - mvn.covariance().eval() # shape: [2, 3, 3] - # ==> [[[ 15.63 31.57 48.51] - # [ 31.57 69.31 105.05] - # [ 48.51 105.05 162.59]] - # - # [[ 2.59 1.41 3.35] - # [ 1.41 2.71 3.34] - # [ 3.35 3.34 8.35]]] - - # Compute the pdf of two `R^3` observations (one from each batch); - # return a length-2 vector. - x = [[-0.9, 0, 0.1], - [-10, 0, 9]] # shape: [2, 3] - mvn.prob(x).eval() # shape: [2] - ``` - - """ - - def __init__(self, - loc=None, - scale_diag=None, - scale_identity_multiplier=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - validate_args=False, - allow_nan_stats=True, - name="MultivariateNormalDiagPlusLowRank"): - """Construct Multivariate Normal distribution on `R^k`. - - The `batch_shape` is the broadcast shape between `loc` and `scale` - arguments. - - The `event_shape` is given by the last dimension of `loc` or the last - dimension of the matrix implied by `scale`. - - Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: - - ```none - scale = diag(scale_diag + scale_identity_multiplier ones(k)) + - scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T - ``` - - where: - - * `scale_diag.shape = [k]`, - * `scale_identity_multiplier.shape = []`, - * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, - * `scale_perturb_diag.shape = [r]`. - - Additional leading dimensions (if any) will index batches. - - If both `scale_diag` and `scale_identity_multiplier` are `None`, then - `scale` is the Identity matrix. - - Args: - loc: Floating-point `Tensor`. If this is set to `None`, `loc` is - implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where - `b >= 0` and `k` represents the event size. - scale_diag: Non-zero, floating-point `Tensor` representing a diagonal - matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, - and characterizes `b`-batches of `k x k` diagonal matrices added to - `scale`. When both `scale_identity_multiplier` and `scale_diag` are - `None` then `scale` is the `Identity`. - scale_identity_multiplier: Non-zero, floating-point `Tensor` representing - a scaled-identity-matrix added to `scale`. May have shape - `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled - `k x k` identity matrices added to `scale`. When both - `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is - the `Identity`. - scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` - perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`, - `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. - When `None`, no rank-`r` update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix - inside the rank-`r` perturbation added to `scale`. May have shape - `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` - diagonal matrices inside the perturbation added to `scale`. When - `None`, an identity matrix is used inside the perturbation. Can only be - specified if `scale_perturb_factor` is also specified. - validate_args: Python `Boolean`, default `False`. When `True` distribution - parameters are checked for validity despite possibly degrading runtime - performance. When `False` invalid inputs may silently render incorrect - outputs. - allow_nan_stats: Python `Boolean`, default `True`. When `True`, - statistics (e.g., mean, mode, variance) use the value "`NaN`" to - indicate the result is undefined. When `False`, an exception is raised - if one or more of the statistic's batch members are undefined. - name: `String` name prefixed to Ops created by this class. - - Raises: - ValueError: if at most `scale_identity_multiplier` is specified. - """ - parameters = locals() - with ops.name_scope(name) as ns: - with ops.name_scope("init", values=[ - loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, - scale_perturb_diag]): - has_low_rank = (scale_perturb_factor is not None or - scale_perturb_diag is not None) - scale = _make_diag_scale( - loc=loc, - scale_diag=scale_diag, - scale_identity_multiplier=scale_identity_multiplier, - validate_args=validate_args, - assert_positive=has_low_rank) - scale_perturb_factor = _convert_to_tensor( - scale_perturb_factor, - name="scale_perturb_factor") - scale_perturb_diag = _convert_to_tensor( - scale_perturb_diag, - name="scale_perturb_diag") - if has_low_rank: - scale = linalg.LinearOperatorUDVHUpdate( - scale, - u=scale_perturb_factor, - diag=scale_perturb_diag, - is_diag_positive=scale_perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - super(MultivariateNormalDiagPlusLowRank, self).__init__( - loc=loc, - scale=scale, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=ns) - self._parameters = parameters - - -class MultivariateNormalDiag(_MultivariateNormalLinearOperator): - """The multivariate normal distribution on `R^k`. - - The Multivariate Normal distribution is defined over `R^k` and parameterized - by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` - `scale` matrix; `covariance = scale @ scale.T` where `@` denotes - matrix-multiplication. - - #### Mathematical Details - - The probability density function (pdf) is, - - ```none - pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, - y = inv(scale) @ (x - loc), - Z = (2 pi)**(0.5 k) |det(scale)|, - ``` - - where: - - * `loc` is a vector in `R^k`, - * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. - - A (non-batch) `scale` matrix is: - - ```none - scale = diag(scale_diag + scale_identity_multiplier * ones(k)) - ``` - - where: - - * `scale_diag.shape = [k]`, and, - * `scale_identity_multiplier.shape = []`. - - Additional leading dimensions (if any) will index batches. - - If both `scale_diag` and `scale_identity_multiplier` are `None`, then - `scale` is the Identity matrix. - - The MultivariateNormal distribution is a member of the [location-scale - family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be - constructed as, - - ```none - X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. - Y = scale @ X + loc - ``` - - #### Examples - - ```python - ds = tf.contrib.distributions - - # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( - loc=[1., -1], - scale_diag=[1, 2.]) - - mvn.mean().eval() - # ==> [1., -1] - - mvn.stddev().eval() - # ==> [1., 2] - - # Evaluate this on an observation in `R^2`, returning a scalar. - mvn.prob([-1., 0]).eval() # shape: [] - - # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( - loc=[1., -1], - scale_identity_multiplier=[1, 2., 3]) - - mvn.mean().eval() # shape: [3, 2] - # ==> [[1., -1] - # [1, -1], - # [1, -1]] - - mvn.stddev().eval() # shape: [3, 2] - # ==> [[1., 1], - # [2, 2], - # [3, 3]] - - # Evaluate this on an observation in `R^2`, returning a length-3 vector. - mvn.prob([-1., 0]).eval() # shape: [3] - - # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( - loc=[[1., 2, 3], - [11, 22, 33]] # shape: [2, 3] - scale_diag=[[1., 2, 3], - [0.5, 1, 1.5]]) # shape: [2, 3] - - # Evaluate this on a two observations, each in `R^3`, returning a length-2 - # vector. - x = [[-1., 0, 1], - [-11, 0, 11.]] # shape: [2, 3]. - mvn.prob(x).eval() # shape: [2] - ``` - - """ - - def __init__(self, - loc=None, - scale_diag=None, - scale_identity_multiplier=None, - validate_args=False, - allow_nan_stats=True, - name="MultivariateNormalDiag"): - """Construct Multivariate Normal distribution on `R^k`. - - The `batch_shape` is the broadcast shape between `loc` and `scale` - arguments. - - The `event_shape` is given by the last dimension of `loc` or the last - dimension of the matrix implied by `scale`. - - Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix - is: - - ```none - scale = diag(scale_diag + scale_identity_multiplier * ones(k)) - ``` - - where: - - * `scale_diag.shape = [k]`, and, - * `scale_identity_multiplier.shape = []`. - - Additional leading dimensions (if any) will index batches. - - If both `scale_diag` and `scale_identity_multiplier` are `None`, then - `scale` is the Identity matrix. - - Args: - loc: Floating-point `Tensor`. If this is set to `None`, `loc` is - implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where - `b >= 0` and `k` represents the event size. - scale_diag: Non-zero, floating-point `Tensor` representing a diagonal - matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, - and characterizes `b`-batches of `k x k` diagonal matrices added to - `scale`. When both `scale_identity_multiplier` and `scale_diag` are - `None` then `scale` is the `Identity`. - scale_identity_multiplier: Non-zero, floating-point `Tensor` representing - a scaled-identity-matrix added to `scale`. May have shape - `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled - `k x k` identity matrices added to `scale`. When both - `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is - the `Identity`. - validate_args: Python `Boolean`, default `False`. When `True` distribution - parameters are checked for validity despite possibly degrading runtime - performance. When `False` invalid inputs may silently render incorrect - outputs. - allow_nan_stats: Python `Boolean`, default `True`. When `True`, - statistics (e.g., mean, mode, variance) use the value "`NaN`" to - indicate the result is undefined. When `False`, an exception is raised - if one or more of the statistic's batch members are undefined. - name: `String` name prefixed to Ops created by this class. - - Raises: - ValueError: if at most `scale_identity_multiplier` is specified. - """ - parameters = locals() - with ops.name_scope(name) as ns: - with ops.name_scope("init", values=[ - loc, scale_diag, scale_identity_multiplier]): - scale = _make_diag_scale( - loc=loc, - scale_diag=scale_diag, - scale_identity_multiplier=scale_identity_multiplier, - validate_args=validate_args, - assert_positive=False) - super(MultivariateNormalDiag, self).__init__( - loc=loc, - scale=scale, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=ns) - self._parameters = parameters - - -class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): - """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" - - def __init__(self, - loc, - scale_diag, - validate_args=False, - allow_nan_stats=True, - name="MultivariateNormalDiagWithSoftplusScale"): - parameters = locals() - with ops.name_scope(name, values=[scale_diag]) as ns: - super(MultivariateNormalDiagWithSoftplusScale, self).__init__( - loc=loc, - scale_diag=nn.softplus(scale_diag), - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=ns) - self._parameters = parameters - - -class MultivariateNormalTriL(_MultivariateNormalLinearOperator): - """The multivariate normal distribution on `R^k`. - - The Multivariate Normal distribution is defined over `R^k` and parameterized - by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` - `scale` matrix; `covariance = scale @ scale.T` where `@` denotes - matrix-multiplication. - - #### Mathematical Details - - The probability density function (pdf) is, - - ```none - pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, - y = inv(scale) @ (x - loc), - Z = (2 pi)**(0.5 k) |det(scale)|, - ``` - - where: - - * `loc` is a vector in `R^k`, - * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. - - A (non-batch) `scale` matrix is: - - ```none - scale = scale_tril - ``` - - where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, - i.e., `tf.diag_part(scale_tril) != 0`. - - Additional leading dimensions (if any) will index batches. - - The MultivariateNormal distribution is a member of the [location-scale - family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be - constructed as, - - ```none - X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. - Y = scale @ X + loc - ``` - - Trainable (batch) Cholesky matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_lower_triangular()` - - #### Examples - - ```python - ds = tf.contrib.distributions - - # Initialize a single 3-variate Gaussian. - mu = [1., 2, 3] - cov = [[ 0.36, 0.12, 0.06], - [ 0.12, 0.29, -0.13], - [ 0.06, -0.13, 0.26]] - scale = tf.cholesky(cov) - # ==> [[ 0.6, 0. , 0. ], - # [ 0.2, 0.5, 0. ], - # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( - loc=mu, - scale_tril=scale) - - mvn.mean().eval() - # ==> [1., 2, 3] - - # Covariance agrees with cholesky(cov) parameterization. - mvn.covariance().eval() - # ==> [[ 0.36, 0.12, 0.06], - # [ 0.12, 0.29, -0.13], - # [ 0.06, -0.13, 0.26]] - - # Compute the pdf of an observation in `R^3` ; return a scalar. - mvn.prob([-1., 0, 1]).eval() # shape: [] - - # Initialize a 2-batch of 3-variate Gaussians. - mu = [[1., 2, 3], - [11, 22, 33]] # shape: [2, 3] - tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( - loc=mu, - scale_tril=tril) - - # Compute the pdf of two `R^3` observations; return a length-2 vector. - x = [[-0.9, 0, 0.1], - [-10, 0, 9]] # shape: [2, 3] - mvn.prob(x).eval() # shape: [2] - - ``` - - """ - - def __init__(self, - loc=None, - scale_tril=None, - validate_args=False, - allow_nan_stats=True, - name="MultivariateNormalTriL"): - """Construct Multivariate Normal distribution on `R^k`. - - The `batch_shape` is the broadcast shape between `loc` and `scale` - arguments. - - The `event_shape` is given by the last dimension of `loc` or the last - dimension of the matrix implied by `scale`. - - Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix - is: - - ```none - scale = scale_tril - ``` - - where `scale_tril` is lower-triangular `k x k` matrix with non-zero - diagonal, i.e., `tf.diag_part(scale_tril) != 0`. - - Additional leading dimensions (if any) will index batches. - - Args: - loc: Floating-point `Tensor`. If this is set to `None`, `loc` is - implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where - `b >= 0` and `k` is the event size. - scale_tril: Floating-point, lower-triangular `Tensor` with non-zero - diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where - `b >= 0` and `k` is the event size. - validate_args: Python `Boolean`, default `False`. When `True` distribution - parameters are checked for validity despite possibly degrading runtime - performance. When `False` invalid inputs may silently render incorrect - outputs. - allow_nan_stats: Python `Boolean`, default `True`. When `True`, - statistics (e.g., mean, mode, variance) use the value "`NaN`" to - indicate the result is undefined. When `False`, an exception is raised - if one or more of the statistic's batch members are undefined. - name: `String` name prefixed to Ops created by this class. - - Raises: - ValueError: if neither `loc` nor `scale_tril` are specified. - """ - parameters = locals() - if loc is None and scale_tril is None: - raise ValueError("Must specify one or both of `loc`, `scale_tril`.") - with ops.name_scope(name) as ns: - with ops.name_scope("init", values=[loc, scale_tril]): - loc = _convert_to_tensor(loc, name="loc") - scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") - if scale_tril is None: - scale = linalg.LinearOperatorIdentity( - num_rows=_event_size_from_loc(loc), - dtype=loc.dtype, - is_self_adjoint=True, - is_positive_definite=True, - assert_proper_shapes=validate_args) - else: - if validate_args: - scale_tril = control_flow_ops.with_dependencies([ - # TODO(b/35157376): Use `assert_none_equal` once it exists. - check_ops.assert_greater( - math_ops.abs(array_ops.matrix_diag_part(scale_tril)), - array_ops.zeros([], scale_tril.dtype), - message="`scale_tril` must have non-zero diagonal"), - ], scale_tril) - scale = linalg.LinearOperatorTriL( - scale_tril, - is_non_singular=True, - is_self_adjoint=False, - is_positive_definite=False) - super(MultivariateNormalTriL, self).__init__( - loc=loc, - scale=scale, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=ns) - self._parameters = parameters - - -@kullback_leibler.RegisterKL(_MultivariateNormalLinearOperator, - _MultivariateNormalLinearOperator) -def _kl_brute_force(a, b, name=None): - """Batched KL divergence `KL(a || b)` for multivariate Normals. - - With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and - covariance `C_a`, `C_b` respectively, - - ``` - KL(a || b) = 0.5 * ( L - k + T + Q ), - L := Log[Det(C_b)] - Log[Det(C_a)] - T := trace(C_b^{-1} C_a), - Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), - ``` - - This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient - methods for solving systems with `C_b` may be available, a dense version of - (the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B` - is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` - and `y`. - - Args: - a: Instance of `_MultivariateNormalLinearOperator`. - b: Instance of `_MultivariateNormalLinearOperator`. - name: (optional) name to use for created ops. Default "kl_mvn". - - Returns: - Batchwise `KL(a || b)`. - """ - - def squared_frobenius_norm(x): - """Helper to make KL calculation slightly more readable.""" - # http://mathworld.wolfram.com/FrobeniusNorm.html - return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) - - # TODO(b/35041439): See also b/35040945. Remove this function once LinOp - # supports something like: - # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) - def is_diagonal(x): - """Helper to identify if `LinearOperator` has only a diagonal component.""" - return (isinstance(x, linalg.LinearOperatorIdentity) or - isinstance(x, linalg.LinearOperatorScaledIdentity) or - isinstance(x, linalg.LinearOperatorDiag)) - - with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + - a.scale.graph_parents + b.scale.graph_parents): - # Calculation is based on: - # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians - # and, - # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm - # i.e., - # If Ca = AA', Cb = BB', then - # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] - # = tr[inv(B) A A' inv(B)'] - # = tr[(inv(B) A) (inv(B) A)'] - # = sum_{ij} (inv(B) A)_{ij}^2 - # = ||inv(B) A||_F**2 - # where ||.||_F is the Frobenius norm and the second equality follows from - # the cyclic permutation property. - if is_diagonal(a.scale) and is_diagonal(b.scale): - # Using `stddev` because it handles expansion of Identity cases. - b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] - else: - b_inv_a = b.scale.solve(a.scale.to_dense()) - kl_div = (b.scale.log_abs_determinant() - - a.scale.log_abs_determinant() - + 0.5 * ( - - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) - + squared_frobenius_norm(b_inv_a) - + squared_frobenius_norm(b.scale.solve( - (b.mean() - a.mean())[..., array_ops.newaxis])))) - kl_div.set_shape(array_ops.broadcast_static_shape( - a.batch_shape, b.batch_shape)) - return kl_div diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py new file mode 100644 index 0000000000..edc5251769 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -0,0 +1,232 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multivariate Normal distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn + + +__all__ = [ + "MultivariateNormalDiag", + "MultivariateNormalDiagWithSoftplusScale", +] + + +class MultivariateNormalDiag( + mvn_linop.MultivariateNormalLinearOperator): + """The multivariate normal distribution on `R^k`. + + The Multivariate Normal distribution is defined over `R^k` and parameterized + by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` + `scale` matrix; `covariance = scale @ scale.T` where `@` denotes + matrix-multiplication. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, + y = inv(scale) @ (x - loc), + Z = (2 pi)**(0.5 k) |det(scale)|, + ``` + + where: + + * `loc` is a vector in `R^k`, + * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, + * `Z` denotes the normalization constant, and, + * `||y||**2` denotes the squared Euclidean norm of `y`. + + A (non-batch) `scale` matrix is: + + ```none + scale = diag(scale_diag + scale_identity_multiplier * ones(k)) + ``` + + where: + + * `scale_diag.shape = [k]`, and, + * `scale_identity_multiplier.shape = []`. + + Additional leading dimensions (if any) will index batches. + + If both `scale_diag` and `scale_identity_multiplier` are `None`, then + `scale` is the Identity matrix. + + The MultivariateNormal distribution is a member of the [location-scale + family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. + Y = scale @ X + loc + ``` + + #### Examples + + ```python + ds = tf.contrib.distributions + + # Initialize a single 2-variate Gaussian. + mvn = ds.MultivariateNormalDiag( + loc=[1., -1], + scale_diag=[1, 2.]) + + mvn.mean().eval() + # ==> [1., -1] + + mvn.stddev().eval() + # ==> [1., 2] + + # Evaluate this on an observation in `R^2`, returning a scalar. + mvn.prob([-1., 0]).eval() # shape: [] + + # Initialize a 3-batch, 2-variate scaled-identity Gaussian. + mvn = ds.MultivariateNormalDiag( + loc=[1., -1], + scale_identity_multiplier=[1, 2., 3]) + + mvn.mean().eval() # shape: [3, 2] + # ==> [[1., -1] + # [1, -1], + # [1, -1]] + + mvn.stddev().eval() # shape: [3, 2] + # ==> [[1., 1], + # [2, 2], + # [3, 3]] + + # Evaluate this on an observation in `R^2`, returning a length-3 vector. + mvn.prob([-1., 0]).eval() # shape: [3] + + # Initialize a 2-batch of 3-variate Gaussians. + mvn = ds.MultivariateNormalDiag( + loc=[[1., 2, 3], + [11, 22, 33]] # shape: [2, 3] + scale_diag=[[1., 2, 3], + [0.5, 1, 1.5]]) # shape: [2, 3] + + # Evaluate this on a two observations, each in `R^3`, returning a length-2 + # vector. + x = [[-1., 0, 1], + [-11, 0, 11.]] # shape: [2, 3]. + mvn.prob(x).eval() # shape: [2] + ``` + + """ + + def __init__(self, + loc=None, + scale_diag=None, + scale_identity_multiplier=None, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalDiag"): + """Construct Multivariate Normal distribution on `R^k`. + + The `batch_shape` is the broadcast shape between `loc` and `scale` + arguments. + + The `event_shape` is given by the last dimension of `loc` or the last + dimension of the matrix implied by `scale`. + + Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: + + ```none + scale = diag(scale_diag + scale_identity_multiplier * ones(k)) + ``` + + where: + + * `scale_diag.shape = [k]`, and, + * `scale_identity_multiplier.shape = []`. + + Additional leading dimensions (if any) will index batches. + + If both `scale_diag` and `scale_identity_multiplier` are `None`, then + `scale` is the Identity matrix. + + Args: + loc: Floating-point `Tensor`. If this is set to `None`, `loc` is + implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where + `b >= 0` and `k` is the event size. + scale_diag: Non-zero, floating-point `Tensor` representing a diagonal + matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, + and characterizes `b`-batches of `k x k` diagonal matrices added to + `scale`. When both `scale_identity_multiplier` and `scale_diag` are + `None` then `scale` is the `Identity`. + scale_identity_multiplier: Non-zero, floating-point `Tensor` representing + a scaled-identity-matrix added to `scale`. May have shape + `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled + `k x k` identity matrices added to `scale`. When both + `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is + the `Identity`. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. + + Raises: + ValueError: if at most `scale_identity_multiplier` is specified. + """ + parameters = locals() + with ops.name_scope(name) as ns: + with ops.name_scope("init", values=[ + loc, scale_diag, scale_identity_multiplier]): + scale = distribution_util.make_diag_scale( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=validate_args, + assert_positive=False) + super(MultivariateNormalDiag, self).__init__( + loc=loc, + scale=scale, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + self._parameters = parameters + + +class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): + """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + + def __init__(self, + loc, + scale_diag, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalDiagWithSoftplusScale"): + parameters = locals() + with ops.name_scope(name, values=[scale_diag]) as ns: + super(MultivariateNormalDiagWithSoftplusScale, self).__init__( + loc=loc, + scale_diag=nn.softplus(scale_diag), + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + self._parameters = parameters diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py new file mode 100644 index 0000000000..51487cf3a3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -0,0 +1,255 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multivariate Normal distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop +from tensorflow.python.framework import ops + + +__all__ = [ + "MultivariateNormalDiagPlusLowRank", +] + + +class MultivariateNormalDiagPlusLowRank( + mvn_linop.MultivariateNormalLinearOperator): + """The multivariate normal distribution on `R^k`. + + The Multivariate Normal distribution is defined over `R^k` and parameterized + by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` + `scale` matrix; `covariance = scale @ scale.T` where `@` denotes + matrix-multiplication. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, + y = inv(scale) @ (x - loc), + Z = (2 pi)**(0.5 k) |det(scale)|, + ``` + + where: + + * `loc` is a vector in `R^k`, + * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, + * `Z` denotes the normalization constant, and, + * `||y||**2` denotes the squared Euclidean norm of `y`. + + A (non-batch) `scale` matrix is: + + ```none + scale = diag(scale_diag + scale_identity_multiplier ones(k)) + + scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T + ``` + + where: + + * `scale_diag.shape = [k]`, + * `scale_identity_multiplier.shape = []`, + * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, + * `scale_perturb_diag.shape = [r]`. + + Additional leading dimensions (if any) will index batches. + + If both `scale_diag` and `scale_identity_multiplier` are `None`, then + `scale` is the Identity matrix. + + The MultivariateNormal distribution is a member of the [location-scale + family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. + Y = scale @ X + loc + ``` + + #### Examples + + ```python + ds = tf.contrib.distributions + + # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, + # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is + # a rank-2 update. + mu = [-0.5., 0, 0.5] # shape: [3] + d = [1.5, 0.5, 2] # shape: [3] + U = [[1., 2], + [-1, 1], + [2, -0.5]] # shape: [3, 2] + m = [4., 5] # shape: [2] + mvn = ds.MultivariateNormalDiagPlusLowRank( + loc=mu + scale_diag=d + scale_perturb_factor=U, + scale_perturb_diag=m) + + # Evaluate this on an observation in `R^3`, returning a scalar. + mvn.prob([-1, 0, 1]).eval() # shape: [] + + # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`. + mu = [[1., 2, 3], + [11, 22, 33]] # shape: [b, k] = [2, 3] + U = [[[1., 2], + [3, 4], + [5, 6]], + [[0.5, 0.75], + [1,0, 0.25], + [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2] + m = [[0.1, 0.2], + [0.4, 0.5]] # shape: [b, r] = [2, 2] + + mvn = ds.MultivariateNormalDiagPlusLowRank( + loc=mu, + scale_perturb_factor=U, + scale_perturb_diag=m) + + mvn.covariance().eval() # shape: [2, 3, 3] + # ==> [[[ 15.63 31.57 48.51] + # [ 31.57 69.31 105.05] + # [ 48.51 105.05 162.59]] + # + # [[ 2.59 1.41 3.35] + # [ 1.41 2.71 3.34] + # [ 3.35 3.34 8.35]]] + + # Compute the pdf of two `R^3` observations (one from each batch); + # return a length-2 vector. + x = [[-0.9, 0, 0.1], + [-10, 0, 9]] # shape: [2, 3] + mvn.prob(x).eval() # shape: [2] + ``` + + """ + + def __init__(self, + loc=None, + scale_diag=None, + scale_identity_multiplier=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalDiagPlusLowRank"): + """Construct Multivariate Normal distribution on `R^k`. + + The `batch_shape` is the broadcast shape between `loc` and `scale` + arguments. + + The `event_shape` is given by the last dimension of `loc` or the last + dimension of the matrix implied by `scale`. + + Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: + + ```none + scale = diag(scale_diag + scale_identity_multiplier ones(k)) + + scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T + ``` + + where: + + * `scale_diag.shape = [k]`, + * `scale_identity_multiplier.shape = []`, + * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, + * `scale_perturb_diag.shape = [r]`. + + Additional leading dimensions (if any) will index batches. + + If both `scale_diag` and `scale_identity_multiplier` are `None`, then + `scale` is the Identity matrix. + + Args: + loc: Floating-point `Tensor`. If this is set to `None`, `loc` is + implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where + `b >= 0` and `k` is the event size. + scale_diag: Non-zero, floating-point `Tensor` representing a diagonal + matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, + and characterizes `b`-batches of `k x k` diagonal matrices added to + `scale`. When both `scale_identity_multiplier` and `scale_diag` are + `None` then `scale` is the `Identity`. + scale_identity_multiplier: Non-zero, floating-point `Tensor` representing + a scaled-identity-matrix added to `scale`. May have shape + `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled + `k x k` identity matrices added to `scale`. When both + `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is + the `Identity`. + scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` + perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`, + `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. + When `None`, no rank-`r` update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix + inside the rank-`r` perturbation added to `scale`. May have shape + `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` + diagonal matrices inside the perturbation added to `scale`. When + `None`, an identity matrix is used inside the perturbation. Can only be + specified if `scale_perturb_factor` is also specified. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. + + Raises: + ValueError: if at most `scale_identity_multiplier` is specified. + """ + parameters = locals() + def _convert_to_tensor(x, name): + return None if x is None else ops.convert_to_tensor(x, name=name) + with ops.name_scope(name) as ns: + with ops.name_scope("init", values=[ + loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, + scale_perturb_diag]): + has_low_rank = (scale_perturb_factor is not None or + scale_perturb_diag is not None) + scale = distribution_util.make_diag_scale( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=validate_args, + assert_positive=has_low_rank) + scale_perturb_factor = _convert_to_tensor( + scale_perturb_factor, + name="scale_perturb_factor") + scale_perturb_diag = _convert_to_tensor( + scale_perturb_diag, + name="scale_perturb_diag") + if has_low_rank: + scale = linalg.LinearOperatorUDVHUpdate( + scale, + u=scale_perturb_factor, + diag=scale_perturb_diag, + is_diag_positive=scale_perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + super(MultivariateNormalDiagPlusLowRank, self).__init__( + loc=loc, + scale=scale, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + self._parameters = parameters diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py new file mode 100644 index 0000000000..f6f26a0b1d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -0,0 +1,383 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multivariate Normal distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import bijector as bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import kullback_leibler +from tensorflow.contrib.distributions.python.ops import normal +from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + "MultivariateNormalLinearOperator", +] + + +_mvn_sample_note = """ +`value` is a batch vector with compatible shape if `value` is a `Tensor` whose +shape can be broadcast up to either: + +```python +self.batch_shape + self.event_shape +``` + +or + +```python +[M1, ..., Mm] + self.batch_shape + self.event_shape +``` + +""" + + +# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests. +class MultivariateNormalLinearOperator( + transformed_distribution.TransformedDistribution): + """The multivariate normal distribution on `R^k`. + + The Multivariate Normal distribution is defined over `R^k` and parameterized + by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` + `scale` matrix; `covariance = scale @ scale.T` where `@` denotes + matrix-multiplication. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, + y = inv(scale) @ (x - loc), + Z = (2 pi)**(0.5 k) |det(scale)|, + ``` + + where: + + * `loc` is a vector in `R^k`, + * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, + * `Z` denotes the normalization constant, and, + * `||y||**2` denotes the squared Euclidean norm of `y`. + + The MultivariateNormal distribution is a member of the [location-scale + family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. + Y = scale @ X + loc + ``` + + #### Examples + + ```python + ds = tf.contrib.distributions + la = tf.contrib.linalg + + # Initialize a single 3-variate Gaussian. + mu = [1., 2, 3] + cov = [[ 0.36, 0.12, 0.06], + [ 0.12, 0.29, -0.13], + [ 0.06, -0.13, 0.26]] + scale = tf.cholesky(cov) + # ==> [[ 0.6, 0. , 0. ], + # [ 0.2, 0.5, 0. ], + # [ 0.1, -0.3, 0.4]]) + + mvn = ds.MultivariateNormalLinearOperator( + loc=mu, + scale=la.LinearOperatorTriL(scale)) + + # Covariance agrees with cholesky(cov) parameterization. + mvn.covariance().eval() + # ==> [[ 0.36, 0.12, 0.06], + # [ 0.12, 0.29, -0.13], + # [ 0.06, -0.13, 0.26]] + + # Compute the pdf of an`R^3` observation; return a scalar. + mvn.prob([-1., 0, 1]).eval() # shape: [] + + # Initialize a 2-batch of 3-variate Gaussians. + mu = [[1., 2, 3], + [11, 22, 33]] # shape: [2, 3] + scale_diag = [[1., 2, 3], + [0.5, 1, 1.5]] # shape: [2, 3] + + mvn = ds.MultivariateNormalLinearOperator( + loc=mu, + scale=la.LinearOperatorDiag(scale_diag)) + + # Compute the pdf of two `R^3` observations; return a length-2 vector. + x = [[-0.9, 0, 0.1], + [-10, 0, 9]] # shape: [2, 3] + mvn.prob(x).eval() # shape: [2] + ``` + + """ + + def __init__(self, + loc=None, + scale=None, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalLinearOperator"): + """Construct Multivariate Normal distribution on `R^k`. + + The `batch_shape` is the broadcast shape between `loc` and `scale` + arguments. + + The `event_shape` is given by the last dimension of `loc` or the last + dimension of the matrix implied by `scale`. + + Recall that `covariance = scale @ scale.T`. + + Additional leading dimensions (if any) will index batches. + + Args: + loc: Floating-point `Tensor`. If this is set to `None`, `loc` is + implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where + `b >= 0` and `k` is the event size. + scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape + `[B1, ..., Bb, k, k]`. + validate_args: `Boolean`, default `False`. Whether to validate input + with asserts. If `validate_args` is `False`, and the inputs are + invalid, correct behavior is not guaranteed. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + + Raises: + ValueError: if `scale` is unspecified. + TypeError: if not `scale.dtype.is_floating` + """ + parameters = locals() + if scale is None: + raise ValueError("Missing required `scale` parameter.") + if not scale.dtype.is_floating: + raise TypeError("`scale` parameter must have floating-point dtype.") + + # Since expand_dims doesn't preserve constant-ness, we obtain the + # non-dynamic value if possible. + event_shape = scale.domain_dimension_tensor() + if tensor_util.constant_value(event_shape) is not None: + event_shape = tensor_util.constant_value(event_shape) + event_shape = event_shape[array_ops.newaxis] + + super(MultivariateNormalLinearOperator, self).__init__( + distribution=normal.Normal( + loc=array_ops.zeros([], dtype=scale.dtype), + scale=array_ops.ones([], dtype=scale.dtype)), + bijector=bijectors.AffineLinearOperator( + shift=loc, scale=scale, validate_args=validate_args), + batch_shape=scale.batch_shape_tensor(), + event_shape=event_shape, + validate_args=validate_args, + name=name) + self._parameters = parameters + + @property + def loc(self): + """The `loc` `Tensor` in `Y = scale @ X + loc`.""" + return self.bijector.shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" + return self.bijector.scale + + def log_det_covariance(self, name="log_det_covariance"): + """Log of determinant of covariance matrix.""" + with ops.name_scope(self.name): + with ops.name_scope(name, values=self.scale.graph_parents): + return 2. * self.scale.log_abs_determinant() + + def det_covariance(self, name="det_covariance"): + """Determinant of covariance matrix.""" + with ops.name_scope(self.name): + with ops.name_scope(name, values=self.scale.graph_parents): + return math_ops.exp(2.* self.scale.log_abs_determinant()) + + @distribution_util.AppendDocstring(_mvn_sample_note) + def _log_prob(self, x): + return super(MultivariateNormalLinearOperator, self)._log_prob(x) + + @distribution_util.AppendDocstring(_mvn_sample_note) + def _prob(self, x): + return super(MultivariateNormalLinearOperator, self)._prob(x) + + def _mean(self): + if self.loc is None: + shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], 0) + return array_ops.zeros(shape, self.dtype) + return array_ops.identity(self.loc) + + def _covariance(self): + # TODO(b/35041434): Remove special-case logic once LinOp supports + # `diag_part`. + if (isinstance(self.scale, linalg.LinearOperatorIdentity) or + isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or + isinstance(self.scale, linalg.LinearOperatorDiag)): + shape = array_ops.concat([self.batch_shape_tensor(), + self.event_shape_tensor()], 0) + diag_part = array_ops.ones(shape, self.scale.dtype) + if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): + diag_part *= math_ops.square( + self.scale.multiplier[..., array_ops.newaxis]) + elif isinstance(self.scale, linalg.LinearOperatorDiag): + diag_part *= math_ops.square(self.scale.diag) + return array_ops.matrix_diag(diag_part) + else: + # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. + return self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())) + + def _variance(self): + # TODO(b/35041434): Remove special-case logic once LinOp supports + # `diag_part`. + if (isinstance(self.scale, linalg.LinearOperatorIdentity) or + isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or + isinstance(self.scale, linalg.LinearOperatorDiag)): + shape = array_ops.concat([self.batch_shape_tensor(), + self.event_shape_tensor()], 0) + diag_part = array_ops.ones(shape, self.scale.dtype) + if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): + diag_part *= math_ops.square( + self.scale.multiplier[..., array_ops.newaxis]) + elif isinstance(self.scale, linalg.LinearOperatorDiag): + diag_part *= math_ops.square(self.scale.diag) + return diag_part + elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) + and self.scale.is_self_adjoint): + return array_ops.matrix_diag_part( + self.scale.apply(self.scale.to_dense())) + else: + # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. + return array_ops.matrix_diag_part( + self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))) + + def _stddev(self): + # TODO(b/35041434): Remove special-case logic once LinOp supports + # `diag_part`. + if (isinstance(self.scale, linalg.LinearOperatorIdentity) or + isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or + isinstance(self.scale, linalg.LinearOperatorDiag)): + shape = array_ops.concat([self.batch_shape_tensor(), + self.event_shape_tensor()], 0) + diag_part = array_ops.ones(shape, self.scale.dtype) + if isinstance(self.scale, linalg.LinearOperatorScaledIdentity): + diag_part *= self.scale.multiplier[..., array_ops.newaxis] + elif isinstance(self.scale, linalg.LinearOperatorDiag): + diag_part *= self.scale.diag + return math_ops.abs(diag_part) + elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) + and self.scale.is_self_adjoint): + return math_ops.sqrt(array_ops.matrix_diag_part( + self.scale.apply(self.scale.to_dense()))) + else: + # TODO(b/35040238): Remove transpose once LinOp supports `transpose`. + return math_ops.sqrt(array_ops.matrix_diag_part( + self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())))) + + def _mode(self): + return self._mean() + + +@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, + MultivariateNormalLinearOperator) +def _kl_brute_force(a, b, name=None): + """Batched KL divergence `KL(a || b)` for multivariate Normals. + + With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and + covariance `C_a`, `C_b` respectively, + + ``` + KL(a || b) = 0.5 * ( L - k + T + Q ), + L := Log[Det(C_b)] - Log[Det(C_a)] + T := trace(C_b^{-1} C_a), + Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), + ``` + + This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient + methods for solving systems with `C_b` may be available, a dense version of + (the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B` + is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` + and `y`. + + Args: + a: Instance of `MultivariateNormalLinearOperator`. + b: Instance of `MultivariateNormalLinearOperator`. + name: (optional) name to use for created ops. Default "kl_mvn". + + Returns: + Batchwise `KL(a || b)`. + """ + + def squared_frobenius_norm(x): + """Helper to make KL calculation slightly more readable.""" + # http://mathworld.wolfram.com/FrobeniusNorm.html + return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) + + # TODO(b/35041439): See also b/35040945. Remove this function once LinOp + # supports something like: + # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) + def is_diagonal(x): + """Helper to identify if `LinearOperator` has only a diagonal component.""" + return (isinstance(x, linalg.LinearOperatorIdentity) or + isinstance(x, linalg.LinearOperatorScaledIdentity) or + isinstance(x, linalg.LinearOperatorDiag)) + + with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + + a.scale.graph_parents + b.scale.graph_parents): + # Calculation is based on: + # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians + # and, + # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm + # i.e., + # If Ca = AA', Cb = BB', then + # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] + # = tr[inv(B) A A' inv(B)'] + # = tr[(inv(B) A) (inv(B) A)'] + # = sum_{ij} (inv(B) A)_{ij}^2 + # = ||inv(B) A||_F**2 + # where ||.||_F is the Frobenius norm and the second equality follows from + # the cyclic permutation property. + if is_diagonal(a.scale) and is_diagonal(b.scale): + # Using `stddev` because it handles expansion of Identity cases. + b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] + else: + b_inv_a = b.scale.solve(a.scale.to_dense()) + kl_div = (b.scale.log_abs_determinant() + - a.scale.log_abs_determinant() + + 0.5 * ( + - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) + + squared_frobenius_norm(b_inv_a) + + squared_frobenius_norm(b.scale.solve( + (b.mean() - a.mean())[..., array_ops.newaxis])))) + kl_div.set_shape(array_ops.broadcast_static_shape( + a.batch_shape, b.batch_shape)) + return kl_div diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py new file mode 100644 index 0000000000..8fdc0822c4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -0,0 +1,213 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multivariate Normal distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + "MultivariateNormalTriL", +] + + +class MultivariateNormalTriL( + mvn_linop.MultivariateNormalLinearOperator): + """The multivariate normal distribution on `R^k`. + + The Multivariate Normal distribution is defined over `R^k` and parameterized + by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` + `scale` matrix; `covariance = scale @ scale.T` where `@` denotes + matrix-multiplication. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, + y = inv(scale) @ (x - loc), + Z = (2 pi)**(0.5 k) |det(scale)|, + ``` + + where: + + * `loc` is a vector in `R^k`, + * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, + * `Z` denotes the normalization constant, and, + * `||y||**2` denotes the squared Euclidean norm of `y`. + + A (non-batch) `scale` matrix is: + + ```none + scale = scale_tril + ``` + + where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, + i.e., `tf.diag_part(scale_tril) != 0`. + + Additional leading dimensions (if any) will index batches. + + The MultivariateNormal distribution is a member of the [location-scale + family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. + Y = scale @ X + loc + ``` + + Trainable (batch) lower-triangular matrices can be created with + `ds.matrix_diag_transform()` and/or `ds.fill_lower_triangular()` + + #### Examples + + ```python + ds = tf.contrib.distributions + + # Initialize a single 3-variate Gaussian. + mu = [1., 2, 3] + cov = [[ 0.36, 0.12, 0.06], + [ 0.12, 0.29, -0.13], + [ 0.06, -0.13, 0.26]] + scale = tf.cholesky(cov) + # ==> [[ 0.6, 0. , 0. ], + # [ 0.2, 0.5, 0. ], + # [ 0.1, -0.3, 0.4]]) + mvn = ds.MultivariateNormalTriL( + loc=mu, + scale_tril=scale) + + mvn.mean().eval() + # ==> [1., 2, 3] + + # Covariance agrees with cholesky(cov) parameterization. + mvn.covariance().eval() + # ==> [[ 0.36, 0.12, 0.06], + # [ 0.12, 0.29, -0.13], + # [ 0.06, -0.13, 0.26]] + + # Compute the pdf of an observation in `R^3` ; return a scalar. + mvn.prob([-1., 0, 1]).eval() # shape: [] + + # Initialize a 2-batch of 3-variate Gaussians. + mu = [[1., 2, 3], + [11, 22, 33]] # shape: [2, 3] + tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. + mvn = ds.MultivariateNormalTriL( + loc=mu, + scale_tril=tril) + + # Compute the pdf of two `R^3` observations; return a length-2 vector. + x = [[-0.9, 0, 0.1], + [-10, 0, 9]] # shape: [2, 3] + mvn.prob(x).eval() # shape: [2] + + ``` + + """ + + def __init__(self, + loc=None, + scale_tril=None, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalTriL"): + """Construct Multivariate Normal distribution on `R^k`. + + The `batch_shape` is the broadcast shape between `loc` and `scale` + arguments. + + The `event_shape` is given by the last dimension of `loc` or the last + dimension of the matrix implied by `scale`. + + Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: + + ```none + scale = scale_tril + ``` + + where `scale_tril` is lower-triangular `k x k` matrix with non-zero + diagonal, i.e., `tf.diag_part(scale_tril) != 0`. + + Additional leading dimensions (if any) will index batches. + + Args: + loc: Floating-point `Tensor`. If this is set to `None`, `loc` is + implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where + `b >= 0` and `k` is the event size. + scale_tril: Floating-point, lower-triangular `Tensor` with non-zero + diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where + `b >= 0` and `k` is the event size. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. + + Raises: + ValueError: if neither `loc` nor `scale_tril` are specified. + """ + parameters = locals() + def _convert_to_tensor(x, name): + return None if x is None else ops.convert_to_tensor(x, name=name) + if loc is None and scale_tril is None: + raise ValueError("Must specify one or both of `loc`, `scale_tril`.") + with ops.name_scope(name) as ns: + with ops.name_scope("init", values=[loc, scale_tril]): + loc = _convert_to_tensor(loc, name="loc") + scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") + if scale_tril is None: + scale = linalg.LinearOperatorIdentity( + num_rows=distribution_util.dimension_size(loc, -1), + dtype=loc.dtype, + is_self_adjoint=True, + is_positive_definite=True, + assert_proper_shapes=validate_args) + else: + if validate_args: + scale_tril = control_flow_ops.with_dependencies([ + # TODO(b/35157376): Use `assert_none_equal` once it exists. + check_ops.assert_greater( + math_ops.abs(array_ops.matrix_diag_part(scale_tril)), + array_ops.zeros([], scale_tril.dtype), + message="`scale_tril` must have non-zero diagonal"), + ], scale_tril) + scale = linalg.LinearOperatorTriL( + scale_tril, + is_non_singular=True, + is_self_adjoint=False, + is_positive_definite=False) + super(MultivariateNormalTriL, self).__init__( + loc=loc, + scale=scale, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + self._parameters = parameters |