aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-07-14 16:46:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-14 16:50:53 -0700
commit7687debbf63e31375d960d663373da8d469f2d2e (patch)
treea6b132cf53fae4dff19004b5b7dd0ee66941ed08
parentee8d9f353a29c3b0df8447e94c11f1a61ee75583 (diff)
Add unit-test for questions:
- http://stackoverflow.com/q/45109305 - #10766 PiperOrigin-RevId: 162026912
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
index 3f4582eb7e..e973c056e0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
@@ -24,7 +24,12 @@ from tensorflow.contrib import distributions
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -233,6 +238,43 @@ class MultivariateNormalDiagTest(test.TestCase):
self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
+ def testMultivariateNormalDiagNegLogLikelihood(self):
+ num_draws = 50
+ dims = 3
+ with self.test_session() as sess:
+ x_pl = array_ops.placeholder(dtype=dtypes.float32,
+ shape=[None, dims],
+ name="x")
+ mu_var = variable_scope.get_variable(
+ name="mu",
+ shape=[dims],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(1.))
+ sess.run([variables.global_variables_initializer()])
+
+ mvn = ds.MultivariateNormalDiag(
+ loc=mu_var,
+ scale_diag=array_ops.ones(shape=[dims], dtype=dtypes.float32))
+
+ # Typically you'd use `mvn.log_prob(x_pl)` which is always at least as
+ # numerically stable as `tf.log(mvn.prob(x_pl))`. However in this test
+ # we're testing a bug specific to `prob` and not `log_prob`;
+ # http://stackoverflow.com/q/45109305. (The underlying issue was not
+ # related to `Distributions` but that `reduce_prod` didn't correctly
+ # handle negative indexes.)
+ neg_log_likelihood = -math_ops.reduce_sum(math_ops.log(mvn.prob(x_pl)))
+ grad_neg_log_likelihood = gradients_impl.gradients(
+ neg_log_likelihood, variables.trainable_variables())
+
+ x = np.zeros([num_draws, dims], dtype=np.float32)
+ grad_neg_log_likelihood_ = sess.run(
+ grad_neg_log_likelihood,
+ feed_dict={x_pl: x})
+ self.assertEqual(1, len(grad_neg_log_likelihood_))
+ self.assertAllClose(grad_neg_log_likelihood_[0],
+ np.tile(num_draws, dims),
+ rtol=1e-6, atol=0.)
+
if __name__ == "__main__":
test.main()