aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 11:09:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:41:12 -0700
commitf6a55cc344cd96098cabd500144aad266e692598 (patch)
treef15c10f3db1ea29792cc60d4d5247005b5e6a2a6 /tensorflow/contrib/distributions/python
parent170634d5a10a94d3bd12cc794c284eafcf47fa54 (diff)
Add tests for broadcasting KL divergence calculations.
PiperOrigin-RevId: 195690035
Diffstat (limited to 'tensorflow/contrib/distributions/python')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py31
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py39
2 files changed, 62 insertions, 8 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
index 7435bcbc68..b003526392 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
@@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
return mu, sigma
def testKLBatch(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
+ def testKLBatchBroadcast(self):
+ batch_shape = [2]
+ event_shape = [3]
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ # No batch shape.
+ mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
+ mvn_a = ds.MultivariateNormalFullCovariance(
+ loc=mu_a,
+ covariance_matrix=sigma_a,
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalFullCovariance(
+ loc=mu_b,
+ covariance_matrix=sigma_b,
+ validate_args=True)
+
+ kl = ds.kl_divergence(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
+ mu_b, sigma_b)
+ expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
+ mu_b, sigma_b)
+ self.assertAllClose(expected_kl_0, kl_v[0])
+ self.assertAllClose(expected_kl_1, kl_v[1])
+
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
index 685f32883d..b556d06123 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
@@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase):
return mu, sigma
def testKLNonBatch(self):
- batch_shape = ()
- event_shape = (2,)
+ batch_shape = []
+ event_shape = [2]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_kl, kl_v)
def testKLBatch(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
+ def testKLBatchBroadcast(self):
+ batch_shape = [2]
+ event_shape = [3]
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ # No batch shape.
+ mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
+ mvn_a = ds.MultivariateNormalTriL(
+ loc=mu_a,
+ scale_tril=np.linalg.cholesky(sigma_a),
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalTriL(
+ loc=mu_b,
+ scale_tril=np.linalg.cholesky(sigma_b),
+ validate_args=True)
+
+ kl = ds.kl_divergence(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
+ mu_b, sigma_b)
+ expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
+ mu_b, sigma_b)
+ self.assertAllClose(expected_kl_0, kl_v[0])
+ self.assertAllClose(expected_kl_1, kl_v[1])
+
def testKLTwoIdenticalDistributionsIsZero(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalTriL(