diff options
author | 2017-07-14 16:46:49 -0700 | |
---|---|---|
committer | 2017-07-14 16:50:53 -0700 | |
commit | 7687debbf63e31375d960d663373da8d469f2d2e (patch) | |
tree | a6b132cf53fae4dff19004b5b7dd0ee66941ed08 | |
parent | ee8d9f353a29c3b0df8447e94c11f1a61ee75583 (diff) |
Add unit-test for questions:
- http://stackoverflow.com/q/45109305
- #10766
PiperOrigin-RevId: 162026912
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 3f4582eb7e..e973c056e0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -24,7 +24,12 @@ from tensorflow.contrib import distributions from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -233,6 +238,43 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertAllClose(mu, samps.mean(axis=0), atol=0.1) self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1) + def testMultivariateNormalDiagNegLogLikelihood(self): + num_draws = 50 + dims = 3 + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtype=dtypes.float32, + shape=[None, dims], + name="x") + mu_var = variable_scope.get_variable( + name="mu", + shape=[dims], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(1.)) + sess.run([variables.global_variables_initializer()]) + + mvn = ds.MultivariateNormalDiag( + loc=mu_var, + scale_diag=array_ops.ones(shape=[dims], dtype=dtypes.float32)) + + # Typically you'd use `mvn.log_prob(x_pl)` which is always at least as + # numerically stable as `tf.log(mvn.prob(x_pl))`. However in this test + # we're testing a bug specific to `prob` and not `log_prob`; + # http://stackoverflow.com/q/45109305. (The underlying issue was not + # related to `Distributions` but that `reduce_prod` didn't correctly + # handle negative indexes.) + neg_log_likelihood = -math_ops.reduce_sum(math_ops.log(mvn.prob(x_pl))) + grad_neg_log_likelihood = gradients_impl.gradients( + neg_log_likelihood, variables.trainable_variables()) + + x = np.zeros([num_draws, dims], dtype=np.float32) + grad_neg_log_likelihood_ = sess.run( + grad_neg_log_likelihood, + feed_dict={x_pl: x}) + self.assertEqual(1, len(grad_neg_log_likelihood_)) + self.assertAllClose(grad_neg_log_likelihood_[0], + np.tile(num_draws, dims), + rtol=1e-6, atol=0.) + if __name__ == "__main__": test.main() |