From 7687debbf63e31375d960d663373da8d469f2d2e Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Fri, 14 Jul 2017 16:46:49 -0700 Subject: Add unit-test for questions: - http://stackoverflow.com/q/45109305 - #10766 PiperOrigin-RevId: 162026912 --- .../python/kernel_tests/mvn_diag_test.py | 42 ++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 3f4582eb7e..e973c056e0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -24,7 +24,12 @@ from tensorflow.contrib import distributions from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -233,6 +238,43 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertAllClose(mu, samps.mean(axis=0), atol=0.1) self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) + def testMultivariateNormalDiagNegLogLikelihood(self): + num_draws = 50 + dims = 3 + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtype=dtypes.float32, + shape=[None, dims], + name="x") + mu_var = variable_scope.get_variable( + name="mu", + shape=[dims], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(1.)) + sess.run([variables.global_variables_initializer()]) + + mvn = ds.MultivariateNormalDiag( + loc=mu_var, + scale_diag=array_ops.ones(shape=[dims], dtype=dtypes.float32)) + + # Typically you'd use `mvn.log_prob(x_pl)` which is always at least as + # numerically stable as `tf.log(mvn.prob(x_pl))`. However in this test + # we're testing a bug specific to `prob` and not `log_prob`; + # http://stackoverflow.com/q/45109305. (The underlying issue was not + # related to `Distributions` but that `reduce_prod` didn't correctly + # handle negative indexes.) + neg_log_likelihood = -math_ops.reduce_sum(math_ops.log(mvn.prob(x_pl))) + grad_neg_log_likelihood = gradients_impl.gradients( + neg_log_likelihood, variables.trainable_variables()) + + x = np.zeros([num_draws, dims], dtype=np.float32) + grad_neg_log_likelihood_ = sess.run( + grad_neg_log_likelihood, + feed_dict={x_pl: x}) + self.assertEqual(1, len(grad_neg_log_likelihood_)) + self.assertAllClose(grad_neg_log_likelihood_[0], + np.tile(num_draws, dims), + rtol=1e-6, atol=0.) + if __name__ == "__main__": test.main() -- cgit v1.2.3