aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
index 90910f3839..200310bc41 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
@@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_)
+ def testEntropy(self):
+ loc = np.array([-0.1, -3.2, 7.])
+ deterministic = deterministic_lib.Deterministic(loc=loc)
+ with self.test_session() as sess:
+ entropy_ = sess.run(deterministic.entropy())
+ self.assertAllEqual(np.zeros(3), entropy_)
+
class VectorDeterministicTest(test.TestCase):
@@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_)
+ def testEntropy(self):
+ loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]])
+ deterministic = deterministic_lib.VectorDeterministic(loc=loc)
+ with self.test_session() as sess:
+ entropy_ = sess.run(deterministic.entropy())
+ self.assertAllEqual(np.zeros(2), entropy_)
+
if __name__ == "__main__":
test.main()