diff options
Diffstat (limited to 'tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py index eebefbad28..d42da11c77 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py @@ -331,5 +331,27 @@ class TestSurrogateLosses(tf.test.TestCase): loss_fn=None) self.assertEqual(None, dt.loss(tf.constant([2.0]))) + def testExplicitStochasticTensors(self): + with self.test_session() as sess: + mu = tf.constant([0.0, 0.1, 0.2]) + sigma = tf.constant([1.1, 1.2, 1.3]) + with sg.value_type(sg.SampleAndReshapeValue()): + dt1 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma) + dt2 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma) + loss = tf.square(tf.identity(dt1)) + 10. + dt2 + + sl_all = sg.surrogate_loss([loss]) + sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1]) + sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2]) + + dt1_term = dt1.distribution.log_pdf(dt1) * loss + dt2_term = dt2.distribution.log_pdf(dt2) * loss + + self.assertAllClose(*sess.run( + [sl_all, sum([loss, dt1_term, dt2_term])])) + self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])])) + self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])])) + + if __name__ == "__main__": tf.test.main() |