aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distributions/BUILD38
-rw-r--r--tensorflow/contrib/distributions/__init__.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py420
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py181
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py981
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py443
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py78
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn.py1063
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py232
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py255
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py383
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py213
12 files changed, 2243 insertions, 2048 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 714e48a35c..7ac72f07e9 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -349,9 +349,43 @@ cuda_py_test(
)
cuda_py_test(
- name = "mvn_test",
+ name = "mvn_diag_test",
+ size = "small",
+ srcs = ["python/kernel_tests/mvn_diag_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
+ name = "mvn_diag_plus_low_rank_test",
size = "medium",
- srcs = ["python/kernel_tests/mvn_test.py"],
+ srcs = ["python/kernel_tests/mvn_diag_plus_low_rank_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
+ name = "mvn_tril_test",
+ size = "small",
+ srcs = ["python/kernel_tests/mvn_tril_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index db02ea60f8..fa92d3c156 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -134,7 +134,9 @@ from tensorflow.contrib.distributions.python.ops.laplace import *
from tensorflow.contrib.distributions.python.ops.logistic import *
from tensorflow.contrib.distributions.python.ops.mixture import *
from tensorflow.contrib.distributions.python.ops.multinomial import *
-from tensorflow.contrib.distributions.python.ops.mvn import *
+from tensorflow.contrib.distributions.python.ops.mvn_diag import *
+from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import *
+from tensorflow.contrib.distributions.python.ops.mvn_tril import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.onehot_categorical import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py
new file mode 100644
index 0000000000..834877769e
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py
@@ -0,0 +1,420 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultivariateNormal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib import distributions
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+ds = distributions
+
+
+class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
+ """Well tested because this is a simple override of the base class."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testDiagBroadcastBothBatchAndEvent(self):
+ # batch_shape: [3], event_shape: [2]
+ diag = np.array([[1., 2], [3, 4], [5, 6]])
+ # batch_shape: [1], event_shape: []
+ identity_multiplier = np.array([5.])
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ scale_diag=diag,
+ scale_identity_multiplier=identity_multiplier,
+ validate_args=True)
+ self.assertAllClose(
+ np.array([[[1. + 5, 0],
+ [0, 2 + 5]],
+ [[3 + 5, 0],
+ [0, 4 + 5]],
+ [[5 + 5, 0],
+ [0, 6 + 5]]]),
+ dist.scale.to_dense().eval())
+
+ def testDiagBroadcastBothBatchAndEvent2(self):
+ # This test differs from `testDiagBroadcastBothBatchAndEvent` in that it
+ # broadcasts batch_shape's from both the `scale_diag` and
+ # `scale_identity_multiplier` args.
+ # batch_shape: [3], event_shape: [2]
+ diag = np.array([[1., 2], [3, 4], [5, 6]])
+ # batch_shape: [3, 1], event_shape: []
+ identity_multiplier = np.array([[5.], [4], [3]])
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ scale_diag=diag,
+ scale_identity_multiplier=identity_multiplier,
+ validate_args=True)
+ self.assertAllEqual(
+ [3, 3, 2, 2],
+ dist.scale.to_dense().get_shape())
+
+ def testDiagBroadcastOnlyEvent(self):
+ # batch_shape: [3], event_shape: [2]
+ diag = np.array([[1., 2], [3, 4], [5, 6]])
+ # batch_shape: [3], event_shape: []
+ identity_multiplier = np.array([5., 4, 3])
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ scale_diag=diag,
+ scale_identity_multiplier=identity_multiplier,
+ validate_args=True)
+ self.assertAllClose(
+ np.array([[[1. + 5, 0],
+ [0, 2 + 5]],
+ [[3 + 4, 0],
+ [0, 4 + 4]],
+ [[5 + 3, 0],
+ [0, 6 + 3]]]), # shape: [3, 2, 2]
+ dist.scale.to_dense().eval())
+
+ def testDiagBroadcastMultiplierAndLoc(self):
+ # batch_shape: [], event_shape: [3]
+ loc = np.array([1., 0, -1])
+ # batch_shape: [3], event_shape: []
+ identity_multiplier = np.array([5., 4, 3])
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ loc=loc,
+ scale_identity_multiplier=identity_multiplier,
+ validate_args=True)
+ self.assertAllClose(
+ np.array([[[5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5]],
+ [[4, 0, 0],
+ [0, 4, 0],
+ [0, 0, 4]],
+ [[3, 0, 0],
+ [0, 3, 0],
+ [0, 0, 3]]]),
+ dist.scale.to_dense().eval())
+
+ def testMean(self):
+ mu = [-1.0, 1.0]
+ diag_large = [1.0, 5.0]
+ v = [[2.0], [3.0]]
+ diag_small = [3.0]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ loc=mu,
+ scale_diag=diag_large,
+ scale_perturb_factor=v,
+ scale_perturb_diag=diag_small,
+ validate_args=True)
+ self.assertAllEqual(mu, dist.mean().eval())
+
+ def testSample(self):
+ # TODO(jvdillon): This test should be the basis of a new test fixture which
+ # is applied to every distribution. When we make this fixture, we'll also
+ # separate the analytical- and sample-based tests as well as for each
+ # function tested. For now, we group things so we can recycle one batch of
+ # samples (thus saving resources).
+
+ mu = np.array([-1., 1, 0.5], dtype=np.float32)
+ diag_large = np.array([1., 0.5, 0.75], dtype=np.float32)
+ diag_small = np.array([-1.1, 1.2], dtype=np.float32)
+ v = np.array([[0.7, 0.8],
+ [0.9, 1],
+ [0.5, 0.6]], dtype=np.float32) # shape: [k, r] = [3, 2]
+
+ true_mean = mu
+ true_scale = np.diag(diag_large) + np.matmul(np.matmul(
+ v, np.diag(diag_small)), v.T)
+ true_covariance = np.matmul(true_scale, true_scale.T)
+ true_variance = np.diag(true_covariance)
+ true_stddev = np.sqrt(true_variance)
+ true_det_covariance = np.linalg.det(true_covariance)
+ true_log_det_covariance = np.log(true_det_covariance)
+
+ with self.test_session() as sess:
+ dist = ds.MultivariateNormalDiagPlusLowRank(
+ loc=mu,
+ scale_diag=diag_large,
+ scale_perturb_factor=v,
+ scale_perturb_diag=diag_small,
+ validate_args=True)
+
+ # The following distributions will test the KL divergence calculation.
+ mvn_identity = ds.MultivariateNormalDiag(
+ loc=np.array([1., 2, 0.25], dtype=np.float32),
+ validate_args=True)
+ mvn_scaled = ds.MultivariateNormalDiag(
+ loc=mvn_identity.loc,
+ scale_identity_multiplier=2.2,
+ validate_args=True)
+ mvn_diag = ds.MultivariateNormalDiag(
+ loc=mvn_identity.loc,
+ scale_diag=np.array([0.5, 1.5, 1.], dtype=np.float32),
+ validate_args=True)
+ mvn_chol = ds.MultivariateNormalTriL(
+ loc=np.array([1., 2, -1], dtype=np.float32),
+ scale_tril=np.array([[6., 0, 0],
+ [2, 5, 0],
+ [1, 3, 4]], dtype=np.float32) / 10.,
+ validate_args=True)
+
+ scale = dist.scale.to_dense()
+
+ n = int(30e3)
+ samps = dist.sample(n, seed=0)
+ sample_mean = math_ops.reduce_mean(samps, 0)
+ x = samps - sample_mean
+ sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n
+
+ sample_kl_identity = math_ops.reduce_mean(
+ dist.log_prob(samps) - mvn_identity.log_prob(samps), 0)
+ analytical_kl_identity = ds.kl(dist, mvn_identity)
+
+ sample_kl_scaled = math_ops.reduce_mean(
+ dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
+ analytical_kl_scaled = ds.kl(dist, mvn_scaled)
+
+ sample_kl_diag = math_ops.reduce_mean(
+ dist.log_prob(samps) - mvn_diag.log_prob(samps), 0)
+ analytical_kl_diag = ds.kl(dist, mvn_diag)
+
+ sample_kl_chol = math_ops.reduce_mean(
+ dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
+ analytical_kl_chol = ds.kl(dist, mvn_chol)
+
+ n = int(10e3)
+ baseline = ds.MultivariateNormalDiag(
+ loc=np.array([-1., 0.25, 1.25], dtype=np.float32),
+ scale_diag=np.array([1.5, 0.5, 1.], dtype=np.float32),
+ validate_args=True)
+ samps = baseline.sample(n, seed=0)
+
+ sample_kl_identity_diag_baseline = math_ops.reduce_mean(
+ baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0)
+ analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity)
+
+ sample_kl_scaled_diag_baseline = math_ops.reduce_mean(
+ baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
+ analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled)
+
+ sample_kl_diag_diag_baseline = math_ops.reduce_mean(
+ baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0)
+ analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag)
+
+ sample_kl_chol_diag_baseline = math_ops.reduce_mean(
+ baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0)
+ analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol)
+
+ [
+ sample_mean_,
+ analytical_mean_,
+ sample_covariance_,
+ analytical_covariance_,
+ analytical_variance_,
+ analytical_stddev_,
+ analytical_log_det_covariance_,
+ analytical_det_covariance_,
+ scale_,
+ sample_kl_identity_, analytical_kl_identity_,
+ sample_kl_scaled_, analytical_kl_scaled_,
+ sample_kl_diag_, analytical_kl_diag_,
+ sample_kl_chol_, analytical_kl_chol_,
+ sample_kl_identity_diag_baseline_,
+ analytical_kl_identity_diag_baseline_,
+ sample_kl_scaled_diag_baseline_, analytical_kl_scaled_diag_baseline_,
+ sample_kl_diag_diag_baseline_, analytical_kl_diag_diag_baseline_,
+ sample_kl_chol_diag_baseline_, analytical_kl_chol_diag_baseline_,
+ ] = sess.run([
+ sample_mean,
+ dist.mean(),
+ sample_covariance,
+ dist.covariance(),
+ dist.variance(),
+ dist.stddev(),
+ dist.log_det_covariance(),
+ dist.det_covariance(),
+ scale,
+ sample_kl_identity, analytical_kl_identity,
+ sample_kl_scaled, analytical_kl_scaled,
+ sample_kl_diag, analytical_kl_diag,
+ sample_kl_chol, analytical_kl_chol,
+ sample_kl_identity_diag_baseline,
+ analytical_kl_identity_diag_baseline,
+ sample_kl_scaled_diag_baseline, analytical_kl_scaled_diag_baseline,
+ sample_kl_diag_diag_baseline, analytical_kl_diag_diag_baseline,
+ sample_kl_chol_diag_baseline, analytical_kl_chol_diag_baseline,
+ ])
+
+ sample_variance_ = np.diag(sample_covariance_)
+ sample_stddev_ = np.sqrt(sample_variance_)
+ sample_det_covariance_ = np.linalg.det(sample_covariance_)
+ sample_log_det_covariance_ = np.log(sample_det_covariance_)
+
+ print("true_mean:\n{} ".format(true_mean))
+ print("sample_mean:\n{}".format(sample_mean_))
+ print("analytical_mean:\n{}".format(analytical_mean_))
+
+ print("true_covariance:\n{}".format(true_covariance))
+ print("sample_covariance:\n{}".format(sample_covariance_))
+ print("analytical_covariance:\n{}".format(
+ analytical_covariance_))
+
+ print("true_variance:\n{}".format(true_variance))
+ print("sample_variance:\n{}".format(sample_variance_))
+ print("analytical_variance:\n{}".format(analytical_variance_))
+
+ print("true_stddev:\n{}".format(true_stddev))
+ print("sample_stddev:\n{}".format(sample_stddev_))
+ print("analytical_stddev:\n{}".format(analytical_stddev_))
+
+ print("true_log_det_covariance:\n{}".format(
+ true_log_det_covariance))
+ print("sample_log_det_covariance:\n{}".format(
+ sample_log_det_covariance_))
+ print("analytical_log_det_covariance:\n{}".format(
+ analytical_log_det_covariance_))
+
+ print("true_det_covariance:\n{}".format(
+ true_det_covariance))
+ print("sample_det_covariance:\n{}".format(
+ sample_det_covariance_))
+ print("analytical_det_covariance:\n{}".format(
+ analytical_det_covariance_))
+
+ print("true_scale:\n{}".format(true_scale))
+ print("scale:\n{}".format(scale_))
+
+ print("kl_identity: analytical:{} sample:{}".format(
+ analytical_kl_identity_, sample_kl_identity_))
+
+ print("kl_scaled: analytical:{} sample:{}".format(
+ analytical_kl_scaled_, sample_kl_scaled_))
+
+ print("kl_diag: analytical:{} sample:{}".format(
+ analytical_kl_diag_, sample_kl_diag_))
+
+ print("kl_chol: analytical:{} sample:{}".format(
+ analytical_kl_chol_, sample_kl_chol_))
+
+ print("kl_identity_diag_baseline: analytical:{} sample:{}".format(
+ analytical_kl_identity_diag_baseline_,
+ sample_kl_identity_diag_baseline_))
+
+ print("kl_scaled_diag_baseline: analytical:{} sample:{}".format(
+ analytical_kl_scaled_diag_baseline_,
+ sample_kl_scaled_diag_baseline_))
+
+ print("kl_diag_diag_baseline: analytical:{} sample:{}".format(
+ analytical_kl_diag_diag_baseline_,
+ sample_kl_diag_diag_baseline_))
+
+ print("kl_chol_diag_baseline: analytical:{} sample:{}".format(
+ analytical_kl_chol_diag_baseline_,
+ sample_kl_chol_diag_baseline_))
+
+ self.assertAllClose(true_mean, sample_mean_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_mean, analytical_mean_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_covariance, sample_covariance_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_covariance, analytical_covariance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_variance, sample_variance_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_variance, analytical_variance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_stddev, sample_stddev_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_stddev, analytical_stddev_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_log_det_covariance,
+ analytical_log_det_covariance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_det_covariance, sample_det_covariance_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_det_covariance, analytical_det_covariance_,
+ atol=0., rtol=1e-5)
+
+ self.assertAllClose(true_scale, scale_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(sample_kl_identity_, analytical_kl_identity_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(sample_kl_scaled_, analytical_kl_scaled_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(sample_kl_diag_, analytical_kl_diag_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(sample_kl_chol_, analytical_kl_chol_,
+ atol=0., rtol=0.02)
+
+ self.assertAllClose(
+ sample_kl_identity_diag_baseline_,
+ analytical_kl_identity_diag_baseline_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(
+ sample_kl_scaled_diag_baseline_,
+ analytical_kl_scaled_diag_baseline_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(
+ sample_kl_diag_diag_baseline_,
+ analytical_kl_diag_diag_baseline_,
+ atol=0., rtol=0.04)
+ self.assertAllClose(
+ sample_kl_chol_diag_baseline_,
+ analytical_kl_chol_diag_baseline_,
+ atol=0., rtol=0.02)
+
+ def testImplicitLargeDiag(self):
+ mu = np.array([[1., 2, 3],
+ [11, 22, 33]]) # shape: [b, k] = [2, 3]
+ u = np.array([[[1., 2],
+ [3, 4],
+ [5, 6]],
+ [[0.5, 0.75],
+ [1, 0.25],
+ [1.5, 1.25]]]) # shape: [b, k, r] = [2, 3, 2]
+ m = np.array([[0.1, 0.2],
+ [0.4, 0.5]]) # shape: [b, r] = [2, 2]
+ scale = np.stack([
+ np.eye(3) + np.matmul(np.matmul(u[0], np.diag(m[0])),
+ np.transpose(u[0])),
+ np.eye(3) + np.matmul(np.matmul(u[1], np.diag(m[1])),
+ np.transpose(u[1])),
+ ])
+ cov = np.stack([np.matmul(scale[0], scale[0].T),
+ np.matmul(scale[1], scale[1].T)])
+ print("expected_cov:\n{}".format(cov))
+ with self.test_session():
+ mvn = ds.MultivariateNormalDiagPlusLowRank(
+ loc=mu,
+ scale_perturb_factor=u,
+ scale_perturb_diag=m)
+ self.assertAllClose(cov, mvn.covariance().eval(), atol=0., rtol=1e-6)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
new file mode 100644
index 0000000000..3838cccb22
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
@@ -0,0 +1,181 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultivariateNormal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+from tensorflow.contrib import distributions
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+ds = distributions
+
+
+class MultivariateNormalDiagTest(test.TestCase):
+ """Well tested because this is a simple override of the base class."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testScalarParams(self):
+ mu = -1.
+ diag = -5.
+ with self.test_session():
+ # TODO(b/35244539): Choose better exception, once LinOp throws it.
+ with self.assertRaises(IndexError):
+ ds.MultivariateNormalDiag(mu, diag)
+
+ def testVectorParams(self):
+ mu = [-1.]
+ diag = [-5.]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
+ self.assertAllEqual([3, 1], dist.sample(3).get_shape())
+
+ def testMean(self):
+ mu = [-1., 1]
+ diag = [1., -5]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
+ self.assertAllEqual(mu, dist.mean().eval())
+
+ def testEntropy(self):
+ mu = [-1., 1]
+ diag = [-1., 5]
+ diag_mat = np.diag(diag)
+ scipy_mvn = stats.multivariate_normal(mean=mu, cov=diag_mat**2)
+ with self.test_session():
+ dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
+ self.assertAllClose(scipy_mvn.entropy(), dist.entropy().eval(), atol=1e-4)
+
+ def testSample(self):
+ mu = [-1., 1]
+ diag = [1., -2]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
+ samps = dist.sample(int(1e3), seed=0).eval()
+ cov_mat = array_ops.matrix_diag(diag).eval()**2
+
+ self.assertAllClose(mu, samps.mean(axis=0),
+ atol=0., rtol=0.05)
+ self.assertAllClose(cov_mat, np.cov(samps.T),
+ atol=0.05, rtol=0.05)
+
+ def testCovariance(self):
+ with self.test_session():
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
+ self.assertAllClose(
+ np.diag(np.ones([3], dtype=np.float32)),
+ mvn.covariance().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_identity_multiplier=[3., 2.])
+ self.assertAllEqual([2], mvn.batch_shape)
+ self.assertAllEqual([3], mvn.event_shape)
+ self.assertAllClose(
+ np.array([[[3., 0, 0],
+ [0, 3, 0],
+ [0, 0, 3]],
+ [[2, 0, 0],
+ [0, 2, 0],
+ [0, 0, 2]]])**2.,
+ mvn.covariance().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_diag=[[3., 2, 1], [4, 5, 6]])
+ self.assertAllEqual([2], mvn.batch_shape)
+ self.assertAllEqual([3], mvn.event_shape)
+ self.assertAllClose(
+ np.array([[[3., 0, 0],
+ [0, 2, 0],
+ [0, 0, 1]],
+ [[4, 0, 0],
+ [0, 5, 0],
+ [0, 0, 6]]])**2.,
+ mvn.covariance().eval())
+
+ def testVariance(self):
+ with self.test_session():
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
+ self.assertAllClose(
+ np.ones([3], dtype=np.float32),
+ mvn.variance().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_identity_multiplier=[3., 2.])
+ self.assertAllClose(
+ np.array([[3., 3, 3],
+ [2, 2, 2]])**2.,
+ mvn.variance().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_diag=[[3., 2, 1],
+ [4, 5, 6]])
+ self.assertAllClose(
+ np.array([[3., 2, 1],
+ [4, 5, 6]])**2.,
+ mvn.variance().eval())
+
+ def testStddev(self):
+ with self.test_session():
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
+ self.assertAllClose(
+ np.ones([3], dtype=np.float32),
+ mvn.stddev().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_identity_multiplier=[3., 2.])
+ self.assertAllClose(
+ np.array([[3., 3, 3],
+ [2, 2, 2]]),
+ mvn.stddev().eval())
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=array_ops.zeros([3], dtype=dtypes.float32),
+ scale_diag=[[3., 2, 1], [4, 5, 6]])
+ self.assertAllClose(
+ np.array([[3., 2, 1],
+ [4, 5, 6]]),
+ mvn.stddev().eval())
+
+ def testMultivariateNormalDiagWithSoftplusScale(self):
+ mu = [-1.0, 1.0]
+ diag = [-1.0, -2.0]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiagWithSoftplusScale(
+ mu, diag, validate_args=True)
+ samps = dist.sample(1000, seed=0).eval()
+ cov_mat = array_ops.matrix_diag(nn_ops.softplus(diag)).eval()**2
+
+ self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
+ self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
deleted file mode 100644
index 79fdb149b4..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
+++ /dev/null
@@ -1,981 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MultivariateNormal."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-from scipy import stats
-from tensorflow.contrib import distributions
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.platform import test
-
-
-ds = distributions
-
-
-class MultivariateNormalDiagTest(test.TestCase):
- """Well tested because this is a simple override of the base class."""
-
- def setUp(self):
- self._rng = np.random.RandomState(42)
-
- def testScalarParams(self):
- mu = -1.
- diag = -5.
- with self.test_session():
- # TODO(b/35244539): Choose better exception, once LinOp throws it.
- with self.assertRaises(IndexError):
- ds.MultivariateNormalDiag(mu, diag)
-
- def testVectorParams(self):
- mu = [-1.]
- diag = [-5.]
- with self.test_session():
- dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
- self.assertAllEqual([3, 1], dist.sample(3).get_shape())
-
- def testMean(self):
- mu = [-1., 1]
- diag = [1., -5]
- with self.test_session():
- dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
- self.assertAllEqual(mu, dist.mean().eval())
-
- def testEntropy(self):
- mu = [-1., 1]
- diag = [-1., 5]
- diag_mat = np.diag(diag)
- scipy_mvn = stats.multivariate_normal(mean=mu, cov=diag_mat**2)
- with self.test_session():
- dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
- self.assertAllClose(scipy_mvn.entropy(), dist.entropy().eval(), atol=1e-4)
-
- def testSample(self):
- mu = [-1., 1]
- diag = [1., -2]
- with self.test_session():
- dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
- samps = dist.sample(int(1e3), seed=0).eval()
- cov_mat = array_ops.matrix_diag(diag).eval()**2
-
- self.assertAllClose(mu, samps.mean(axis=0),
- atol=0., rtol=0.05)
- self.assertAllClose(cov_mat, np.cov(samps.T),
- atol=0.05, rtol=0.05)
-
- def testCovariance(self):
- with self.test_session():
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
- self.assertAllClose(
- np.diag(np.ones([3], dtype=np.float32)),
- mvn.covariance().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_identity_multiplier=[3., 2.])
- self.assertAllEqual([2], mvn.batch_shape)
- self.assertAllEqual([3], mvn.event_shape)
- self.assertAllClose(
- np.array([[[3., 0, 0],
- [0, 3, 0],
- [0, 0, 3]],
- [[2, 0, 0],
- [0, 2, 0],
- [0, 0, 2]]])**2.,
- mvn.covariance().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_diag=[[3., 2, 1], [4, 5, 6]])
- self.assertAllEqual([2], mvn.batch_shape)
- self.assertAllEqual([3], mvn.event_shape)
- self.assertAllClose(
- np.array([[[3., 0, 0],
- [0, 2, 0],
- [0, 0, 1]],
- [[4, 0, 0],
- [0, 5, 0],
- [0, 0, 6]]])**2.,
- mvn.covariance().eval())
-
- def testVariance(self):
- with self.test_session():
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
- self.assertAllClose(
- np.ones([3], dtype=np.float32),
- mvn.variance().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_identity_multiplier=[3., 2.])
- self.assertAllClose(
- np.array([[3., 3, 3],
- [2, 2, 2]])**2.,
- mvn.variance().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_diag=[[3., 2, 1],
- [4, 5, 6]])
- self.assertAllClose(
- np.array([[3., 2, 1],
- [4, 5, 6]])**2.,
- mvn.variance().eval())
-
- def testStddev(self):
- with self.test_session():
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
- self.assertAllClose(
- np.ones([3], dtype=np.float32),
- mvn.stddev().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_identity_multiplier=[3., 2.])
- self.assertAllClose(
- np.array([[3., 3, 3],
- [2, 2, 2]]),
- mvn.stddev().eval())
-
- mvn = ds.MultivariateNormalDiag(
- loc=array_ops.zeros([3], dtype=dtypes.float32),
- scale_diag=[[3., 2, 1], [4, 5, 6]])
- self.assertAllClose(
- np.array([[3., 2, 1],
- [4, 5, 6]]),
- mvn.stddev().eval())
-
- def testMultivariateNormalDiagWithSoftplusScale(self):
- mu = [-1.0, 1.0]
- diag = [-1.0, -2.0]
- with self.test_session():
- dist = ds.MultivariateNormalDiagWithSoftplusScale(
- mu, diag, validate_args=True)
- samps = dist.sample(1000, seed=0).eval()
- cov_mat = array_ops.matrix_diag(nn_ops.softplus(diag)).eval()**2
-
- self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
- self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
-
-
-class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
- """Well tested because this is a simple override of the base class."""
-
- def setUp(self):
- self._rng = np.random.RandomState(42)
-
- def testDiagBroadcastBothBatchAndEvent(self):
- # batch_shape: [3], event_shape: [2]
- diag = np.array([[1., 2], [3, 4], [5, 6]])
- # batch_shape: [1], event_shape: []
- identity_multiplier = np.array([5.])
- with self.test_session():
- dist = ds.MultivariateNormalDiagPlusLowRank(
- scale_diag=diag,
- scale_identity_multiplier=identity_multiplier,
- validate_args=True)
- self.assertAllClose(
- np.array([[[1. + 5, 0],
- [0, 2 + 5]],
- [[3 + 5, 0],
- [0, 4 + 5]],
- [[5 + 5, 0],
- [0, 6 + 5]]]),
- dist.scale.to_dense().eval())
-
- def testDiagBroadcastBothBatchAndEvent2(self):
- # This test differs from `testDiagBroadcastBothBatchAndEvent` in that it
- # broadcasts batch_shape's from both the `scale_diag` and
- # `scale_identity_multiplier` args.
- # batch_shape: [3], event_shape: [2]
- diag = np.array([[1., 2], [3, 4], [5, 6]])
- # batch_shape: [3, 1], event_shape: []
- identity_multiplier = np.array([[5.], [4], [3]])
- with self.test_session():
- dist = ds.MultivariateNormalDiagPlusLowRank(
- scale_diag=diag,
- scale_identity_multiplier=identity_multiplier,
- validate_args=True)
- self.assertAllEqual(
- [3, 3, 2, 2],
- dist.scale.to_dense().get_shape())
-
- def testDiagBroadcastOnlyEvent(self):
- # batch_shape: [3], event_shape: [2]
- diag = np.array([[1., 2], [3, 4], [5, 6]])
- # batch_shape: [3], event_shape: []
- identity_multiplier = np.array([5., 4, 3])
- with self.test_session():
- dist = ds.MultivariateNormalDiagPlusLowRank(
- scale_diag=diag,
- scale_identity_multiplier=identity_multiplier,
- validate_args=True)
- self.assertAllClose(
- np.array([[[1. + 5, 0],
- [0, 2 + 5]],
- [[3 + 4, 0],
- [0, 4 + 4]],
- [[5 + 3, 0],
- [0, 6 + 3]]]), # shape: [3, 2, 2]
- dist.scale.to_dense().eval())
-
- def testDiagBroadcastMultiplierAndLoc(self):
- # batch_shape: [], event_shape: [3]
- loc = np.array([1., 0, -1])
- # batch_shape: [3], event_shape: []
- identity_multiplier = np.array([5., 4, 3])
- with self.test_session():
- dist = ds.MultivariateNormalDiagPlusLowRank(
- loc=loc,
- scale_identity_multiplier=identity_multiplier,
- validate_args=True)
- self.assertAllClose(
- np.array([[[5, 0, 0],
- [0, 5, 0],
- [0, 0, 5]],
- [[4, 0, 0],
- [0, 4, 0],
- [0, 0, 4]],
- [[3, 0, 0],
- [0, 3, 0],
- [0, 0, 3]]]),
- dist.scale.to_dense().eval())
-
- def testMean(self):
- mu = [-1.0, 1.0]
- diag_large = [1.0, 5.0]
- v = [[2.0], [3.0]]
- diag_small = [3.0]
- with self.test_session():
- dist = ds.MultivariateNormalDiagPlusLowRank(
- loc=mu,
- scale_diag=diag_large,
- scale_perturb_factor=v,
- scale_perturb_diag=diag_small,
- validate_args=True)
- self.assertAllEqual(mu, dist.mean().eval())
-
- def testSample(self):
- # TODO(jvdillon): This test should be the basis of a new test fixture which
- # is applied to every distribution. When we make this fixture, we'll also
- # separate the analytical- and sample-based tests as well as for each
- # function tested. For now, we group things so we can recycle one batch of
- # samples (thus saving resources).
-
- mu = np.array([-1., 1, 0.5], dtype=np.float32)
- diag_large = np.array([1., 0.5, 0.75], dtype=np.float32)
- diag_small = np.array([-1.1, 1.2], dtype=np.float32)
- v = np.array([[0.7, 0.8],
- [0.9, 1],
- [0.5, 0.6]], dtype=np.float32) # shape: [k, r] = [3, 2]
-
- true_mean = mu
- true_scale = np.diag(diag_large) + np.matmul(np.matmul(
- v, np.diag(diag_small)), v.T)
- true_covariance = np.matmul(true_scale, true_scale.T)
- true_variance = np.diag(true_covariance)
- true_stddev = np.sqrt(true_variance)
- true_det_covariance = np.linalg.det(true_covariance)
- true_log_det_covariance = np.log(true_det_covariance)
-
- with self.test_session() as sess:
- dist = ds.MultivariateNormalDiagPlusLowRank(
- loc=mu,
- scale_diag=diag_large,
- scale_perturb_factor=v,
- scale_perturb_diag=diag_small,
- validate_args=True)
-
- # The following distributions will test the KL divergence calculation.
- mvn_identity = ds.MultivariateNormalDiag(
- loc=np.array([1., 2, 0.25], dtype=np.float32),
- validate_args=True)
- mvn_scaled = ds.MultivariateNormalDiag(
- loc=mvn_identity.loc,
- scale_identity_multiplier=2.2,
- validate_args=True)
- mvn_diag = ds.MultivariateNormalDiag(
- loc=mvn_identity.loc,
- scale_diag=np.array([0.5, 1.5, 1.], dtype=np.float32),
- validate_args=True)
- mvn_chol = ds.MultivariateNormalTriL(
- loc=np.array([1., 2, -1], dtype=np.float32),
- scale_tril=np.array([[6., 0, 0],
- [2, 5, 0],
- [1, 3, 4]], dtype=np.float32) / 10.,
- validate_args=True)
-
- scale = dist.scale.to_dense()
-
- n = int(30e3)
- samps = dist.sample(n, seed=0)
- sample_mean = math_ops.reduce_mean(samps, 0)
- x = samps - sample_mean
- sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n
-
- sample_kl_identity = math_ops.reduce_mean(
- dist.log_prob(samps) - mvn_identity.log_prob(samps), 0)
- analytical_kl_identity = ds.kl(dist, mvn_identity)
-
- sample_kl_scaled = math_ops.reduce_mean(
- dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
- analytical_kl_scaled = ds.kl(dist, mvn_scaled)
-
- sample_kl_diag = math_ops.reduce_mean(
- dist.log_prob(samps) - mvn_diag.log_prob(samps), 0)
- analytical_kl_diag = ds.kl(dist, mvn_diag)
-
- sample_kl_chol = math_ops.reduce_mean(
- dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
- analytical_kl_chol = ds.kl(dist, mvn_chol)
-
- n = int(10e3)
- baseline = ds.MultivariateNormalDiag(
- loc=np.array([-1., 0.25, 1.25], dtype=np.float32),
- scale_diag=np.array([1.5, 0.5, 1.], dtype=np.float32),
- validate_args=True)
- samps = baseline.sample(n, seed=0)
-
- sample_kl_identity_diag_baseline = math_ops.reduce_mean(
- baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0)
- analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity)
-
- sample_kl_scaled_diag_baseline = math_ops.reduce_mean(
- baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
- analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled)
-
- sample_kl_diag_diag_baseline = math_ops.reduce_mean(
- baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0)
- analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag)
-
- sample_kl_chol_diag_baseline = math_ops.reduce_mean(
- baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0)
- analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol)
-
- [
- sample_mean_,
- analytical_mean_,
- sample_covariance_,
- analytical_covariance_,
- analytical_variance_,
- analytical_stddev_,
- analytical_log_det_covariance_,
- analytical_det_covariance_,
- scale_,
- sample_kl_identity_, analytical_kl_identity_,
- sample_kl_scaled_, analytical_kl_scaled_,
- sample_kl_diag_, analytical_kl_diag_,
- sample_kl_chol_, analytical_kl_chol_,
- sample_kl_identity_diag_baseline_,
- analytical_kl_identity_diag_baseline_,
- sample_kl_scaled_diag_baseline_, analytical_kl_scaled_diag_baseline_,
- sample_kl_diag_diag_baseline_, analytical_kl_diag_diag_baseline_,
- sample_kl_chol_diag_baseline_, analytical_kl_chol_diag_baseline_,
- ] = sess.run([
- sample_mean,
- dist.mean(),
- sample_covariance,
- dist.covariance(),
- dist.variance(),
- dist.stddev(),
- dist.log_det_covariance(),
- dist.det_covariance(),
- scale,
- sample_kl_identity, analytical_kl_identity,
- sample_kl_scaled, analytical_kl_scaled,
- sample_kl_diag, analytical_kl_diag,
- sample_kl_chol, analytical_kl_chol,
- sample_kl_identity_diag_baseline,
- analytical_kl_identity_diag_baseline,
- sample_kl_scaled_diag_baseline, analytical_kl_scaled_diag_baseline,
- sample_kl_diag_diag_baseline, analytical_kl_diag_diag_baseline,
- sample_kl_chol_diag_baseline, analytical_kl_chol_diag_baseline,
- ])
-
- sample_variance_ = np.diag(sample_covariance_)
- sample_stddev_ = np.sqrt(sample_variance_)
- sample_det_covariance_ = np.linalg.det(sample_covariance_)
- sample_log_det_covariance_ = np.log(sample_det_covariance_)
-
- print("true_mean:\n{} ".format(true_mean))
- print("sample_mean:\n{}".format(sample_mean_))
- print("analytical_mean:\n{}".format(analytical_mean_))
-
- print("true_covariance:\n{}".format(true_covariance))
- print("sample_covariance:\n{}".format(sample_covariance_))
- print("analytical_covariance:\n{}".format(
- analytical_covariance_))
-
- print("true_variance:\n{}".format(true_variance))
- print("sample_variance:\n{}".format(sample_variance_))
- print("analytical_variance:\n{}".format(analytical_variance_))
-
- print("true_stddev:\n{}".format(true_stddev))
- print("sample_stddev:\n{}".format(sample_stddev_))
- print("analytical_stddev:\n{}".format(analytical_stddev_))
-
- print("true_log_det_covariance:\n{}".format(
- true_log_det_covariance))
- print("sample_log_det_covariance:\n{}".format(
- sample_log_det_covariance_))
- print("analytical_log_det_covariance:\n{}".format(
- analytical_log_det_covariance_))
-
- print("true_det_covariance:\n{}".format(
- true_det_covariance))
- print("sample_det_covariance:\n{}".format(
- sample_det_covariance_))
- print("analytical_det_covariance:\n{}".format(
- analytical_det_covariance_))
-
- print("true_scale:\n{}".format(true_scale))
- print("scale:\n{}".format(scale_))
-
- print("kl_identity: analytical:{} sample:{}".format(
- analytical_kl_identity_, sample_kl_identity_))
-
- print("kl_scaled: analytical:{} sample:{}".format(
- analytical_kl_scaled_, sample_kl_scaled_))
-
- print("kl_diag: analytical:{} sample:{}".format(
- analytical_kl_diag_, sample_kl_diag_))
-
- print("kl_chol: analytical:{} sample:{}".format(
- analytical_kl_chol_, sample_kl_chol_))
-
- print("kl_identity_diag_baseline: analytical:{} sample:{}".format(
- analytical_kl_identity_diag_baseline_,
- sample_kl_identity_diag_baseline_))
-
- print("kl_scaled_diag_baseline: analytical:{} sample:{}".format(
- analytical_kl_scaled_diag_baseline_,
- sample_kl_scaled_diag_baseline_))
-
- print("kl_diag_diag_baseline: analytical:{} sample:{}".format(
- analytical_kl_diag_diag_baseline_,
- sample_kl_diag_diag_baseline_))
-
- print("kl_chol_diag_baseline: analytical:{} sample:{}".format(
- analytical_kl_chol_diag_baseline_,
- sample_kl_chol_diag_baseline_))
-
- self.assertAllClose(true_mean, sample_mean_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_mean, analytical_mean_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_covariance, sample_covariance_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_covariance, analytical_covariance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_variance, sample_variance_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_variance, analytical_variance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_stddev, sample_stddev_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_stddev, analytical_stddev_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_log_det_covariance,
- analytical_log_det_covariance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_det_covariance, sample_det_covariance_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_det_covariance, analytical_det_covariance_,
- atol=0., rtol=1e-5)
-
- self.assertAllClose(true_scale, scale_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(sample_kl_identity_, analytical_kl_identity_,
- atol=0., rtol=0.02)
- self.assertAllClose(sample_kl_scaled_, analytical_kl_scaled_,
- atol=0., rtol=0.02)
- self.assertAllClose(sample_kl_diag_, analytical_kl_diag_,
- atol=0., rtol=0.02)
- self.assertAllClose(sample_kl_chol_, analytical_kl_chol_,
- atol=0., rtol=0.02)
-
- self.assertAllClose(
- sample_kl_identity_diag_baseline_,
- analytical_kl_identity_diag_baseline_,
- atol=0., rtol=0.02)
- self.assertAllClose(
- sample_kl_scaled_diag_baseline_,
- analytical_kl_scaled_diag_baseline_,
- atol=0., rtol=0.02)
- self.assertAllClose(
- sample_kl_diag_diag_baseline_,
- analytical_kl_diag_diag_baseline_,
- atol=0., rtol=0.04)
- self.assertAllClose(
- sample_kl_chol_diag_baseline_,
- analytical_kl_chol_diag_baseline_,
- atol=0., rtol=0.02)
-
- def testImplicitLargeDiag(self):
- mu = np.array([[1., 2, 3],
- [11, 22, 33]]) # shape: [b, k] = [2, 3]
- u = np.array([[[1., 2],
- [3, 4],
- [5, 6]],
- [[0.5, 0.75],
- [1, 0.25],
- [1.5, 1.25]]]) # shape: [b, k, r] = [2, 3, 2]
- m = np.array([[0.1, 0.2],
- [0.4, 0.5]]) # shape: [b, r] = [2, 2]
- scale = np.stack([
- np.eye(3) + np.matmul(np.matmul(u[0], np.diag(m[0])),
- np.transpose(u[0])),
- np.eye(3) + np.matmul(np.matmul(u[1], np.diag(m[1])),
- np.transpose(u[1])),
- ])
- cov = np.stack([np.matmul(scale[0], scale[0].T),
- np.matmul(scale[1], scale[1].T)])
- print("expected_cov:\n{}".format(cov))
- with self.test_session():
- mvn = ds.MultivariateNormalDiagPlusLowRank(
- loc=mu,
- scale_perturb_factor=u,
- scale_perturb_diag=m)
- self.assertAllClose(cov, mvn.covariance().eval(), atol=0., rtol=1e-6)
-
-
-class MultivariateNormalTriLTest(test.TestCase):
-
- def setUp(self):
- self._rng = np.random.RandomState(42)
-
- def _random_chol(self, *shape):
- mat = self._rng.rand(*shape)
- chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus)
- chol = array_ops.matrix_band_part(chol, -1, 0)
- sigma = math_ops.matmul(chol, chol, adjoint_b=True)
- return chol.eval(), sigma.eval()
-
- def testLogPDFScalarBatch(self):
- with self.test_session():
- mu = self._rng.rand(2)
- chol, sigma = self._random_chol(2, 2)
- chol[1, 1] = -chol[1, 1]
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- x = self._rng.rand(2)
-
- log_pdf = mvn.log_prob(x)
- pdf = mvn.prob(x)
-
- scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
-
- expected_log_pdf = scipy_mvn.logpdf(x)
- expected_pdf = scipy_mvn.pdf(x)
- self.assertEqual((), log_pdf.get_shape())
- self.assertEqual((), pdf.get_shape())
- self.assertAllClose(expected_log_pdf, log_pdf.eval())
- self.assertAllClose(expected_pdf, pdf.eval())
-
- def testLogPDFXIsHigherRank(self):
- with self.test_session():
- mu = self._rng.rand(2)
- chol, sigma = self._random_chol(2, 2)
- chol[0, 0] = -chol[0, 0]
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- x = self._rng.rand(3, 2)
-
- log_pdf = mvn.log_prob(x)
- pdf = mvn.prob(x)
-
- scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
-
- expected_log_pdf = scipy_mvn.logpdf(x)
- expected_pdf = scipy_mvn.pdf(x)
- self.assertEqual((3,), log_pdf.get_shape())
- self.assertEqual((3,), pdf.get_shape())
- self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02)
- self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03)
-
- def testLogPDFXLowerDimension(self):
- with self.test_session():
- mu = self._rng.rand(3, 2)
- chol, sigma = self._random_chol(3, 2, 2)
- chol[0, 0, 0] = -chol[0, 0, 0]
- chol[2, 1, 1] = -chol[2, 1, 1]
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- x = self._rng.rand(2)
-
- log_pdf = mvn.log_prob(x)
- pdf = mvn.prob(x)
-
- self.assertEqual((3,), log_pdf.get_shape())
- self.assertEqual((3,), pdf.get_shape())
-
- # scipy can't do batches, so just test one of them.
- scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :])
- expected_log_pdf = scipy_mvn.logpdf(x)
- expected_pdf = scipy_mvn.pdf(x)
-
- self.assertAllClose(expected_log_pdf, log_pdf.eval()[1])
- self.assertAllClose(expected_pdf, pdf.eval()[1])
-
- def testEntropy(self):
- with self.test_session():
- mu = self._rng.rand(2)
- chol, sigma = self._random_chol(2, 2)
- chol[0, 0] = -chol[0, 0]
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- entropy = mvn.entropy()
-
- scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
- expected_entropy = scipy_mvn.entropy()
- self.assertEqual(entropy.get_shape(), ())
- self.assertAllClose(expected_entropy, entropy.eval())
-
- def testEntropyMultidimensional(self):
- with self.test_session():
- mu = self._rng.rand(3, 5, 2)
- chol, sigma = self._random_chol(3, 5, 2, 2)
- chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
- chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- entropy = mvn.entropy()
-
- # Scipy doesn't do batches, so test one of them.
- expected_entropy = stats.multivariate_normal(
- mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
- self.assertEqual(entropy.get_shape(), (3, 5))
- self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
-
- def testSample(self):
- with self.test_session():
- mu = self._rng.rand(2)
- chol, sigma = self._random_chol(2, 2)
- chol[0, 0] = -chol[0, 0]
- sigma[0, 1] = -sigma[0, 1]
- sigma[1, 0] = -sigma[1, 0]
-
- n = constant_op.constant(100000)
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- samples = mvn.sample(n, seed=137)
- sample_values = samples.eval()
- self.assertEqual(samples.get_shape(), [int(100e3), 2])
- self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
- self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
-
- def testSampleWithSampleShape(self):
- with self.test_session():
- mu = self._rng.rand(3, 5, 2)
- chol, sigma = self._random_chol(3, 5, 2, 2)
- chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
- chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
-
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- samples_val = mvn.sample((10, 11, 12), seed=137).eval()
-
- # Check sample shape
- self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)
-
- # Check sample means
- x = samples_val[:, :, :, 1, 1, :]
- self.assertAllClose(
- x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05)
-
- # Check that log_prob(samples) works
- log_prob_val = mvn.log_prob(samples_val).eval()
- x_log_pdf = log_prob_val[:, :, :, 1, 1]
- expected_log_pdf = stats.multivariate_normal(
- mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
- self.assertAllClose(expected_log_pdf, x_log_pdf)
-
- def testSampleMultiDimensional(self):
- with self.test_session():
- mu = self._rng.rand(3, 5, 2)
- chol, sigma = self._random_chol(3, 5, 2, 2)
- chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
- chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
-
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
- n = constant_op.constant(100000)
- samples = mvn.sample(n, seed=137)
- sample_values = samples.eval()
-
- self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
- self.assertAllClose(
- sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05)
- self.assertAllClose(
- np.cov(sample_values[:, 1, 1, :], rowvar=0),
- sigma[1, 1, :, :],
- atol=1e-1)
-
- def testShapes(self):
- with self.test_session():
- mu = self._rng.rand(3, 5, 2)
- chol, _ = self._random_chol(3, 5, 2, 2)
- chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
- chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
-
- mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
-
- # Shapes known at graph construction time.
- self.assertEqual((2,), tuple(mvn.event_shape.as_list()))
- self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list()))
-
- # Shapes known at runtime.
- self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval()))
- self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval()))
-
- def _random_mu_and_sigma(self, batch_shape, event_shape):
- # This ensures sigma is positive def.
- mat_shape = batch_shape + event_shape + event_shape
- mat = self._rng.randn(*mat_shape)
- perm = np.arange(mat.ndim)
- perm[-2:] = [perm[-1], perm[-2]]
- sigma = np.matmul(mat, np.transpose(mat, perm))
-
- mu_shape = batch_shape + event_shape
- mu = self._rng.randn(*mu_shape)
-
- return mu, sigma
-
- def testKLNonBatch(self):
- batch_shape = ()
- event_shape = (2,)
- with self.test_session():
- mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
- mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
- mvn_a = ds.MultivariateNormalTriL(
- loc=mu_a,
- scale_tril=np.linalg.cholesky(sigma_a),
- validate_args=True)
- mvn_b = ds.MultivariateNormalTriL(
- loc=mu_b,
- scale_tril=np.linalg.cholesky(sigma_b),
- validate_args=True)
-
- kl = ds.kl(mvn_a, mvn_b)
- self.assertEqual(batch_shape, kl.get_shape())
-
- kl_v = kl.eval()
- expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
- self.assertAllClose(expected_kl, kl_v)
-
- def testKLBatch(self):
- batch_shape = (2,)
- event_shape = (3,)
- with self.test_session():
- mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
- mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
- mvn_a = ds.MultivariateNormalTriL(
- loc=mu_a,
- scale_tril=np.linalg.cholesky(sigma_a),
- validate_args=True)
- mvn_b = ds.MultivariateNormalTriL(
- loc=mu_b,
- scale_tril=np.linalg.cholesky(sigma_b),
- validate_args=True)
-
- kl = ds.kl(mvn_a, mvn_b)
- self.assertEqual(batch_shape, kl.get_shape())
-
- kl_v = kl.eval()
- expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
- mu_b[0, :], sigma_b[0, :])
- expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
- mu_b[1, :], sigma_b[1, :])
- self.assertAllClose(expected_kl_0, kl_v[0])
- self.assertAllClose(expected_kl_1, kl_v[1])
-
- def testKLTwoIdenticalDistributionsIsZero(self):
- batch_shape = (2,)
- event_shape = (3,)
- with self.test_session():
- mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
- mvn_a = ds.MultivariateNormalTriL(
- loc=mu_a,
- scale_tril=np.linalg.cholesky(sigma_a),
- validate_args=True)
-
- # Should be zero since KL(p || p) = =.
- kl = ds.kl(mvn_a, mvn_a)
- self.assertEqual(batch_shape, kl.get_shape())
-
- kl_v = kl.eval()
- self.assertAllClose(np.zeros(*batch_shape), kl_v)
-
- def testSampleLarge(self):
- mu = np.array([-1., 1], dtype=np.float32)
- scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3.
-
- true_mean = mu
- true_scale = scale_tril
- true_covariance = np.matmul(true_scale, true_scale.T)
- true_variance = np.diag(true_covariance)
- true_stddev = np.sqrt(true_variance)
- true_det_covariance = np.linalg.det(true_covariance)
- true_log_det_covariance = np.log(true_det_covariance)
-
- with self.test_session() as sess:
- dist = ds.MultivariateNormalTriL(
- loc=mu,
- scale_tril=scale_tril,
- validate_args=True)
-
- # The following distributions will test the KL divergence calculation.
- mvn_chol = ds.MultivariateNormalTriL(
- loc=np.array([0.5, 1.2], dtype=np.float32),
- scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32),
- validate_args=True)
-
- n = int(10e3)
- samps = dist.sample(n, seed=0)
- sample_mean = math_ops.reduce_mean(samps, 0)
- x = samps - sample_mean
- sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n
-
- sample_kl_chol = math_ops.reduce_mean(
- dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
- analytical_kl_chol = ds.kl(dist, mvn_chol)
-
- scale = dist.scale.to_dense()
-
- [
- sample_mean_,
- analytical_mean_,
- sample_covariance_,
- analytical_covariance_,
- analytical_variance_,
- analytical_stddev_,
- analytical_log_det_covariance_,
- analytical_det_covariance_,
- sample_kl_chol_, analytical_kl_chol_,
- scale_,
- ] = sess.run([
- sample_mean,
- dist.mean(),
- sample_covariance,
- dist.covariance(),
- dist.variance(),
- dist.stddev(),
- dist.log_det_covariance(),
- dist.det_covariance(),
- sample_kl_chol, analytical_kl_chol,
- scale,
- ])
-
- sample_variance_ = np.diag(sample_covariance_)
- sample_stddev_ = np.sqrt(sample_variance_)
- sample_det_covariance_ = np.linalg.det(sample_covariance_)
- sample_log_det_covariance_ = np.log(sample_det_covariance_)
-
- print("true_mean:\n{} ".format(true_mean))
- print("sample_mean:\n{}".format(sample_mean_))
- print("analytical_mean:\n{}".format(analytical_mean_))
-
- print("true_covariance:\n{}".format(true_covariance))
- print("sample_covariance:\n{}".format(sample_covariance_))
- print("analytical_covariance:\n{}".format(analytical_covariance_))
-
- print("true_variance:\n{}".format(true_variance))
- print("sample_variance:\n{}".format(sample_variance_))
- print("analytical_variance:\n{}".format(analytical_variance_))
-
- print("true_stddev:\n{}".format(true_stddev))
- print("sample_stddev:\n{}".format(sample_stddev_))
- print("analytical_stddev:\n{}".format(analytical_stddev_))
-
- print("true_log_det_covariance:\n{}".format(true_log_det_covariance))
- print("sample_log_det_covariance:\n{}".format(sample_log_det_covariance_))
- print("analytical_log_det_covariance:\n{}".format(
- analytical_log_det_covariance_))
-
- print("true_det_covariance:\n{}".format(true_det_covariance))
- print("sample_det_covariance:\n{}".format(sample_det_covariance_))
- print("analytical_det_covariance:\n{}".format(analytical_det_covariance_))
-
- print("true_scale:\n{}".format(true_scale))
- print("scale:\n{}".format(scale_))
-
- print("kl_chol: analytical:{} sample:{}".format(
- analytical_kl_chol_, sample_kl_chol_))
-
- self.assertAllClose(true_mean, sample_mean_,
- atol=0., rtol=0.03)
- self.assertAllClose(true_mean, analytical_mean_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_covariance, sample_covariance_,
- atol=0., rtol=0.03)
- self.assertAllClose(true_covariance, analytical_covariance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_variance, sample_variance_,
- atol=0., rtol=0.02)
- self.assertAllClose(true_variance, analytical_variance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_stddev, sample_stddev_,
- atol=0., rtol=0.01)
- self.assertAllClose(true_stddev, analytical_stddev_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_,
- atol=0., rtol=0.04)
- self.assertAllClose(true_log_det_covariance,
- analytical_log_det_covariance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_det_covariance, sample_det_covariance_,
- atol=0., rtol=0.03)
- self.assertAllClose(true_det_covariance, analytical_det_covariance_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(true_scale, scale_,
- atol=0., rtol=1e-6)
-
- self.assertAllClose(sample_kl_chol_, analytical_kl_chol_,
- atol=0., rtol=0.02)
-
-
-def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
- """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
- # Check using numpy operations
- # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
- # So it is important to also check that KL(mvn, mvn) = 0.
- sigma_b_inv = np.linalg.inv(sigma_b)
-
- t = np.trace(sigma_b_inv.dot(sigma_a))
- q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
- k = mu_a.shape[0]
- l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
-
- return 0.5 * (t + q - k + l)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
new file mode 100644
index 0000000000..994b9877fb
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
@@ -0,0 +1,443 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultivariateNormal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+from tensorflow.contrib import distributions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+ds = distributions
+
+
+class MultivariateNormalTriLTest(test.TestCase):
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def _random_chol(self, *shape):
+ mat = self._rng.rand(*shape)
+ chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus)
+ chol = array_ops.matrix_band_part(chol, -1, 0)
+ sigma = math_ops.matmul(chol, chol, adjoint_b=True)
+ return chol.eval(), sigma.eval()
+
+ def testLogPDFScalarBatch(self):
+ with self.test_session():
+ mu = self._rng.rand(2)
+ chol, sigma = self._random_chol(2, 2)
+ chol[1, 1] = -chol[1, 1]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ x = self._rng.rand(2)
+
+ log_pdf = mvn.log_prob(x)
+ pdf = mvn.prob(x)
+
+ scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
+
+ expected_log_pdf = scipy_mvn.logpdf(x)
+ expected_pdf = scipy_mvn.pdf(x)
+ self.assertEqual((), log_pdf.get_shape())
+ self.assertEqual((), pdf.get_shape())
+ self.assertAllClose(expected_log_pdf, log_pdf.eval())
+ self.assertAllClose(expected_pdf, pdf.eval())
+
+ def testLogPDFXIsHigherRank(self):
+ with self.test_session():
+ mu = self._rng.rand(2)
+ chol, sigma = self._random_chol(2, 2)
+ chol[0, 0] = -chol[0, 0]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ x = self._rng.rand(3, 2)
+
+ log_pdf = mvn.log_prob(x)
+ pdf = mvn.prob(x)
+
+ scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
+
+ expected_log_pdf = scipy_mvn.logpdf(x)
+ expected_pdf = scipy_mvn.pdf(x)
+ self.assertEqual((3,), log_pdf.get_shape())
+ self.assertEqual((3,), pdf.get_shape())
+ self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02)
+ self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03)
+
+ def testLogPDFXLowerDimension(self):
+ with self.test_session():
+ mu = self._rng.rand(3, 2)
+ chol, sigma = self._random_chol(3, 2, 2)
+ chol[0, 0, 0] = -chol[0, 0, 0]
+ chol[2, 1, 1] = -chol[2, 1, 1]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ x = self._rng.rand(2)
+
+ log_pdf = mvn.log_prob(x)
+ pdf = mvn.prob(x)
+
+ self.assertEqual((3,), log_pdf.get_shape())
+ self.assertEqual((3,), pdf.get_shape())
+
+ # scipy can't do batches, so just test one of them.
+ scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :])
+ expected_log_pdf = scipy_mvn.logpdf(x)
+ expected_pdf = scipy_mvn.pdf(x)
+
+ self.assertAllClose(expected_log_pdf, log_pdf.eval()[1])
+ self.assertAllClose(expected_pdf, pdf.eval()[1])
+
+ def testEntropy(self):
+ with self.test_session():
+ mu = self._rng.rand(2)
+ chol, sigma = self._random_chol(2, 2)
+ chol[0, 0] = -chol[0, 0]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ entropy = mvn.entropy()
+
+ scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
+ expected_entropy = scipy_mvn.entropy()
+ self.assertEqual(entropy.get_shape(), ())
+ self.assertAllClose(expected_entropy, entropy.eval())
+
+ def testEntropyMultidimensional(self):
+ with self.test_session():
+ mu = self._rng.rand(3, 5, 2)
+ chol, sigma = self._random_chol(3, 5, 2, 2)
+ chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
+ chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ entropy = mvn.entropy()
+
+ # Scipy doesn't do batches, so test one of them.
+ expected_entropy = stats.multivariate_normal(
+ mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
+ self.assertEqual(entropy.get_shape(), (3, 5))
+ self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
+
+ def testSample(self):
+ with self.test_session():
+ mu = self._rng.rand(2)
+ chol, sigma = self._random_chol(2, 2)
+ chol[0, 0] = -chol[0, 0]
+ sigma[0, 1] = -sigma[0, 1]
+ sigma[1, 0] = -sigma[1, 0]
+
+ n = constant_op.constant(100000)
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ samples = mvn.sample(n, seed=137)
+ sample_values = samples.eval()
+ self.assertEqual(samples.get_shape(), [int(100e3), 2])
+ self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
+ self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
+
+ def testSampleWithSampleShape(self):
+ with self.test_session():
+ mu = self._rng.rand(3, 5, 2)
+ chol, sigma = self._random_chol(3, 5, 2, 2)
+ chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
+ chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
+
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ samples_val = mvn.sample((10, 11, 12), seed=137).eval()
+
+ # Check sample shape
+ self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)
+
+ # Check sample means
+ x = samples_val[:, :, :, 1, 1, :]
+ self.assertAllClose(
+ x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05)
+
+ # Check that log_prob(samples) works
+ log_prob_val = mvn.log_prob(samples_val).eval()
+ x_log_pdf = log_prob_val[:, :, :, 1, 1]
+ expected_log_pdf = stats.multivariate_normal(
+ mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
+ self.assertAllClose(expected_log_pdf, x_log_pdf)
+
+ def testSampleMultiDimensional(self):
+ with self.test_session():
+ mu = self._rng.rand(3, 5, 2)
+ chol, sigma = self._random_chol(3, 5, 2, 2)
+ chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
+ chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
+
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ n = constant_op.constant(100000)
+ samples = mvn.sample(n, seed=137)
+ sample_values = samples.eval()
+
+ self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
+ self.assertAllClose(
+ sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05)
+ self.assertAllClose(
+ np.cov(sample_values[:, 1, 1, :], rowvar=0),
+ sigma[1, 1, :, :],
+ atol=1e-1)
+
+ def testShapes(self):
+ with self.test_session():
+ mu = self._rng.rand(3, 5, 2)
+ chol, _ = self._random_chol(3, 5, 2, 2)
+ chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
+ chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
+
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+
+ # Shapes known at graph construction time.
+ self.assertEqual((2,), tuple(mvn.event_shape.as_list()))
+ self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list()))
+
+ # Shapes known at runtime.
+ self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval()))
+ self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval()))
+
+ def _random_mu_and_sigma(self, batch_shape, event_shape):
+ # This ensures sigma is positive def.
+ mat_shape = batch_shape + event_shape + event_shape
+ mat = self._rng.randn(*mat_shape)
+ perm = np.arange(mat.ndim)
+ perm[-2:] = [perm[-1], perm[-2]]
+ sigma = np.matmul(mat, np.transpose(mat, perm))
+
+ mu_shape = batch_shape + event_shape
+ mu = self._rng.randn(*mu_shape)
+
+ return mu, sigma
+
+ def testKLNonBatch(self):
+ batch_shape = ()
+ event_shape = (2,)
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
+ mvn_a = ds.MultivariateNormalTriL(
+ loc=mu_a,
+ scale_tril=np.linalg.cholesky(sigma_a),
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalTriL(
+ loc=mu_b,
+ scale_tril=np.linalg.cholesky(sigma_b),
+ validate_args=True)
+
+ kl = ds.kl(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
+ self.assertAllClose(expected_kl, kl_v)
+
+ def testKLBatch(self):
+ batch_shape = (2,)
+ event_shape = (3,)
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
+ mvn_a = ds.MultivariateNormalTriL(
+ loc=mu_a,
+ scale_tril=np.linalg.cholesky(sigma_a),
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalTriL(
+ loc=mu_b,
+ scale_tril=np.linalg.cholesky(sigma_b),
+ validate_args=True)
+
+ kl = ds.kl(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
+ mu_b[0, :], sigma_b[0, :])
+ expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
+ mu_b[1, :], sigma_b[1, :])
+ self.assertAllClose(expected_kl_0, kl_v[0])
+ self.assertAllClose(expected_kl_1, kl_v[1])
+
+ def testKLTwoIdenticalDistributionsIsZero(self):
+ batch_shape = (2,)
+ event_shape = (3,)
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ mvn_a = ds.MultivariateNormalTriL(
+ loc=mu_a,
+ scale_tril=np.linalg.cholesky(sigma_a),
+ validate_args=True)
+
+ # Should be zero since KL(p || p) = =.
+ kl = ds.kl(mvn_a, mvn_a)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ self.assertAllClose(np.zeros(*batch_shape), kl_v)
+
+ def testSampleLarge(self):
+ mu = np.array([-1., 1], dtype=np.float32)
+ scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3.
+
+ true_mean = mu
+ true_scale = scale_tril
+ true_covariance = np.matmul(true_scale, true_scale.T)
+ true_variance = np.diag(true_covariance)
+ true_stddev = np.sqrt(true_variance)
+ true_det_covariance = np.linalg.det(true_covariance)
+ true_log_det_covariance = np.log(true_det_covariance)
+
+ with self.test_session() as sess:
+ dist = ds.MultivariateNormalTriL(
+ loc=mu,
+ scale_tril=scale_tril,
+ validate_args=True)
+
+ # The following distributions will test the KL divergence calculation.
+ mvn_chol = ds.MultivariateNormalTriL(
+ loc=np.array([0.5, 1.2], dtype=np.float32),
+ scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32),
+ validate_args=True)
+
+ n = int(10e3)
+ samps = dist.sample(n, seed=0)
+ sample_mean = math_ops.reduce_mean(samps, 0)
+ x = samps - sample_mean
+ sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n
+
+ sample_kl_chol = math_ops.reduce_mean(
+ dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
+ analytical_kl_chol = ds.kl(dist, mvn_chol)
+
+ scale = dist.scale.to_dense()
+
+ [
+ sample_mean_,
+ analytical_mean_,
+ sample_covariance_,
+ analytical_covariance_,
+ analytical_variance_,
+ analytical_stddev_,
+ analytical_log_det_covariance_,
+ analytical_det_covariance_,
+ sample_kl_chol_, analytical_kl_chol_,
+ scale_,
+ ] = sess.run([
+ sample_mean,
+ dist.mean(),
+ sample_covariance,
+ dist.covariance(),
+ dist.variance(),
+ dist.stddev(),
+ dist.log_det_covariance(),
+ dist.det_covariance(),
+ sample_kl_chol, analytical_kl_chol,
+ scale,
+ ])
+
+ sample_variance_ = np.diag(sample_covariance_)
+ sample_stddev_ = np.sqrt(sample_variance_)
+ sample_det_covariance_ = np.linalg.det(sample_covariance_)
+ sample_log_det_covariance_ = np.log(sample_det_covariance_)
+
+ print("true_mean:\n{} ".format(true_mean))
+ print("sample_mean:\n{}".format(sample_mean_))
+ print("analytical_mean:\n{}".format(analytical_mean_))
+
+ print("true_covariance:\n{}".format(true_covariance))
+ print("sample_covariance:\n{}".format(sample_covariance_))
+ print("analytical_covariance:\n{}".format(analytical_covariance_))
+
+ print("true_variance:\n{}".format(true_variance))
+ print("sample_variance:\n{}".format(sample_variance_))
+ print("analytical_variance:\n{}".format(analytical_variance_))
+
+ print("true_stddev:\n{}".format(true_stddev))
+ print("sample_stddev:\n{}".format(sample_stddev_))
+ print("analytical_stddev:\n{}".format(analytical_stddev_))
+
+ print("true_log_det_covariance:\n{}".format(true_log_det_covariance))
+ print("sample_log_det_covariance:\n{}".format(sample_log_det_covariance_))
+ print("analytical_log_det_covariance:\n{}".format(
+ analytical_log_det_covariance_))
+
+ print("true_det_covariance:\n{}".format(true_det_covariance))
+ print("sample_det_covariance:\n{}".format(sample_det_covariance_))
+ print("analytical_det_covariance:\n{}".format(analytical_det_covariance_))
+
+ print("true_scale:\n{}".format(true_scale))
+ print("scale:\n{}".format(scale_))
+
+ print("kl_chol: analytical:{} sample:{}".format(
+ analytical_kl_chol_, sample_kl_chol_))
+
+ self.assertAllClose(true_mean, sample_mean_,
+ atol=0., rtol=0.03)
+ self.assertAllClose(true_mean, analytical_mean_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_covariance, sample_covariance_,
+ atol=0., rtol=0.03)
+ self.assertAllClose(true_covariance, analytical_covariance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_variance, sample_variance_,
+ atol=0., rtol=0.02)
+ self.assertAllClose(true_variance, analytical_variance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_stddev, sample_stddev_,
+ atol=0., rtol=0.01)
+ self.assertAllClose(true_stddev, analytical_stddev_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_log_det_covariance, sample_log_det_covariance_,
+ atol=0., rtol=0.04)
+ self.assertAllClose(true_log_det_covariance,
+ analytical_log_det_covariance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_det_covariance, sample_det_covariance_,
+ atol=0., rtol=0.03)
+ self.assertAllClose(true_det_covariance, analytical_det_covariance_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(true_scale, scale_,
+ atol=0., rtol=1e-6)
+
+ self.assertAllClose(sample_kl_chol_, analytical_kl_chol_,
+ atol=0., rtol=0.02)
+
+
+def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
+ """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
+ # Check using numpy operations
+ # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
+ # So it is important to also check that KL(mvn, mvn) = 0.
+ sigma_b_inv = np.linalg.inv(sigma_b)
+
+ t = np.trace(sigma_b_inv.dot(sigma_a))
+ q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
+ k = mu_a.shape[0]
+ l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
+
+ return 0.5 * (t + q - k + l)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 3695ff007a..10b4a6ceab 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -24,6 +24,7 @@ import math
import numpy as np
from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import linalg
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -32,7 +33,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -613,6 +613,82 @@ def softplus_inverse(x, name=None):
array_ops.where(is_too_large, too_large_value, y))
+# TODO(b/35290280): Add unit-tests.
+def dimension_size(x, axis):
+ """Returns the size of a specific dimension."""
+ # Since tf.gather isn't "constant-in, constant-out", we must first check the
+ # static shape or fallback to dynamic shape.
+ num_rows = (None if x.get_shape().ndims is None
+ else x.get_shape()[axis].value)
+ if num_rows is not None:
+ return num_rows
+ return array_ops.shape(x)[axis]
+
+
+# TODO(b/35290280): Add unit-tests.
+def make_diag_scale(loc, scale_diag, scale_identity_multiplier,
+ validate_args, assert_positive, name=None):
+ """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs."""
+ def _convert_to_tensor(x, name):
+ return None if x is None else ops.convert_to_tensor(x, name=name)
+
+ def _maybe_attach_assertion(x):
+ if not validate_args:
+ return x
+ if assert_positive:
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ x, message="diagonal part must be positive"),
+ ], x)
+ # TODO(b/35157376): Use `assert_none_equal` once it exists.
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_greater(
+ math_ops.abs(x),
+ array_ops.zeros([], x.dtype),
+ message="diagonal part must be non-zero"),
+ ], x)
+
+ with ops.name_scope(name, "make_diag_scale",
+ values=[loc, scale_diag, scale_identity_multiplier]):
+ loc = _convert_to_tensor(loc, name="loc")
+ scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
+ scale_identity_multiplier = _convert_to_tensor(
+ scale_identity_multiplier,
+ name="scale_identity_multiplier")
+
+ if scale_diag is not None:
+ if scale_identity_multiplier is not None:
+ scale_diag += scale_identity_multiplier[..., array_ops.newaxis]
+ return linalg.LinearOperatorDiag(
+ diag=_maybe_attach_assertion(scale_diag),
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=assert_positive)
+
+ # TODO(b/35290280): Consider inferring shape from scale_perturb_factor.
+ if loc is None:
+ raise ValueError(
+ "Cannot infer `event_shape` unless `loc` is specified.")
+
+ num_rows = dimension_size(loc, -1)
+
+ if scale_identity_multiplier is None:
+ return linalg.LinearOperatorIdentity(
+ num_rows=num_rows,
+ dtype=loc.dtype.base_dtype,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ assert_proper_shapes=validate_args)
+
+ return linalg.LinearOperatorScaledIdentity(
+ num_rows=num_rows,
+ multiplier=_maybe_attach_assertion(scale_identity_multiplier),
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=assert_positive,
+ assert_proper_shapes=validate_args)
+
+
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.
diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py
deleted file mode 100644
index a451f3a04c..0000000000
--- a/tensorflow/contrib/distributions/python/ops/mvn.py
+++ /dev/null
@@ -1,1063 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Multivariate Normal distribution classes."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib import linalg
-from tensorflow.contrib.distributions.python.ops import bijector as bijectors
-from tensorflow.contrib.distributions.python.ops import distribution_util
-from tensorflow.contrib.distributions.python.ops import kullback_leibler
-from tensorflow.contrib.distributions.python.ops import normal
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-
-
-__all__ = [
- "MultivariateNormalDiag",
- "MultivariateNormalDiagWithSoftplusScale",
- "MultivariateNormalDiagPlusLowRank",
- "MultivariateNormalTriL",
-]
-
-_mvn_sample_note = """
-`value` is a batch vector with compatible shape if `value` is a `Tensor` whose
-shape can be broadcast up to either:
-
-```python
-self.batch_shape + self.event_shape
-```
-
-or
-
-```python
-[M1, ..., Mm] + self.batch_shape + self.event_shape
-```
-
-"""
-
-
-def _convert_to_tensor(x, name):
- """Helper; same as `ops.convert_to_tensor` but passes-through `None`."""
- return x if x is None else ops.convert_to_tensor(x, name=name)
-
-
-def _event_size_from_loc(loc):
- """Helper; returns the shape of the last dimension of `loc`."""
- # Since tf.gather isn't "constant-in, constant-out", we must first check the
- # static shape or fallback to dynamic shape.
- num_rows = loc.get_shape().with_rank_at_least(1)[-1].value
- if num_rows is not None:
- return num_rows
- return array_ops.shape(loc)[-1]
-
-
-def _make_diag_scale(loc, scale_diag, scale_identity_multiplier,
- validate_args, assert_positive):
- """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs."""
- loc = _convert_to_tensor(loc, name="loc")
- scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
- scale_identity_multiplier = _convert_to_tensor(
- scale_identity_multiplier,
- name="scale_identity_multiplier")
-
- def _maybe_attach_assertion(x):
- if not validate_args:
- return x
- if assert_positive:
- return control_flow_ops.with_dependencies([
- check_ops.assert_positive(
- x, message="diagonal part must be positive"),
- ], x)
- # TODO(b/35157376): Use `assert_none_equal` once it exists.
- return control_flow_ops.with_dependencies([
- check_ops.assert_greater(
- math_ops.abs(x),
- array_ops.zeros([], x.dtype),
- message="diagonal part must be non-zero"),
- ], x)
-
- if scale_diag is not None:
- if scale_identity_multiplier is not None:
- scale_diag += scale_identity_multiplier[..., array_ops.newaxis]
- return linalg.LinearOperatorDiag(
- diag=_maybe_attach_assertion(scale_diag),
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=assert_positive)
-
- # TODO(b/34878297): Consider inferring shape from scale_perturb_factor.
- if loc is None:
- raise ValueError(
- "Cannot infer `event_shape` unless `loc` is specified.")
-
- num_rows = _event_size_from_loc(loc)
-
- if scale_identity_multiplier is None:
- return linalg.LinearOperatorIdentity(
- num_rows=num_rows,
- dtype=loc.dtype.base_dtype,
- is_self_adjoint=True,
- is_positive_definite=True,
- assert_proper_shapes=validate_args)
-
- return linalg.LinearOperatorScaledIdentity(
- num_rows=num_rows,
- multiplier=_maybe_attach_assertion(scale_identity_multiplier),
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=assert_positive,
- assert_proper_shapes=validate_args)
-
-
-class _MultivariateNormalLinearOperator(
- transformed_distribution.TransformedDistribution):
- """The multivariate normal distribution on `R^k`.
-
- The Multivariate Normal distribution is defined over `R^k` and parameterized
- by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
- `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
- matrix-multiplication.
-
- #### Mathematical Details
-
- The probability density function (pdf) is,
-
- ```none
- pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
- y = inv(scale) @ (x - loc),
- Z = (2 pi)**(0.5 k) |det(scale)|,
- ```
-
- where:
-
- * `loc` is a vector in `R^k`,
- * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
- * `Z` denotes the normalization constant, and,
- * `||y||**2` denotes the squared Euclidean norm of `y`.
-
- The MultivariateNormal distribution is a member of the [location-scale
- family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
- constructed as,
-
- ```none
- X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
- Y = scale @ X + loc
- ```
-
- #### Examples
-
- ```python
- ds = tf.contrib.distributions
- la = tf.contrib.linalg
-
- # Initialize a single 3-variate Gaussian.
- mu = [1., 2, 3]
- cov = [[ 0.36, 0.12, 0.06],
- [ 0.12, 0.29, -0.13],
- [ 0.06, -0.13, 0.26]]
- scale = tf.cholesky(cov)
- # ==> [[ 0.6, 0. , 0. ],
- # [ 0.2, 0.5, 0. ],
- # [ 0.1, -0.3, 0.4]])
-
- mvn = ds._MultivariateNormalLinearOperator(
- loc=mu,
- scale=la.LinearOperatorTriL(scale))
-
- # Covariance agrees with cholesky(cov) parameterization.
- mvn.covariance().eval()
- # ==> [[ 0.36, 0.12, 0.06],
- # [ 0.12, 0.29, -0.13],
- # [ 0.06, -0.13, 0.26]]
-
- # Compute the pdf of an`R^3` observation; return a scalar.
- mvn.prob([-1., 0, 1]).eval() # shape: []
-
- # Initialize a 2-batch of 3-variate Gaussians.
- mu = [[1., 2, 3],
- [11, 22, 33]] # shape: [2, 3]
- scale_diag = [[1., 2, 3],
- [0.5, 1, 1.5]] # shape: [2, 3]
-
- mvn = ds._MultivariateNormalLinearOperator(
- loc=mu,
- scale=la.LinearOperatorDiag(scale_diag))
-
- # Compute the pdf of two `R^3` observations; return a length-2 vector.
- x = [[-0.9, 0, 0.1],
- [-10, 0, 9]] # shape: [2, 3]
- mvn.prob(x).eval() # shape: [2]
- ```
-
- """
-
- def __init__(self,
- loc=None,
- scale=None,
- validate_args=False,
- allow_nan_stats=True,
- name="MultivariateNormalLinearOperator"):
- """Construct Multivariate Normal distribution on `R^k`.
-
- The `batch_shape` is the broadcast shape between `loc` and `scale`
- arguments.
-
- The `event_shape` is given by the last dimension of `loc` or the last
- dimension of the matrix implied by `scale`.
-
- Recall that `covariance = scale @ scale.T`.
-
- Additional leading dimensions (if any) will index batches.
-
- Args:
- loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
- implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
- `b >= 0` and `k` represents the event size.
- scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
- `[B1, ..., Bb, k, k]`.
- validate_args: `Boolean`, default `False`. Whether to validate input
- with asserts. If `validate_args` is `False`, and the inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
-
- Raises:
- ValueError: if `scale` is unspecified.
- TypeError: if not `scale.dtype.is_floating`
- """
- parameters = locals()
- if scale is None:
- raise ValueError("Missing required `scale` parameter.")
- if not scale.dtype.is_floating:
- raise TypeError("`scale` parameter must have floating-point dtype.")
-
- # Since expand_dims doesn't preserve constant-ness, we obtain the
- # non-dynamic value if possible.
- event_shape = scale.domain_dimension_tensor()
- if tensor_util.constant_value(event_shape) is not None:
- event_shape = tensor_util.constant_value(event_shape)
- event_shape = event_shape[array_ops.newaxis]
-
- super(_MultivariateNormalLinearOperator, self).__init__(
- distribution=normal.Normal(
- loc=array_ops.zeros([], dtype=scale.dtype),
- scale=array_ops.ones([], dtype=scale.dtype)),
- bijector=bijectors.AffineLinearOperator(
- shift=loc, scale=scale, validate_args=validate_args),
- batch_shape=scale.batch_shape_tensor(),
- event_shape=event_shape,
- validate_args=validate_args,
- name=name)
- self._parameters = parameters
-
- @property
- def loc(self):
- """The `loc` `Tensor` in `Y = scale @ X + loc`."""
- return self.bijector.shift
-
- @property
- def scale(self):
- """The `scale` `LinearOperator` in `Y = scale @ X + loc`."""
- return self.bijector.scale
-
- def log_det_covariance(self, name="log_det_covariance"):
- """Log of determinant of covariance matrix."""
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=self.scale.graph_parents):
- return 2. * self.scale.log_abs_determinant()
-
- def det_covariance(self, name="det_covariance"):
- """Determinant of covariance matrix."""
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=self.scale.graph_parents):
- return math_ops.exp(2.* self.scale.log_abs_determinant())
-
- @distribution_util.AppendDocstring(_mvn_sample_note)
- def _log_prob(self, x):
- return super(_MultivariateNormalLinearOperator, self)._log_prob(x)
-
- @distribution_util.AppendDocstring(_mvn_sample_note)
- def _prob(self, x):
- return super(_MultivariateNormalLinearOperator, self)._prob(x)
-
- def _mean(self):
- if self.loc is None:
- shape = array_ops.concat([
- self.batch_shape_tensor(),
- self.event_shape_tensor(),
- ], 0)
- return array_ops.zeros(shape, self.dtype)
- return array_ops.identity(self.loc)
-
- def _covariance(self):
- # TODO(b/35041434): Remove special-case logic once LinOp supports
- # `diag_part`.
- if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
- isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
- isinstance(self.scale, linalg.LinearOperatorDiag)):
- shape = array_ops.concat([self.batch_shape_tensor(),
- self.event_shape_tensor()], 0)
- diag_part = array_ops.ones(shape, self.scale.dtype)
- if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
- diag_part *= math_ops.square(
- self.scale.multiplier[..., array_ops.newaxis])
- elif isinstance(self.scale, linalg.LinearOperatorDiag):
- diag_part *= math_ops.square(self.scale.diag)
- return array_ops.matrix_diag(diag_part)
- else:
- # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
- return self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))
-
- def _variance(self):
- # TODO(b/35041434): Remove special-case logic once LinOp supports
- # `diag_part`.
- if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
- isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
- isinstance(self.scale, linalg.LinearOperatorDiag)):
- shape = array_ops.concat([self.batch_shape_tensor(),
- self.event_shape_tensor()], 0)
- diag_part = array_ops.ones(shape, self.scale.dtype)
- if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
- diag_part *= math_ops.square(
- self.scale.multiplier[..., array_ops.newaxis])
- elif isinstance(self.scale, linalg.LinearOperatorDiag):
- diag_part *= math_ops.square(self.scale.diag)
- return diag_part
- elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate)
- and self.scale.is_self_adjoint):
- return array_ops.matrix_diag_part(
- self.scale.apply(self.scale.to_dense()))
- else:
- # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
- return array_ops.matrix_diag_part(
- self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())))
-
- def _stddev(self):
- # TODO(b/35041434): Remove special-case logic once LinOp supports
- # `diag_part`.
- if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
- isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
- isinstance(self.scale, linalg.LinearOperatorDiag)):
- shape = array_ops.concat([self.batch_shape_tensor(),
- self.event_shape_tensor()], 0)
- diag_part = array_ops.ones(shape, self.scale.dtype)
- if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
- diag_part *= self.scale.multiplier[..., array_ops.newaxis]
- elif isinstance(self.scale, linalg.LinearOperatorDiag):
- diag_part *= self.scale.diag
- return math_ops.abs(diag_part)
- elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate)
- and self.scale.is_self_adjoint):
- return math_ops.sqrt(array_ops.matrix_diag_part(
- self.scale.apply(self.scale.to_dense())))
- else:
- # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
- return math_ops.sqrt(array_ops.matrix_diag_part(
- self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))))
-
- def _mode(self):
- return self._mean()
-
-
-class MultivariateNormalDiagPlusLowRank(_MultivariateNormalLinearOperator):
- """The multivariate normal distribution on `R^k`.
-
- The Multivariate Normal distribution is defined over `R^k` and parameterized
- by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
- `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
- matrix-multiplication.
-
- #### Mathematical Details
-
- The probability density function (pdf) is,
-
- ```none
- pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
- y = inv(scale) @ (x - loc),
- Z = (2 pi)**(0.5 k) |det(scale)|,
- ```
-
- where:
-
- * `loc` is a vector in `R^k`,
- * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
- * `Z` denotes the normalization constant, and,
- * `||y||**2` denotes the squared Euclidean norm of `y`.
-
- A (non-batch) `scale` matrix is:
-
- ```none
- scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
- scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
- ```
-
- where:
-
- * `scale_diag.shape = [k]`,
- * `scale_identity_multiplier.shape = []`,
- * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
- * `scale_perturb_diag.shape = [r]`.
-
- Additional leading dimensions (if any) will index batches.
-
- If both `scale_diag` and `scale_identity_multiplier` are `None`, then
- `scale` is the Identity matrix.
-
- The MultivariateNormal distribution is a member of the [location-scale
- family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
- constructed as,
-
- ```none
- X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
- Y = scale @ X + loc
- ```
-
- #### Examples
-
- ```python
- ds = tf.contrib.distributions
-
- # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
- # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
- # a rank-2 update.
- mu = [-0.5., 0, 0.5] # shape: [3]
- d = [1.5, 0.5, 2] # shape: [3]
- U = [[1., 2],
- [-1, 1],
- [2, -0.5]] # shape: [3, 2]
- m = [4., 5] # shape: [2]
- mvn = ds.MultivariateNormalDiagPlusLowRank(
- loc=mu
- scale_diag=d
- scale_perturb_factor=U,
- scale_perturb_diag=m)
-
- # Evaluate this on an observation in `R^3`, returning a scalar.
- mvn.prob([-1, 0, 1]).eval() # shape: []
-
- # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`.
- mu = [[1., 2, 3],
- [11, 22, 33]] # shape: [b, k] = [2, 3]
- U = [[[1., 2],
- [3, 4],
- [5, 6]],
- [[0.5, 0.75],
- [1,0, 0.25],
- [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2]
- m = [[0.1, 0.2],
- [0.4, 0.5]] # shape: [b, r] = [2, 2]
-
- mvn = ds.MultivariateNormalDiagPlusLowRank(
- loc=mu,
- scale_perturb_factor=U,
- scale_perturb_diag=m)
-
- mvn.covariance().eval() # shape: [2, 3, 3]
- # ==> [[[ 15.63 31.57 48.51]
- # [ 31.57 69.31 105.05]
- # [ 48.51 105.05 162.59]]
- #
- # [[ 2.59 1.41 3.35]
- # [ 1.41 2.71 3.34]
- # [ 3.35 3.34 8.35]]]
-
- # Compute the pdf of two `R^3` observations (one from each batch);
- # return a length-2 vector.
- x = [[-0.9, 0, 0.1],
- [-10, 0, 9]] # shape: [2, 3]
- mvn.prob(x).eval() # shape: [2]
- ```
-
- """
-
- def __init__(self,
- loc=None,
- scale_diag=None,
- scale_identity_multiplier=None,
- scale_perturb_factor=None,
- scale_perturb_diag=None,
- validate_args=False,
- allow_nan_stats=True,
- name="MultivariateNormalDiagPlusLowRank"):
- """Construct Multivariate Normal distribution on `R^k`.
-
- The `batch_shape` is the broadcast shape between `loc` and `scale`
- arguments.
-
- The `event_shape` is given by the last dimension of `loc` or the last
- dimension of the matrix implied by `scale`.
-
- Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
-
- ```none
- scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
- scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
- ```
-
- where:
-
- * `scale_diag.shape = [k]`,
- * `scale_identity_multiplier.shape = []`,
- * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
- * `scale_perturb_diag.shape = [r]`.
-
- Additional leading dimensions (if any) will index batches.
-
- If both `scale_diag` and `scale_identity_multiplier` are `None`, then
- `scale` is the Identity matrix.
-
- Args:
- loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
- implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
- `b >= 0` and `k` represents the event size.
- scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
- matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
- and characterizes `b`-batches of `k x k` diagonal matrices added to
- `scale`. When both `scale_identity_multiplier` and `scale_diag` are
- `None` then `scale` is the `Identity`.
- scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
- a scaled-identity-matrix added to `scale`. May have shape
- `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
- `k x k` identity matrices added to `scale`. When both
- `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
- the `Identity`.
- scale_perturb_factor: Floating-point `Tensor` representing a rank-`r`
- perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`,
- `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`.
- When `None`, no rank-`r` update is added to `scale`.
- scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix
- inside the rank-`r` perturbation added to `scale`. May have shape
- `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r`
- diagonal matrices inside the perturbation added to `scale`. When
- `None`, an identity matrix is used inside the perturbation. Can only be
- specified if `scale_perturb_factor` is also specified.
- validate_args: Python `Boolean`, default `False`. When `True` distribution
- parameters are checked for validity despite possibly degrading runtime
- performance. When `False` invalid inputs may silently render incorrect
- outputs.
- allow_nan_stats: Python `Boolean`, default `True`. When `True`,
- statistics (e.g., mean, mode, variance) use the value "`NaN`" to
- indicate the result is undefined. When `False`, an exception is raised
- if one or more of the statistic's batch members are undefined.
- name: `String` name prefixed to Ops created by this class.
-
- Raises:
- ValueError: if at most `scale_identity_multiplier` is specified.
- """
- parameters = locals()
- with ops.name_scope(name) as ns:
- with ops.name_scope("init", values=[
- loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
- scale_perturb_diag]):
- has_low_rank = (scale_perturb_factor is not None or
- scale_perturb_diag is not None)
- scale = _make_diag_scale(
- loc=loc,
- scale_diag=scale_diag,
- scale_identity_multiplier=scale_identity_multiplier,
- validate_args=validate_args,
- assert_positive=has_low_rank)
- scale_perturb_factor = _convert_to_tensor(
- scale_perturb_factor,
- name="scale_perturb_factor")
- scale_perturb_diag = _convert_to_tensor(
- scale_perturb_diag,
- name="scale_perturb_diag")
- if has_low_rank:
- scale = linalg.LinearOperatorUDVHUpdate(
- scale,
- u=scale_perturb_factor,
- diag=scale_perturb_diag,
- is_diag_positive=scale_perturb_diag is None,
- is_non_singular=True, # Implied by is_positive_definite=True.
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
- super(MultivariateNormalDiagPlusLowRank, self).__init__(
- loc=loc,
- scale=scale,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
- name=ns)
- self._parameters = parameters
-
-
-class MultivariateNormalDiag(_MultivariateNormalLinearOperator):
- """The multivariate normal distribution on `R^k`.
-
- The Multivariate Normal distribution is defined over `R^k` and parameterized
- by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
- `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
- matrix-multiplication.
-
- #### Mathematical Details
-
- The probability density function (pdf) is,
-
- ```none
- pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
- y = inv(scale) @ (x - loc),
- Z = (2 pi)**(0.5 k) |det(scale)|,
- ```
-
- where:
-
- * `loc` is a vector in `R^k`,
- * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
- * `Z` denotes the normalization constant, and,
- * `||y||**2` denotes the squared Euclidean norm of `y`.
-
- A (non-batch) `scale` matrix is:
-
- ```none
- scale = diag(scale_diag + scale_identity_multiplier * ones(k))
- ```
-
- where:
-
- * `scale_diag.shape = [k]`, and,
- * `scale_identity_multiplier.shape = []`.
-
- Additional leading dimensions (if any) will index batches.
-
- If both `scale_diag` and `scale_identity_multiplier` are `None`, then
- `scale` is the Identity matrix.
-
- The MultivariateNormal distribution is a member of the [location-scale
- family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
- constructed as,
-
- ```none
- X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
- Y = scale @ X + loc
- ```
-
- #### Examples
-
- ```python
- ds = tf.contrib.distributions
-
- # Initialize a single 2-variate Gaussian.
- mvn = ds.MultivariateNormalDiag(
- loc=[1., -1],
- scale_diag=[1, 2.])
-
- mvn.mean().eval()
- # ==> [1., -1]
-
- mvn.stddev().eval()
- # ==> [1., 2]
-
- # Evaluate this on an observation in `R^2`, returning a scalar.
- mvn.prob([-1., 0]).eval() # shape: []
-
- # Initialize a 3-batch, 2-variate scaled-identity Gaussian.
- mvn = ds.MultivariateNormalDiag(
- loc=[1., -1],
- scale_identity_multiplier=[1, 2., 3])
-
- mvn.mean().eval() # shape: [3, 2]
- # ==> [[1., -1]
- # [1, -1],
- # [1, -1]]
-
- mvn.stddev().eval() # shape: [3, 2]
- # ==> [[1., 1],
- # [2, 2],
- # [3, 3]]
-
- # Evaluate this on an observation in `R^2`, returning a length-3 vector.
- mvn.prob([-1., 0]).eval() # shape: [3]
-
- # Initialize a 2-batch of 3-variate Gaussians.
- mvn = ds.MultivariateNormalDiag(
- loc=[[1., 2, 3],
- [11, 22, 33]] # shape: [2, 3]
- scale_diag=[[1., 2, 3],
- [0.5, 1, 1.5]]) # shape: [2, 3]
-
- # Evaluate this on a two observations, each in `R^3`, returning a length-2
- # vector.
- x = [[-1., 0, 1],
- [-11, 0, 11.]] # shape: [2, 3].
- mvn.prob(x).eval() # shape: [2]
- ```
-
- """
-
- def __init__(self,
- loc=None,
- scale_diag=None,
- scale_identity_multiplier=None,
- validate_args=False,
- allow_nan_stats=True,
- name="MultivariateNormalDiag"):
- """Construct Multivariate Normal distribution on `R^k`.
-
- The `batch_shape` is the broadcast shape between `loc` and `scale`
- arguments.
-
- The `event_shape` is given by the last dimension of `loc` or the last
- dimension of the matrix implied by `scale`.
-
- Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix
- is:
-
- ```none
- scale = diag(scale_diag + scale_identity_multiplier * ones(k))
- ```
-
- where:
-
- * `scale_diag.shape = [k]`, and,
- * `scale_identity_multiplier.shape = []`.
-
- Additional leading dimensions (if any) will index batches.
-
- If both `scale_diag` and `scale_identity_multiplier` are `None`, then
- `scale` is the Identity matrix.
-
- Args:
- loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
- implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
- `b >= 0` and `k` represents the event size.
- scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
- matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
- and characterizes `b`-batches of `k x k` diagonal matrices added to
- `scale`. When both `scale_identity_multiplier` and `scale_diag` are
- `None` then `scale` is the `Identity`.
- scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
- a scaled-identity-matrix added to `scale`. May have shape
- `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
- `k x k` identity matrices added to `scale`. When both
- `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
- the `Identity`.
- validate_args: Python `Boolean`, default `False`. When `True` distribution
- parameters are checked for validity despite possibly degrading runtime
- performance. When `False` invalid inputs may silently render incorrect
- outputs.
- allow_nan_stats: Python `Boolean`, default `True`. When `True`,
- statistics (e.g., mean, mode, variance) use the value "`NaN`" to
- indicate the result is undefined. When `False`, an exception is raised
- if one or more of the statistic's batch members are undefined.
- name: `String` name prefixed to Ops created by this class.
-
- Raises:
- ValueError: if at most `scale_identity_multiplier` is specified.
- """
- parameters = locals()
- with ops.name_scope(name) as ns:
- with ops.name_scope("init", values=[
- loc, scale_diag, scale_identity_multiplier]):
- scale = _make_diag_scale(
- loc=loc,
- scale_diag=scale_diag,
- scale_identity_multiplier=scale_identity_multiplier,
- validate_args=validate_args,
- assert_positive=False)
- super(MultivariateNormalDiag, self).__init__(
- loc=loc,
- scale=scale,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
- name=ns)
- self._parameters = parameters
-
-
-class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
- """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`."""
-
- def __init__(self,
- loc,
- scale_diag,
- validate_args=False,
- allow_nan_stats=True,
- name="MultivariateNormalDiagWithSoftplusScale"):
- parameters = locals()
- with ops.name_scope(name, values=[scale_diag]) as ns:
- super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
- loc=loc,
- scale_diag=nn.softplus(scale_diag),
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
- name=ns)
- self._parameters = parameters
-
-
-class MultivariateNormalTriL(_MultivariateNormalLinearOperator):
- """The multivariate normal distribution on `R^k`.
-
- The Multivariate Normal distribution is defined over `R^k` and parameterized
- by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
- `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
- matrix-multiplication.
-
- #### Mathematical Details
-
- The probability density function (pdf) is,
-
- ```none
- pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
- y = inv(scale) @ (x - loc),
- Z = (2 pi)**(0.5 k) |det(scale)|,
- ```
-
- where:
-
- * `loc` is a vector in `R^k`,
- * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
- * `Z` denotes the normalization constant, and,
- * `||y||**2` denotes the squared Euclidean norm of `y`.
-
- A (non-batch) `scale` matrix is:
-
- ```none
- scale = scale_tril
- ```
-
- where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal,
- i.e., `tf.diag_part(scale_tril) != 0`.
-
- Additional leading dimensions (if any) will index batches.
-
- The MultivariateNormal distribution is a member of the [location-scale
- family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
- constructed as,
-
- ```none
- X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
- Y = scale @ X + loc
- ```
-
- Trainable (batch) Cholesky matrices can be created with
- `ds.matrix_diag_transform()` and/or `ds.fill_lower_triangular()`
-
- #### Examples
-
- ```python
- ds = tf.contrib.distributions
-
- # Initialize a single 3-variate Gaussian.
- mu = [1., 2, 3]
- cov = [[ 0.36, 0.12, 0.06],
- [ 0.12, 0.29, -0.13],
- [ 0.06, -0.13, 0.26]]
- scale = tf.cholesky(cov)
- # ==> [[ 0.6, 0. , 0. ],
- # [ 0.2, 0.5, 0. ],
- # [ 0.1, -0.3, 0.4]])
- mvn = ds.MultivariateNormalTriL(
- loc=mu,
- scale_tril=scale)
-
- mvn.mean().eval()
- # ==> [1., 2, 3]
-
- # Covariance agrees with cholesky(cov) parameterization.
- mvn.covariance().eval()
- # ==> [[ 0.36, 0.12, 0.06],
- # [ 0.12, 0.29, -0.13],
- # [ 0.06, -0.13, 0.26]]
-
- # Compute the pdf of an observation in `R^3` ; return a scalar.
- mvn.prob([-1., 0, 1]).eval() # shape: []
-
- # Initialize a 2-batch of 3-variate Gaussians.
- mu = [[1., 2, 3],
- [11, 22, 33]] # shape: [2, 3]
- tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal.
- mvn = ds.MultivariateNormalTriL(
- loc=mu,
- scale_tril=tril)
-
- # Compute the pdf of two `R^3` observations; return a length-2 vector.
- x = [[-0.9, 0, 0.1],
- [-10, 0, 9]] # shape: [2, 3]
- mvn.prob(x).eval() # shape: [2]
-
- ```
-
- """
-
- def __init__(self,
- loc=None,
- scale_tril=None,
- validate_args=False,
- allow_nan_stats=True,
- name="MultivariateNormalTriL"):
- """Construct Multivariate Normal distribution on `R^k`.
-
- The `batch_shape` is the broadcast shape between `loc` and `scale`
- arguments.
-
- The `event_shape` is given by the last dimension of `loc` or the last
- dimension of the matrix implied by `scale`.
-
- Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix
- is:
-
- ```none
- scale = scale_tril
- ```
-
- where `scale_tril` is lower-triangular `k x k` matrix with non-zero
- diagonal, i.e., `tf.diag_part(scale_tril) != 0`.
-
- Additional leading dimensions (if any) will index batches.
-
- Args:
- loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
- implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
- `b >= 0` and `k` is the event size.
- scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
- diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
- `b >= 0` and `k` is the event size.
- validate_args: Python `Boolean`, default `False`. When `True` distribution
- parameters are checked for validity despite possibly degrading runtime
- performance. When `False` invalid inputs may silently render incorrect
- outputs.
- allow_nan_stats: Python `Boolean`, default `True`. When `True`,
- statistics (e.g., mean, mode, variance) use the value "`NaN`" to
- indicate the result is undefined. When `False`, an exception is raised
- if one or more of the statistic's batch members are undefined.
- name: `String` name prefixed to Ops created by this class.
-
- Raises:
- ValueError: if neither `loc` nor `scale_tril` are specified.
- """
- parameters = locals()
- if loc is None and scale_tril is None:
- raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
- with ops.name_scope(name) as ns:
- with ops.name_scope("init", values=[loc, scale_tril]):
- loc = _convert_to_tensor(loc, name="loc")
- scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
- if scale_tril is None:
- scale = linalg.LinearOperatorIdentity(
- num_rows=_event_size_from_loc(loc),
- dtype=loc.dtype,
- is_self_adjoint=True,
- is_positive_definite=True,
- assert_proper_shapes=validate_args)
- else:
- if validate_args:
- scale_tril = control_flow_ops.with_dependencies([
- # TODO(b/35157376): Use `assert_none_equal` once it exists.
- check_ops.assert_greater(
- math_ops.abs(array_ops.matrix_diag_part(scale_tril)),
- array_ops.zeros([], scale_tril.dtype),
- message="`scale_tril` must have non-zero diagonal"),
- ], scale_tril)
- scale = linalg.LinearOperatorTriL(
- scale_tril,
- is_non_singular=True,
- is_self_adjoint=False,
- is_positive_definite=False)
- super(MultivariateNormalTriL, self).__init__(
- loc=loc,
- scale=scale,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
- name=ns)
- self._parameters = parameters
-
-
-@kullback_leibler.RegisterKL(_MultivariateNormalLinearOperator,
- _MultivariateNormalLinearOperator)
-def _kl_brute_force(a, b, name=None):
- """Batched KL divergence `KL(a || b)` for multivariate Normals.
-
- With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
- covariance `C_a`, `C_b` respectively,
-
- ```
- KL(a || b) = 0.5 * ( L - k + T + Q ),
- L := Log[Det(C_b)] - Log[Det(C_a)]
- T := trace(C_b^{-1} C_a),
- Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
- ```
-
- This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
- methods for solving systems with `C_b` may be available, a dense version of
- (the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
- is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
- and `y`.
-
- Args:
- a: Instance of `_MultivariateNormalLinearOperator`.
- b: Instance of `_MultivariateNormalLinearOperator`.
- name: (optional) name to use for created ops. Default "kl_mvn".
-
- Returns:
- Batchwise `KL(a || b)`.
- """
-
- def squared_frobenius_norm(x):
- """Helper to make KL calculation slightly more readable."""
- # http://mathworld.wolfram.com/FrobeniusNorm.html
- return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))
-
- # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
- # supports something like:
- # A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
- def is_diagonal(x):
- """Helper to identify if `LinearOperator` has only a diagonal component."""
- return (isinstance(x, linalg.LinearOperatorIdentity) or
- isinstance(x, linalg.LinearOperatorScaledIdentity) or
- isinstance(x, linalg.LinearOperatorDiag))
-
- with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] +
- a.scale.graph_parents + b.scale.graph_parents):
- # Calculation is based on:
- # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
- # and,
- # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
- # i.e.,
- # If Ca = AA', Cb = BB', then
- # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
- # = tr[inv(B) A A' inv(B)']
- # = tr[(inv(B) A) (inv(B) A)']
- # = sum_{ij} (inv(B) A)_{ij}^2
- # = ||inv(B) A||_F**2
- # where ||.||_F is the Frobenius norm and the second equality follows from
- # the cyclic permutation property.
- if is_diagonal(a.scale) and is_diagonal(b.scale):
- # Using `stddev` because it handles expansion of Identity cases.
- b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
- else:
- b_inv_a = b.scale.solve(a.scale.to_dense())
- kl_div = (b.scale.log_abs_determinant()
- - a.scale.log_abs_determinant()
- + 0.5 * (
- - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype)
- + squared_frobenius_norm(b_inv_a)
- + squared_frobenius_norm(b.scale.solve(
- (b.mean() - a.mean())[..., array_ops.newaxis]))))
- kl_div.set_shape(array_ops.broadcast_static_shape(
- a.batch_shape, b.batch_shape))
- return kl_div
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
new file mode 100644
index 0000000000..edc5251769
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -0,0 +1,232 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multivariate Normal distribution classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+
+
+__all__ = [
+ "MultivariateNormalDiag",
+ "MultivariateNormalDiagWithSoftplusScale",
+]
+
+
+class MultivariateNormalDiag(
+ mvn_linop.MultivariateNormalLinearOperator):
+ """The multivariate normal distribution on `R^k`.
+
+ The Multivariate Normal distribution is defined over `R^k` and parameterized
+ by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
+ `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
+ matrix-multiplication.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
+ y = inv(scale) @ (x - loc),
+ Z = (2 pi)**(0.5 k) |det(scale)|,
+ ```
+
+ where:
+
+ * `loc` is a vector in `R^k`,
+ * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
+ * `Z` denotes the normalization constant, and,
+ * `||y||**2` denotes the squared Euclidean norm of `y`.
+
+ A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = diag(scale_diag + scale_identity_multiplier * ones(k))
+ ```
+
+ where:
+
+ * `scale_diag.shape = [k]`, and,
+ * `scale_identity_multiplier.shape = []`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ If both `scale_diag` and `scale_identity_multiplier` are `None`, then
+ `scale` is the Identity matrix.
+
+ The MultivariateNormal distribution is a member of the [location-scale
+ family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
+ Y = scale @ X + loc
+ ```
+
+ #### Examples
+
+ ```python
+ ds = tf.contrib.distributions
+
+ # Initialize a single 2-variate Gaussian.
+ mvn = ds.MultivariateNormalDiag(
+ loc=[1., -1],
+ scale_diag=[1, 2.])
+
+ mvn.mean().eval()
+ # ==> [1., -1]
+
+ mvn.stddev().eval()
+ # ==> [1., 2]
+
+ # Evaluate this on an observation in `R^2`, returning a scalar.
+ mvn.prob([-1., 0]).eval() # shape: []
+
+ # Initialize a 3-batch, 2-variate scaled-identity Gaussian.
+ mvn = ds.MultivariateNormalDiag(
+ loc=[1., -1],
+ scale_identity_multiplier=[1, 2., 3])
+
+ mvn.mean().eval() # shape: [3, 2]
+ # ==> [[1., -1]
+ # [1, -1],
+ # [1, -1]]
+
+ mvn.stddev().eval() # shape: [3, 2]
+ # ==> [[1., 1],
+ # [2, 2],
+ # [3, 3]]
+
+ # Evaluate this on an observation in `R^2`, returning a length-3 vector.
+ mvn.prob([-1., 0]).eval() # shape: [3]
+
+ # Initialize a 2-batch of 3-variate Gaussians.
+ mvn = ds.MultivariateNormalDiag(
+ loc=[[1., 2, 3],
+ [11, 22, 33]] # shape: [2, 3]
+ scale_diag=[[1., 2, 3],
+ [0.5, 1, 1.5]]) # shape: [2, 3]
+
+ # Evaluate this on a two observations, each in `R^3`, returning a length-2
+ # vector.
+ x = [[-1., 0, 1],
+ [-11, 0, 11.]] # shape: [2, 3].
+ mvn.prob(x).eval() # shape: [2]
+ ```
+
+ """
+
+ def __init__(self,
+ loc=None,
+ scale_diag=None,
+ scale_identity_multiplier=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="MultivariateNormalDiag"):
+ """Construct Multivariate Normal distribution on `R^k`.
+
+ The `batch_shape` is the broadcast shape between `loc` and `scale`
+ arguments.
+
+ The `event_shape` is given by the last dimension of `loc` or the last
+ dimension of the matrix implied by `scale`.
+
+ Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = diag(scale_diag + scale_identity_multiplier * ones(k))
+ ```
+
+ where:
+
+ * `scale_diag.shape = [k]`, and,
+ * `scale_identity_multiplier.shape = []`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ If both `scale_diag` and `scale_identity_multiplier` are `None`, then
+ `scale` is the Identity matrix.
+
+ Args:
+ loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
+ implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
+ `b >= 0` and `k` is the event size.
+ scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
+ matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
+ and characterizes `b`-batches of `k x k` diagonal matrices added to
+ `scale`. When both `scale_identity_multiplier` and `scale_diag` are
+ `None` then `scale` is the `Identity`.
+ scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
+ a scaled-identity-matrix added to `scale`. May have shape
+ `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
+ `k x k` identity matrices added to `scale`. When both
+ `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
+ the `Identity`.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
+
+ Raises:
+ ValueError: if at most `scale_identity_multiplier` is specified.
+ """
+ parameters = locals()
+ with ops.name_scope(name) as ns:
+ with ops.name_scope("init", values=[
+ loc, scale_diag, scale_identity_multiplier]):
+ scale = distribution_util.make_diag_scale(
+ loc=loc,
+ scale_diag=scale_diag,
+ scale_identity_multiplier=scale_identity_multiplier,
+ validate_args=validate_args,
+ assert_positive=False)
+ super(MultivariateNormalDiag, self).__init__(
+ loc=loc,
+ scale=scale,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=ns)
+ self._parameters = parameters
+
+
+class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
+ """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`."""
+
+ def __init__(self,
+ loc,
+ scale_diag,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="MultivariateNormalDiagWithSoftplusScale"):
+ parameters = locals()
+ with ops.name_scope(name, values=[scale_diag]) as ns:
+ super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
+ loc=loc,
+ scale_diag=nn.softplus(scale_diag),
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=ns)
+ self._parameters = parameters
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
new file mode 100644
index 0000000000..51487cf3a3
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -0,0 +1,255 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multivariate Normal distribution classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import linalg
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
+from tensorflow.python.framework import ops
+
+
+__all__ = [
+ "MultivariateNormalDiagPlusLowRank",
+]
+
+
+class MultivariateNormalDiagPlusLowRank(
+ mvn_linop.MultivariateNormalLinearOperator):
+ """The multivariate normal distribution on `R^k`.
+
+ The Multivariate Normal distribution is defined over `R^k` and parameterized
+ by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
+ `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
+ matrix-multiplication.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
+ y = inv(scale) @ (x - loc),
+ Z = (2 pi)**(0.5 k) |det(scale)|,
+ ```
+
+ where:
+
+ * `loc` is a vector in `R^k`,
+ * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
+ * `Z` denotes the normalization constant, and,
+ * `||y||**2` denotes the squared Euclidean norm of `y`.
+
+ A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
+ scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
+ ```
+
+ where:
+
+ * `scale_diag.shape = [k]`,
+ * `scale_identity_multiplier.shape = []`,
+ * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
+ * `scale_perturb_diag.shape = [r]`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ If both `scale_diag` and `scale_identity_multiplier` are `None`, then
+ `scale` is the Identity matrix.
+
+ The MultivariateNormal distribution is a member of the [location-scale
+ family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
+ Y = scale @ X + loc
+ ```
+
+ #### Examples
+
+ ```python
+ ds = tf.contrib.distributions
+
+ # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
+ # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
+ # a rank-2 update.
+ mu = [-0.5., 0, 0.5] # shape: [3]
+ d = [1.5, 0.5, 2] # shape: [3]
+ U = [[1., 2],
+ [-1, 1],
+ [2, -0.5]] # shape: [3, 2]
+ m = [4., 5] # shape: [2]
+ mvn = ds.MultivariateNormalDiagPlusLowRank(
+ loc=mu
+ scale_diag=d
+ scale_perturb_factor=U,
+ scale_perturb_diag=m)
+
+ # Evaluate this on an observation in `R^3`, returning a scalar.
+ mvn.prob([-1, 0, 1]).eval() # shape: []
+
+ # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`.
+ mu = [[1., 2, 3],
+ [11, 22, 33]] # shape: [b, k] = [2, 3]
+ U = [[[1., 2],
+ [3, 4],
+ [5, 6]],
+ [[0.5, 0.75],
+ [1,0, 0.25],
+ [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2]
+ m = [[0.1, 0.2],
+ [0.4, 0.5]] # shape: [b, r] = [2, 2]
+
+ mvn = ds.MultivariateNormalDiagPlusLowRank(
+ loc=mu,
+ scale_perturb_factor=U,
+ scale_perturb_diag=m)
+
+ mvn.covariance().eval() # shape: [2, 3, 3]
+ # ==> [[[ 15.63 31.57 48.51]
+ # [ 31.57 69.31 105.05]
+ # [ 48.51 105.05 162.59]]
+ #
+ # [[ 2.59 1.41 3.35]
+ # [ 1.41 2.71 3.34]
+ # [ 3.35 3.34 8.35]]]
+
+ # Compute the pdf of two `R^3` observations (one from each batch);
+ # return a length-2 vector.
+ x = [[-0.9, 0, 0.1],
+ [-10, 0, 9]] # shape: [2, 3]
+ mvn.prob(x).eval() # shape: [2]
+ ```
+
+ """
+
+ def __init__(self,
+ loc=None,
+ scale_diag=None,
+ scale_identity_multiplier=None,
+ scale_perturb_factor=None,
+ scale_perturb_diag=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="MultivariateNormalDiagPlusLowRank"):
+ """Construct Multivariate Normal distribution on `R^k`.
+
+ The `batch_shape` is the broadcast shape between `loc` and `scale`
+ arguments.
+
+ The `event_shape` is given by the last dimension of `loc` or the last
+ dimension of the matrix implied by `scale`.
+
+ Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
+ scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
+ ```
+
+ where:
+
+ * `scale_diag.shape = [k]`,
+ * `scale_identity_multiplier.shape = []`,
+ * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
+ * `scale_perturb_diag.shape = [r]`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ If both `scale_diag` and `scale_identity_multiplier` are `None`, then
+ `scale` is the Identity matrix.
+
+ Args:
+ loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
+ implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
+ `b >= 0` and `k` is the event size.
+ scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
+ matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
+ and characterizes `b`-batches of `k x k` diagonal matrices added to
+ `scale`. When both `scale_identity_multiplier` and `scale_diag` are
+ `None` then `scale` is the `Identity`.
+ scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
+ a scaled-identity-matrix added to `scale`. May have shape
+ `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
+ `k x k` identity matrices added to `scale`. When both
+ `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
+ the `Identity`.
+ scale_perturb_factor: Floating-point `Tensor` representing a rank-`r`
+ perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`,
+ `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`.
+ When `None`, no rank-`r` update is added to `scale`.
+ scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix
+ inside the rank-`r` perturbation added to `scale`. May have shape
+ `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r`
+ diagonal matrices inside the perturbation added to `scale`. When
+ `None`, an identity matrix is used inside the perturbation. Can only be
+ specified if `scale_perturb_factor` is also specified.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
+
+ Raises:
+ ValueError: if at most `scale_identity_multiplier` is specified.
+ """
+ parameters = locals()
+ def _convert_to_tensor(x, name):
+ return None if x is None else ops.convert_to_tensor(x, name=name)
+ with ops.name_scope(name) as ns:
+ with ops.name_scope("init", values=[
+ loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
+ scale_perturb_diag]):
+ has_low_rank = (scale_perturb_factor is not None or
+ scale_perturb_diag is not None)
+ scale = distribution_util.make_diag_scale(
+ loc=loc,
+ scale_diag=scale_diag,
+ scale_identity_multiplier=scale_identity_multiplier,
+ validate_args=validate_args,
+ assert_positive=has_low_rank)
+ scale_perturb_factor = _convert_to_tensor(
+ scale_perturb_factor,
+ name="scale_perturb_factor")
+ scale_perturb_diag = _convert_to_tensor(
+ scale_perturb_diag,
+ name="scale_perturb_diag")
+ if has_low_rank:
+ scale = linalg.LinearOperatorUDVHUpdate(
+ scale,
+ u=scale_perturb_factor,
+ diag=scale_perturb_diag,
+ is_diag_positive=scale_perturb_diag is None,
+ is_non_singular=True, # Implied by is_positive_definite=True.
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ is_square=True)
+ super(MultivariateNormalDiagPlusLowRank, self).__init__(
+ loc=loc,
+ scale=scale,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=ns)
+ self._parameters = parameters
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
new file mode 100644
index 0000000000..f6f26a0b1d
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -0,0 +1,383 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multivariate Normal distribution classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import linalg
+from tensorflow.contrib.distributions.python.ops import bijector as bijectors
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import kullback_leibler
+from tensorflow.contrib.distributions.python.ops import normal
+from tensorflow.contrib.distributions.python.ops import transformed_distribution
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+__all__ = [
+ "MultivariateNormalLinearOperator",
+]
+
+
+_mvn_sample_note = """
+`value` is a batch vector with compatible shape if `value` is a `Tensor` whose
+shape can be broadcast up to either:
+
+```python
+self.batch_shape + self.event_shape
+```
+
+or
+
+```python
+[M1, ..., Mm] + self.batch_shape + self.event_shape
+```
+
+"""
+
+
+# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
+class MultivariateNormalLinearOperator(
+ transformed_distribution.TransformedDistribution):
+ """The multivariate normal distribution on `R^k`.
+
+ The Multivariate Normal distribution is defined over `R^k` and parameterized
+ by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
+ `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
+ matrix-multiplication.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
+ y = inv(scale) @ (x - loc),
+ Z = (2 pi)**(0.5 k) |det(scale)|,
+ ```
+
+ where:
+
+ * `loc` is a vector in `R^k`,
+ * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
+ * `Z` denotes the normalization constant, and,
+ * `||y||**2` denotes the squared Euclidean norm of `y`.
+
+ The MultivariateNormal distribution is a member of the [location-scale
+ family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
+ Y = scale @ X + loc
+ ```
+
+ #### Examples
+
+ ```python
+ ds = tf.contrib.distributions
+ la = tf.contrib.linalg
+
+ # Initialize a single 3-variate Gaussian.
+ mu = [1., 2, 3]
+ cov = [[ 0.36, 0.12, 0.06],
+ [ 0.12, 0.29, -0.13],
+ [ 0.06, -0.13, 0.26]]
+ scale = tf.cholesky(cov)
+ # ==> [[ 0.6, 0. , 0. ],
+ # [ 0.2, 0.5, 0. ],
+ # [ 0.1, -0.3, 0.4]])
+
+ mvn = ds.MultivariateNormalLinearOperator(
+ loc=mu,
+ scale=la.LinearOperatorTriL(scale))
+
+ # Covariance agrees with cholesky(cov) parameterization.
+ mvn.covariance().eval()
+ # ==> [[ 0.36, 0.12, 0.06],
+ # [ 0.12, 0.29, -0.13],
+ # [ 0.06, -0.13, 0.26]]
+
+ # Compute the pdf of an`R^3` observation; return a scalar.
+ mvn.prob([-1., 0, 1]).eval() # shape: []
+
+ # Initialize a 2-batch of 3-variate Gaussians.
+ mu = [[1., 2, 3],
+ [11, 22, 33]] # shape: [2, 3]
+ scale_diag = [[1., 2, 3],
+ [0.5, 1, 1.5]] # shape: [2, 3]
+
+ mvn = ds.MultivariateNormalLinearOperator(
+ loc=mu,
+ scale=la.LinearOperatorDiag(scale_diag))
+
+ # Compute the pdf of two `R^3` observations; return a length-2 vector.
+ x = [[-0.9, 0, 0.1],
+ [-10, 0, 9]] # shape: [2, 3]
+ mvn.prob(x).eval() # shape: [2]
+ ```
+
+ """
+
+ def __init__(self,
+ loc=None,
+ scale=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="MultivariateNormalLinearOperator"):
+ """Construct Multivariate Normal distribution on `R^k`.
+
+ The `batch_shape` is the broadcast shape between `loc` and `scale`
+ arguments.
+
+ The `event_shape` is given by the last dimension of `loc` or the last
+ dimension of the matrix implied by `scale`.
+
+ Recall that `covariance = scale @ scale.T`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ Args:
+ loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
+ implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
+ `b >= 0` and `k` is the event size.
+ scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
+ `[B1, ..., Bb, k, k]`.
+ validate_args: `Boolean`, default `False`. Whether to validate input
+ with asserts. If `validate_args` is `False`, and the inputs are
+ invalid, correct behavior is not guaranteed.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: The name to give Ops created by the initializer.
+
+ Raises:
+ ValueError: if `scale` is unspecified.
+ TypeError: if not `scale.dtype.is_floating`
+ """
+ parameters = locals()
+ if scale is None:
+ raise ValueError("Missing required `scale` parameter.")
+ if not scale.dtype.is_floating:
+ raise TypeError("`scale` parameter must have floating-point dtype.")
+
+ # Since expand_dims doesn't preserve constant-ness, we obtain the
+ # non-dynamic value if possible.
+ event_shape = scale.domain_dimension_tensor()
+ if tensor_util.constant_value(event_shape) is not None:
+ event_shape = tensor_util.constant_value(event_shape)
+ event_shape = event_shape[array_ops.newaxis]
+
+ super(MultivariateNormalLinearOperator, self).__init__(
+ distribution=normal.Normal(
+ loc=array_ops.zeros([], dtype=scale.dtype),
+ scale=array_ops.ones([], dtype=scale.dtype)),
+ bijector=bijectors.AffineLinearOperator(
+ shift=loc, scale=scale, validate_args=validate_args),
+ batch_shape=scale.batch_shape_tensor(),
+ event_shape=event_shape,
+ validate_args=validate_args,
+ name=name)
+ self._parameters = parameters
+
+ @property
+ def loc(self):
+ """The `loc` `Tensor` in `Y = scale @ X + loc`."""
+ return self.bijector.shift
+
+ @property
+ def scale(self):
+ """The `scale` `LinearOperator` in `Y = scale @ X + loc`."""
+ return self.bijector.scale
+
+ def log_det_covariance(self, name="log_det_covariance"):
+ """Log of determinant of covariance matrix."""
+ with ops.name_scope(self.name):
+ with ops.name_scope(name, values=self.scale.graph_parents):
+ return 2. * self.scale.log_abs_determinant()
+
+ def det_covariance(self, name="det_covariance"):
+ """Determinant of covariance matrix."""
+ with ops.name_scope(self.name):
+ with ops.name_scope(name, values=self.scale.graph_parents):
+ return math_ops.exp(2.* self.scale.log_abs_determinant())
+
+ @distribution_util.AppendDocstring(_mvn_sample_note)
+ def _log_prob(self, x):
+ return super(MultivariateNormalLinearOperator, self)._log_prob(x)
+
+ @distribution_util.AppendDocstring(_mvn_sample_note)
+ def _prob(self, x):
+ return super(MultivariateNormalLinearOperator, self)._prob(x)
+
+ def _mean(self):
+ if self.loc is None:
+ shape = array_ops.concat([
+ self.batch_shape_tensor(),
+ self.event_shape_tensor(),
+ ], 0)
+ return array_ops.zeros(shape, self.dtype)
+ return array_ops.identity(self.loc)
+
+ def _covariance(self):
+ # TODO(b/35041434): Remove special-case logic once LinOp supports
+ # `diag_part`.
+ if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorDiag)):
+ shape = array_ops.concat([self.batch_shape_tensor(),
+ self.event_shape_tensor()], 0)
+ diag_part = array_ops.ones(shape, self.scale.dtype)
+ if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
+ diag_part *= math_ops.square(
+ self.scale.multiplier[..., array_ops.newaxis])
+ elif isinstance(self.scale, linalg.LinearOperatorDiag):
+ diag_part *= math_ops.square(self.scale.diag)
+ return array_ops.matrix_diag(diag_part)
+ else:
+ # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
+ return self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))
+
+ def _variance(self):
+ # TODO(b/35041434): Remove special-case logic once LinOp supports
+ # `diag_part`.
+ if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorDiag)):
+ shape = array_ops.concat([self.batch_shape_tensor(),
+ self.event_shape_tensor()], 0)
+ diag_part = array_ops.ones(shape, self.scale.dtype)
+ if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
+ diag_part *= math_ops.square(
+ self.scale.multiplier[..., array_ops.newaxis])
+ elif isinstance(self.scale, linalg.LinearOperatorDiag):
+ diag_part *= math_ops.square(self.scale.diag)
+ return diag_part
+ elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate)
+ and self.scale.is_self_adjoint):
+ return array_ops.matrix_diag_part(
+ self.scale.apply(self.scale.to_dense()))
+ else:
+ # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
+ return array_ops.matrix_diag_part(
+ self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense())))
+
+ def _stddev(self):
+ # TODO(b/35041434): Remove special-case logic once LinOp supports
+ # `diag_part`.
+ if (isinstance(self.scale, linalg.LinearOperatorIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorScaledIdentity) or
+ isinstance(self.scale, linalg.LinearOperatorDiag)):
+ shape = array_ops.concat([self.batch_shape_tensor(),
+ self.event_shape_tensor()], 0)
+ diag_part = array_ops.ones(shape, self.scale.dtype)
+ if isinstance(self.scale, linalg.LinearOperatorScaledIdentity):
+ diag_part *= self.scale.multiplier[..., array_ops.newaxis]
+ elif isinstance(self.scale, linalg.LinearOperatorDiag):
+ diag_part *= self.scale.diag
+ return math_ops.abs(diag_part)
+ elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate)
+ and self.scale.is_self_adjoint):
+ return math_ops.sqrt(array_ops.matrix_diag_part(
+ self.scale.apply(self.scale.to_dense())))
+ else:
+ # TODO(b/35040238): Remove transpose once LinOp supports `transpose`.
+ return math_ops.sqrt(array_ops.matrix_diag_part(
+ self.scale.apply(array_ops.matrix_transpose(self.scale.to_dense()))))
+
+ def _mode(self):
+ return self._mean()
+
+
+@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator,
+ MultivariateNormalLinearOperator)
+def _kl_brute_force(a, b, name=None):
+ """Batched KL divergence `KL(a || b)` for multivariate Normals.
+
+ With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
+ covariance `C_a`, `C_b` respectively,
+
+ ```
+ KL(a || b) = 0.5 * ( L - k + T + Q ),
+ L := Log[Det(C_b)] - Log[Det(C_a)]
+ T := trace(C_b^{-1} C_a),
+ Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
+ ```
+
+ This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
+ methods for solving systems with `C_b` may be available, a dense version of
+ (the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
+ is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
+ and `y`.
+
+ Args:
+ a: Instance of `MultivariateNormalLinearOperator`.
+ b: Instance of `MultivariateNormalLinearOperator`.
+ name: (optional) name to use for created ops. Default "kl_mvn".
+
+ Returns:
+ Batchwise `KL(a || b)`.
+ """
+
+ def squared_frobenius_norm(x):
+ """Helper to make KL calculation slightly more readable."""
+ # http://mathworld.wolfram.com/FrobeniusNorm.html
+ return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))
+
+ # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
+ # supports something like:
+ # A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
+ def is_diagonal(x):
+ """Helper to identify if `LinearOperator` has only a diagonal component."""
+ return (isinstance(x, linalg.LinearOperatorIdentity) or
+ isinstance(x, linalg.LinearOperatorScaledIdentity) or
+ isinstance(x, linalg.LinearOperatorDiag))
+
+ with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] +
+ a.scale.graph_parents + b.scale.graph_parents):
+ # Calculation is based on:
+ # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
+ # and,
+ # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
+ # i.e.,
+ # If Ca = AA', Cb = BB', then
+ # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
+ # = tr[inv(B) A A' inv(B)']
+ # = tr[(inv(B) A) (inv(B) A)']
+ # = sum_{ij} (inv(B) A)_{ij}^2
+ # = ||inv(B) A||_F**2
+ # where ||.||_F is the Frobenius norm and the second equality follows from
+ # the cyclic permutation property.
+ if is_diagonal(a.scale) and is_diagonal(b.scale):
+ # Using `stddev` because it handles expansion of Identity cases.
+ b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
+ else:
+ b_inv_a = b.scale.solve(a.scale.to_dense())
+ kl_div = (b.scale.log_abs_determinant()
+ - a.scale.log_abs_determinant()
+ + 0.5 * (
+ - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype)
+ + squared_frobenius_norm(b_inv_a)
+ + squared_frobenius_norm(b.scale.solve(
+ (b.mean() - a.mean())[..., array_ops.newaxis]))))
+ kl_div.set_shape(array_ops.broadcast_static_shape(
+ a.batch_shape, b.batch_shape))
+ return kl_div
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
new file mode 100644
index 0000000000..8fdc0822c4
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -0,0 +1,213 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multivariate Normal distribution classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import linalg
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+__all__ = [
+ "MultivariateNormalTriL",
+]
+
+
+class MultivariateNormalTriL(
+ mvn_linop.MultivariateNormalLinearOperator):
+ """The multivariate normal distribution on `R^k`.
+
+ The Multivariate Normal distribution is defined over `R^k` and parameterized
+ by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
+ `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
+ matrix-multiplication.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
+ y = inv(scale) @ (x - loc),
+ Z = (2 pi)**(0.5 k) |det(scale)|,
+ ```
+
+ where:
+
+ * `loc` is a vector in `R^k`,
+ * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
+ * `Z` denotes the normalization constant, and,
+ * `||y||**2` denotes the squared Euclidean norm of `y`.
+
+ A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = scale_tril
+ ```
+
+ where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal,
+ i.e., `tf.diag_part(scale_tril) != 0`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ The MultivariateNormal distribution is a member of the [location-scale
+ family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift.
+ Y = scale @ X + loc
+ ```
+
+ Trainable (batch) lower-triangular matrices can be created with
+ `ds.matrix_diag_transform()` and/or `ds.fill_lower_triangular()`
+
+ #### Examples
+
+ ```python
+ ds = tf.contrib.distributions
+
+ # Initialize a single 3-variate Gaussian.
+ mu = [1., 2, 3]
+ cov = [[ 0.36, 0.12, 0.06],
+ [ 0.12, 0.29, -0.13],
+ [ 0.06, -0.13, 0.26]]
+ scale = tf.cholesky(cov)
+ # ==> [[ 0.6, 0. , 0. ],
+ # [ 0.2, 0.5, 0. ],
+ # [ 0.1, -0.3, 0.4]])
+ mvn = ds.MultivariateNormalTriL(
+ loc=mu,
+ scale_tril=scale)
+
+ mvn.mean().eval()
+ # ==> [1., 2, 3]
+
+ # Covariance agrees with cholesky(cov) parameterization.
+ mvn.covariance().eval()
+ # ==> [[ 0.36, 0.12, 0.06],
+ # [ 0.12, 0.29, -0.13],
+ # [ 0.06, -0.13, 0.26]]
+
+ # Compute the pdf of an observation in `R^3` ; return a scalar.
+ mvn.prob([-1., 0, 1]).eval() # shape: []
+
+ # Initialize a 2-batch of 3-variate Gaussians.
+ mu = [[1., 2, 3],
+ [11, 22, 33]] # shape: [2, 3]
+ tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal.
+ mvn = ds.MultivariateNormalTriL(
+ loc=mu,
+ scale_tril=tril)
+
+ # Compute the pdf of two `R^3` observations; return a length-2 vector.
+ x = [[-0.9, 0, 0.1],
+ [-10, 0, 9]] # shape: [2, 3]
+ mvn.prob(x).eval() # shape: [2]
+
+ ```
+
+ """
+
+ def __init__(self,
+ loc=None,
+ scale_tril=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="MultivariateNormalTriL"):
+ """Construct Multivariate Normal distribution on `R^k`.
+
+ The `batch_shape` is the broadcast shape between `loc` and `scale`
+ arguments.
+
+ The `event_shape` is given by the last dimension of `loc` or the last
+ dimension of the matrix implied by `scale`.
+
+ Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
+
+ ```none
+ scale = scale_tril
+ ```
+
+ where `scale_tril` is lower-triangular `k x k` matrix with non-zero
+ diagonal, i.e., `tf.diag_part(scale_tril) != 0`.
+
+ Additional leading dimensions (if any) will index batches.
+
+ Args:
+ loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
+ implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
+ `b >= 0` and `k` is the event size.
+ scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
+ diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
+ `b >= 0` and `k` is the event size.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
+
+ Raises:
+ ValueError: if neither `loc` nor `scale_tril` are specified.
+ """
+ parameters = locals()
+ def _convert_to_tensor(x, name):
+ return None if x is None else ops.convert_to_tensor(x, name=name)
+ if loc is None and scale_tril is None:
+ raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
+ with ops.name_scope(name) as ns:
+ with ops.name_scope("init", values=[loc, scale_tril]):
+ loc = _convert_to_tensor(loc, name="loc")
+ scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
+ if scale_tril is None:
+ scale = linalg.LinearOperatorIdentity(
+ num_rows=distribution_util.dimension_size(loc, -1),
+ dtype=loc.dtype,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ assert_proper_shapes=validate_args)
+ else:
+ if validate_args:
+ scale_tril = control_flow_ops.with_dependencies([
+ # TODO(b/35157376): Use `assert_none_equal` once it exists.
+ check_ops.assert_greater(
+ math_ops.abs(array_ops.matrix_diag_part(scale_tril)),
+ array_ops.zeros([], scale_tril.dtype),
+ message="`scale_tril` must have non-zero diagonal"),
+ ], scale_tril)
+ scale = linalg.LinearOperatorTriL(
+ scale_tril,
+ is_non_singular=True,
+ is_self_adjoint=False,
+ is_positive_definite=False)
+ super(MultivariateNormalTriL, self).__init__(
+ loc=loc,
+ scale=scale,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=ns)
+ self._parameters = parameters