diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/bernoulli_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/bernoulli_test.py | 320 |
1 files changed, 320 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py new file mode 100644 index 0000000000..ef93c4dab0 --- /dev/null +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -0,0 +1,320 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Bernoulli distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +special = try_import("scipy.special") + + +def make_bernoulli(batch_shape, dtype=dtypes.int32): + p = np.random.uniform(size=list(batch_shape)) + p = constant_op.constant(p, dtype=dtypes.float32) + return bernoulli.Bernoulli(probs=p, dtype=dtype) + + +def entropy(p): + q = 1. - p + return -q * np.log(q) - p * np.log(p) + + +class BernoulliTest(test.TestCase): + + def testP(self): + p = [0.2, 0.4] + dist = bernoulli.Bernoulli(probs=p) + with self.test_session(): + self.assertAllClose(p, dist.probs.eval()) + + def testLogits(self): + logits = [-42., 42.] + dist = bernoulli.Bernoulli(logits=logits) + with self.test_session(): + self.assertAllClose(logits, dist.logits.eval()) + + if not special: + return + + with self.test_session(): + self.assertAllClose(special.expit(logits), dist.probs.eval()) + + p = [0.01, 0.99, 0.42] + dist = bernoulli.Bernoulli(probs=p) + with self.test_session(): + self.assertAllClose(special.logit(p), dist.logits.eval()) + + def testInvalidP(self): + invalid_ps = [1.01, 2.] + for p in invalid_ps: + with self.test_session(): + with self.assertRaisesOpError("probs has components greater than 1"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + dist.probs.eval() + + invalid_ps = [-0.01, -3.] + for p in invalid_ps: + with self.test_session(): + with self.assertRaisesOpError("Condition x >= 0"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + dist.probs.eval() + + valid_ps = [0.0, 0.5, 1.0] + for p in valid_ps: + with self.test_session(): + dist = bernoulli.Bernoulli(probs=p) + self.assertEqual(p, dist.probs.eval()) # Should not fail + + def testShapes(self): + with self.test_session(): + for batch_shape in ([], [1], [2, 3, 4]): + dist = make_bernoulli(batch_shape) + self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) + self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) + self.assertAllEqual([], dist.event_shape.as_list()) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + + def testDtype(self): + dist = make_bernoulli([]) + self.assertEqual(dist.dtype, dtypes.int32) + self.assertEqual(dist.dtype, dist.sample(5).dtype) + self.assertEqual(dist.dtype, dist.mode().dtype) + self.assertEqual(dist.probs.dtype, dist.mean().dtype) + self.assertEqual(dist.probs.dtype, dist.variance().dtype) + self.assertEqual(dist.probs.dtype, dist.stddev().dtype) + self.assertEqual(dist.probs.dtype, dist.entropy().dtype) + self.assertEqual(dist.probs.dtype, dist.prob(0).dtype) + self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype) + + dist64 = make_bernoulli([], dtypes.int64) + self.assertEqual(dist64.dtype, dtypes.int64) + self.assertEqual(dist64.dtype, dist64.sample(5).dtype) + self.assertEqual(dist64.dtype, dist64.mode().dtype) + + def _testPmf(self, **kwargs): + dist = bernoulli.Bernoulli(**kwargs) + with self.test_session(): + # pylint: disable=bad-continuation + xs = [ + 0, + [1], + [1, 0], + [[1, 0]], + [[1, 0], [1, 1]], + ] + expected_pmfs = [ + [[0.8, 0.6], [0.7, 0.4]], + [[0.2, 0.4], [0.3, 0.6]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.6]], + ] + # pylint: enable=bad-continuation + + for x, expected_pmf in zip(xs, expected_pmfs): + self.assertAllClose(dist.prob(x).eval(), expected_pmf) + self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf)) + + def testPmfCorrectBroadcastDynamicShape(self): + with self.test_session(): + p = array_ops.placeholder(dtype=dtypes.float32) + dist = bernoulli.Bernoulli(probs=p) + event1 = [1, 0, 1] + event2 = [[1, 0, 1]] + self.assertAllClose( + dist.prob(event1).eval({ + p: [0.2, 0.3, 0.4] + }), [0.2, 0.7, 0.4]) + self.assertAllClose( + dist.prob(event2).eval({ + p: [0.2, 0.3, 0.4] + }), [[0.2, 0.7, 0.4]]) + + def testPmfInvalid(self): + p = [0.1, 0.2, 0.7] + with self.test_session(): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + with self.assertRaisesOpError("must be non-negative."): + dist.prob([1, 1, -1]).eval() + with self.assertRaisesOpError("is not less than or equal to 1."): + dist.prob([2, 0, 1]).eval() + + def testPmfWithP(self): + p = [[0.2, 0.4], [0.3, 0.6]] + self._testPmf(probs=p) + if not special: + return + self._testPmf(logits=special.logit(p)) + + def testBroadcasting(self): + with self.test_session(): + p = array_ops.placeholder(dtypes.float32) + dist = bernoulli.Bernoulli(probs=p) + self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) + self.assertAllClose( + np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({ + p: 0.5 + })) + self.assertAllClose( + np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({ + p: [0.5, 0.5, 0.5] + })) + + def testPmfShapes(self): + with self.test_session(): + p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) + dist = bernoulli.Bernoulli(probs=p) + self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) + + with self.test_session(): + dist = bernoulli.Bernoulli(probs=0.5) + self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape)) + + with self.test_session(): + dist = bernoulli.Bernoulli(probs=0.5) + self.assertEqual((), dist.log_prob(1).get_shape()) + self.assertEqual((1), dist.log_prob([1]).get_shape()) + self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) + + with self.test_session(): + dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) + self.assertEqual((2, 1), dist.log_prob(1).get_shape()) + + def testBoundaryConditions(self): + with self.test_session(): + dist = bernoulli.Bernoulli(probs=1.0) + self.assertAllClose(np.nan, dist.log_prob(0).eval()) + self.assertAllClose([np.nan], [dist.log_prob(1).eval()]) + + def testEntropyNoBatch(self): + p = 0.2 + dist = bernoulli.Bernoulli(probs=p) + with self.test_session(): + self.assertAllClose(dist.entropy().eval(), entropy(p)) + + def testEntropyWithBatch(self): + p = [[0.1, 0.7], [0.2, 0.6]] + dist = bernoulli.Bernoulli(probs=p, validate_args=False) + with self.test_session(): + self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)], + [entropy(0.2), entropy(0.6)]]) + + def testSampleN(self): + with self.test_session(): + p = [0.2, 0.6] + dist = bernoulli.Bernoulli(probs=p) + n = 100000 + samples = dist.sample(n) + samples.set_shape([n, 2]) + self.assertEqual(samples.dtype, dtypes.int32) + sample_values = samples.eval() + self.assertTrue(np.all(sample_values >= 0)) + self.assertTrue(np.all(sample_values <= 1)) + # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / + # n). This means that the tolerance is very sensitive to the value of p + # as well as n. + self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) + self.assertEqual(set([0, 1]), set(sample_values.flatten())) + # In this test we're just interested in verifying there isn't a crash + # owing to mismatched types. b/30940152 + dist = bernoulli.Bernoulli(np.log([.2, .4])) + self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + + def testSampleActsLikeSampleN(self): + with self.test_session() as sess: + p = [0.2, 0.6] + dist = bernoulli.Bernoulli(probs=p) + n = 1000 + seed = 42 + self.assertAllEqual( + dist.sample(n, seed).eval(), dist.sample(n, seed).eval()) + n = array_ops.placeholder(dtypes.int32) + sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)], + feed_dict={n: 1000}) + self.assertAllEqual(sample, sample) + + def testMean(self): + with self.test_session(): + p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) + dist = bernoulli.Bernoulli(probs=p) + self.assertAllEqual(dist.mean().eval(), p) + + def testVarianceAndStd(self): + var = lambda p: p * (1. - p) + with self.test_session(): + p = [[0.2, 0.7], [0.5, 0.4]] + dist = bernoulli.Bernoulli(probs=p) + self.assertAllClose( + dist.variance().eval(), + np.array( + [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) + self.assertAllClose( + dist.stddev().eval(), + np.array( + [[np.sqrt(var(0.2)), np.sqrt(var(0.7))], + [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], + dtype=np.float32)) + + def testBernoulliWithSigmoidProbs(self): + p = np.array([8.3, 4.2]) + dist = bernoulli.BernoulliWithSigmoidProbs(logits=p) + with self.test_session(): + self.assertAllClose(math_ops.sigmoid(p).eval(), dist.probs.eval()) + + def testBernoulliBernoulliKL(self): + with self.test_session() as sess: + batch_size = 6 + a_p = np.array([0.5] * batch_size, dtype=np.float32) + b_p = np.array([0.4] * batch_size, dtype=np.float32) + + a = bernoulli.Bernoulli(probs=a_p) + b = bernoulli.Bernoulli(probs=b_p) + + kl = kullback_leibler.kl_divergence(a, b) + kl_val = sess.run(kl) + + kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( + (1. - a_p) / (1. - b_p))) + + self.assertEqual(kl.get_shape(), (batch_size,)) + self.assertAllClose(kl_val, kl_expected) + + +if __name__ == "__main__": + test.main() |