aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
index eebefbad28..d42da11c77 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
@@ -331,5 +331,27 @@ class TestSurrogateLosses(tf.test.TestCase):
loss_fn=None)
self.assertEqual(None, dt.loss(tf.constant([2.0])))
+ def testExplicitStochasticTensors(self):
+ with self.test_session() as sess:
+ mu = tf.constant([0.0, 0.1, 0.2])
+ sigma = tf.constant([1.1, 1.2, 1.3])
+ with sg.value_type(sg.SampleAndReshapeValue()):
+ dt1 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma)
+ dt2 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma)
+ loss = tf.square(tf.identity(dt1)) + 10. + dt2
+
+ sl_all = sg.surrogate_loss([loss])
+ sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1])
+ sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2])
+
+ dt1_term = dt1.distribution.log_pdf(dt1) * loss
+ dt2_term = dt2.distribution.log_pdf(dt2) * loss
+
+ self.assertAllClose(*sess.run(
+ [sl_all, sum([loss, dt1_term, dt2_term])]))
+ self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])]))
+ self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])]))
+
+
if __name__ == "__main__":
tf.test.main()