aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/distributions/beta_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/beta_test.py')
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py394
1 files changed, 394 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
new file mode 100644
index 0000000000..91a451f033
--- /dev/null
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -0,0 +1,394 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.distributions import beta as beta_lib
+from tensorflow.python.ops.distributions import kullback_leibler
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def try_import(name): # pylint: disable=invalid-name
+ module = None
+ try:
+ module = importlib.import_module(name)
+ except ImportError as e:
+ tf_logging.warning("Could not import %s: %s" % (name, str(e)))
+ return module
+
+
+special = try_import("scipy.special")
+stats = try_import("scipy.stats")
+
+
+class BetaTest(test.TestCase):
+
+ def testSimpleShapes(self):
+ with self.test_session():
+ a = np.random.rand(3)
+ b = np.random.rand(3)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+
+ def testComplexShapes(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(3, 2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(
+ tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+
+ def testComplexShapesBroadcast(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(
+ tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+
+ def testAlphaProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration1.get_shape())
+ self.assertAllClose(a, dist.concentration1.eval())
+
+ def testBetaProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration0.get_shape())
+ self.assertAllClose(b, dist.concentration0.eval())
+
+ def testPdfXProper(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = beta_lib.Beta(a, b, validate_args=True)
+ dist.prob([.1, .3, .6]).eval()
+ dist.prob([.2, .3, .5]).eval()
+ # Either condition can trigger.
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([-1., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([0., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be no larger than `1`"):
+ dist.prob([.1, .2, 1.2]).eval()
+
+ def testPdfTwoBatches(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.5, .5]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1., 3. / 2], pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfTwoBatchesNontrivialX(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.3, .7]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1, 63. / 50], pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfUniformZeroBatch(self):
+ with self.test_session():
+ # This is equivalent to a uniform distribution
+ a = 1.
+ b = 1.
+ x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1.] * 5, pdf.eval())
+ self.assertEqual((5,), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2]]
+ b = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testBetaMean(self):
+ with session.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = stats.beta.mean(a, b)
+ self.assertAllClose(expected_mean, dist.mean().eval())
+
+ def testBetaVariance(self):
+ with session.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = stats.beta.var(a, b)
+ self.assertAllClose(expected_variance, dist.variance().eval())
+
+ def testBetaMode(self):
+ with session.Session():
+ a = np.array([1.1, 2, 3])
+ b = np.array([2., 4, 1.2])
+ expected_mode = (a - 1) / (a + b - 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mode().get_shape(), (3,))
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testBetaModeInvalid(self):
+ with session.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ dist.mode().eval()
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ dist.mode().eval()
+
+ def testBetaModeEnableAllowNanStats(self):
+ with session.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testBetaEntropy(self):
+ with session.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.beta.entropy(a, b)
+ self.assertAllClose(expected_entropy, dist.entropy().eval())
+
+ def testBetaSample(self):
+ with self.test_session():
+ a = 1.
+ b = 2.
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values,
+ stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
+ # The standard error of the sample mean is 1 / (sqrt(18 * n))
+ self.assertAllClose(
+ sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
+ self.assertAllClose(
+ np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+
+ # Test that sampling with the same seed twice gives the same results.
+ def testBetaSampleMultipleTimes(self):
+ with self.test_session():
+ a_val = 1.
+ b_val = 2.
+ n_val = 100
+
+ random_seed.set_random_seed(654321)
+ beta1 = beta_lib.Beta(concentration1=a_val,
+ concentration0=b_val,
+ name="beta1")
+ samples1 = beta1.sample(n_val, seed=123456).eval()
+
+ random_seed.set_random_seed(654321)
+ beta2 = beta_lib.Beta(concentration1=a_val,
+ concentration0=b_val,
+ name="beta2")
+ samples2 = beta2.sample(n_val, seed=123456).eval()
+
+ self.assertAllClose(samples1, samples2)
+
+ def testBetaSampleMultidimensional(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2).astype(np.float32)
+ b = np.random.rand(3, 2, 2).astype(np.float32)
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values[:, 1, :].mean(axis=0),
+ stats.beta.mean(a, b)[1, :],
+ atol=1e-1)
+
+ def testBetaCdf(self):
+ with self.test_session():
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = beta_lib.Beta(a, b).cdf(x).eval()
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+
+ def testBetaLogCdf(self):
+ with self.test_session():
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval()
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+
+ def testBetaWithSoftplusConcentration(self):
+ with self.test_session():
+ a, b = -4.2, -9.1
+ dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+ self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval())
+ self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval())
+
+ def testBetaBetaKL(self):
+ with self.test_session() as sess:
+ for shape in [(10,), (4, 5)]:
+ a1 = 6.0 * np.random.random(size=shape) + 1e-4
+ b1 = 6.0 * np.random.random(size=shape) + 1e-4
+ a2 = 6.0 * np.random.random(size=shape) + 1e-4
+ b2 = 6.0 * np.random.random(size=shape) + 1e-4
+ # Take inverse softplus of values to test BetaWithSoftplusConcentration
+ a1_sp = np.log(np.exp(a1) - 1.0)
+ b1_sp = np.log(np.exp(b1) - 1.0)
+ a2_sp = np.log(np.exp(a2) - 1.0)
+ b2_sp = np.log(np.exp(b2) - 1.0)
+
+ d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
+ d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
+ d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
+ concentration0=b1_sp)
+ d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
+ concentration0=b2_sp)
+
+ if not special:
+ return
+ kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
+ (a1 - a2) * special.digamma(a1) +
+ (b1 - b2) * special.digamma(b1) +
+ (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
+
+ for dist1 in [d1, d1_sp]:
+ for dist2 in [d2, d2_sp]:
+ kl = kullback_leibler.kl_divergence(dist1, dist2)
+ kl_val = sess.run(kl)
+ self.assertEqual(kl.get_shape(), shape)
+ self.assertAllClose(kl_val, kl_expected)
+
+ # Make sure KL(d1||d1) is 0
+ kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1))
+ self.assertAllClose(kl_same, np.zeros_like(kl_expected))
+
+
+if __name__ == "__main__":
+ test.main()