aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py')
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py480
1 files changed, 480 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
new file mode 100644
index 0000000000..d009f4e931
--- /dev/null
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -0,0 +1,480 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import dirichlet_multinomial
+from tensorflow.python.platform import test
+
+
+ds = dirichlet_multinomial
+
+
+class DirichletMultinomialTest(test.TestCase):
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testSimpleShapes(self):
+ with self.test_session():
+ alpha = np.random.rand(3)
+ dist = ds.DirichletMultinomial(1., alpha)
+ self.assertEqual(3, dist.event_shape_tensor().eval())
+ self.assertAllEqual([], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
+
+ def testComplexShapes(self):
+ with self.test_session():
+ alpha = np.random.rand(3, 2, 2)
+ n = [[3., 2], [4, 5], [6, 7]]
+ dist = ds.DirichletMultinomial(n, alpha)
+ self.assertEqual(2, dist.event_shape_tensor().eval())
+ self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
+
+ def testNproperty(self):
+ alpha = [[1., 2, 3]]
+ n = [[5.]]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(n, alpha)
+ self.assertEqual([1, 1], dist.total_count.get_shape())
+ self.assertAllClose(n, dist.total_count.eval())
+
+ def testAlphaProperty(self):
+ alpha = [[1., 2, 3]]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(1, alpha)
+ self.assertEqual([1, 3], dist.concentration.get_shape())
+ self.assertAllClose(alpha, dist.concentration.eval())
+
+ def testPmfNandCountsAgree(self):
+ alpha = [[1., 2, 3]]
+ n = [[5.]]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
+ dist.prob([2., 3, 0]).eval()
+ dist.prob([3., 0, 2]).eval()
+ with self.assertRaisesOpError("counts must be non-negative"):
+ dist.prob([-1., 4, 2]).eval()
+ with self.assertRaisesOpError(
+ "counts last-dimension must sum to `self.total_count`"):
+ dist.prob([3., 3, 0]).eval()
+
+ def testPmfNonIntegerCounts(self):
+ alpha = [[1., 2, 3]]
+ n = [[5.]]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
+ dist.prob([2., 3, 0]).eval()
+ dist.prob([3., 0, 2]).eval()
+ dist.prob([3.0, 0, 2.0]).eval()
+ # Both equality and integer checking fail.
+ placeholder = array_ops.placeholder(dtypes.float32)
+ with self.assertRaisesOpError(
+ "counts cannot contain fractional components"):
+ dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]})
+ dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
+ dist.prob([1., 2., 3.]).eval()
+ # Non-integer arguments work.
+ dist.prob([1.0, 2.5, 1.5]).eval()
+
+ def testPmfBothZeroBatches(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ # Both zero-batches. No broadcast
+ alpha = [1., 2]
+ counts = [1., 0]
+ dist = ds.DirichletMultinomial(1., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(1 / 3., pmf.eval())
+ self.assertEqual((), pmf.get_shape())
+
+ def testPmfBothZeroBatchesNontrivialN(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ # Both zero-batches. No broadcast
+ alpha = [1., 2]
+ counts = [3., 2]
+ dist = ds.DirichletMultinomial(5., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(1 / 7., pmf.eval())
+ self.assertEqual((), pmf.get_shape())
+
+ def testPmfBothZeroBatchesMultidimensionalN(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ alpha = [1., 2]
+ counts = [3., 2]
+ n = np.full([4, 3], 5., dtype=np.float32)
+ dist = ds.DirichletMultinomial(n, alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval())
+ self.assertEqual((4, 3), pmf.get_shape())
+
+ def testPmfAlphaStretchedInBroadcastWhenSameRank(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ alpha = [[1., 2]]
+ counts = [[1., 0], [0., 1]]
+ dist = ds.DirichletMultinomial([1.], alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
+ self.assertAllEqual([2], pmf.get_shape())
+
+ def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ alpha = [1., 2]
+ counts = [[1., 0], [0., 1]]
+ pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
+ self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
+ self.assertAllEqual([2], pmf.get_shape())
+
+ def testPmfCountsStretchedInBroadcastWhenSameRank(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ alpha = [[1., 2], [2., 3]]
+ counts = [[1., 0]]
+ pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
+ self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
+ self.assertAllEqual([2], pmf.get_shape())
+
+ def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ with self.test_session():
+ alpha = [[1., 2], [2., 3]]
+ counts = [1., 0]
+ pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
+ self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
+ self.assertAllEqual([2], pmf.get_shape())
+
+ def testPmfForOneVoteIsTheMeanWithOneRecordInput(self):
+ # The probabilities of one vote falling into class k is the mean for class
+ # k.
+ alpha = [1., 2, 3]
+ with self.test_session():
+ for class_num in range(3):
+ counts = np.zeros([3], dtype=np.float32)
+ counts[class_num] = 1
+ dist = ds.DirichletMultinomial(1., alpha)
+ mean = dist.mean().eval()
+ pmf = dist.prob(counts).eval()
+
+ self.assertAllClose(mean[class_num], pmf)
+ self.assertAllEqual([3], mean.shape)
+ self.assertAllEqual([], pmf.shape)
+
+ def testMeanDoubleTwoVotes(self):
+ # The probabilities of two votes falling into class k for
+ # DirichletMultinomial(2, alpha) is twice as much as the probability of one
+ # vote falling into class k for DirichletMultinomial(1, alpha)
+ alpha = [1., 2, 3]
+ with self.test_session():
+ for class_num in range(3):
+ counts_one = np.zeros([3], dtype=np.float32)
+ counts_one[class_num] = 1.
+ counts_two = np.zeros([3], dtype=np.float32)
+ counts_two[class_num] = 2
+
+ dist1 = ds.DirichletMultinomial(1., alpha)
+ dist2 = ds.DirichletMultinomial(2., alpha)
+
+ mean1 = dist1.mean().eval()
+ mean2 = dist2.mean().eval()
+
+ self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
+ self.assertAllEqual([3], mean1.shape)
+
+ def testCovarianceFromSampling(self):
+ # We will test mean, cov, var, stddev on a DirichletMultinomial constructed
+ # via broadcast between alpha, n.
+ alpha = np.array([[1., 2, 3],
+ [2.5, 4, 0.01]], dtype=np.float32)
+ # Ideally we'd be able to test broadcasting but, the multinomial sampler
+ # doesn't support different total counts.
+ n = np.float32(5)
+ with self.test_session() as sess:
+ # batch_shape=[2], event_shape=[3]
+ dist = ds.DirichletMultinomial(n, alpha)
+ x = dist.sample(int(250e3), seed=1)
+ sample_mean = math_ops.reduce_mean(x, 0)
+ x_centered = x - sample_mean[array_ops.newaxis, ...]
+ sample_cov = math_ops.reduce_mean(math_ops.matmul(
+ x_centered[..., array_ops.newaxis],
+ x_centered[..., array_ops.newaxis, :]), 0)
+ sample_var = array_ops.matrix_diag_part(sample_cov)
+ sample_stddev = math_ops.sqrt(sample_var)
+ [
+ sample_mean_,
+ sample_cov_,
+ sample_var_,
+ sample_stddev_,
+ analytic_mean,
+ analytic_cov,
+ analytic_var,
+ analytic_stddev,
+ ] = sess.run([
+ sample_mean,
+ sample_cov,
+ sample_var,
+ sample_stddev,
+ dist.mean(),
+ dist.covariance(),
+ dist.variance(),
+ dist.stddev(),
+ ])
+ self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
+ self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05)
+ self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
+ self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
+
+ def testCovariance(self):
+ # Shape [2]
+ alpha = [1., 2]
+ ns = [2., 3., 4., 5.]
+ alpha_0 = np.sum(alpha)
+
+ # Diagonal entries are of the form:
+ # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) *
+ # (alpha_sum + n) / (alpha_sum + 1)
+ variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
+ # Off diagonal entries are of the form:
+ # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) *
+ # (alpha_sum + n) / (alpha_sum + 1)
+ covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
+ # Shape [2, 2].
+ shared_matrix = np.array([[
+ variance_entry(alpha[0], alpha_0),
+ covariance_entry(alpha[0], alpha[1], alpha_0)
+ ], [
+ covariance_entry(alpha[1], alpha[0], alpha_0),
+ variance_entry(alpha[1], alpha_0)
+ ]])
+
+ with self.test_session():
+ for n in ns:
+ # n is shape [] and alpha is shape [2].
+ dist = ds.DirichletMultinomial(n, alpha)
+ covariance = dist.covariance()
+ expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
+
+ self.assertEqual([2, 2], covariance.get_shape())
+ self.assertAllClose(expected_covariance, covariance.eval())
+
+ def testCovarianceNAlphaBroadcast(self):
+ alpha_v = [1., 2, 3]
+ alpha_0 = 6.
+
+ # Shape [4, 3]
+ alpha = np.array(4 * [alpha_v], dtype=np.float32)
+ # Shape [4, 1]
+ ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
+
+ variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
+ covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
+ # Shape [4, 3, 3]
+ shared_matrix = np.array(
+ 4 * [[[
+ variance_entry(alpha_v[0], alpha_0),
+ covariance_entry(alpha_v[0], alpha_v[1], alpha_0),
+ covariance_entry(alpha_v[0], alpha_v[2], alpha_0)
+ ], [
+ covariance_entry(alpha_v[1], alpha_v[0], alpha_0),
+ variance_entry(alpha_v[1], alpha_0),
+ covariance_entry(alpha_v[1], alpha_v[2], alpha_0)
+ ], [
+ covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
+ covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
+ variance_entry(alpha_v[2], alpha_0)
+ ]]],
+ dtype=np.float32)
+
+ with self.test_session():
+ # ns is shape [4, 1], and alpha is shape [4, 3].
+ dist = ds.DirichletMultinomial(ns, alpha)
+ covariance = dist.covariance()
+ expected_covariance = shared_matrix * (
+ ns * (ns + alpha_0) / (1 + alpha_0))[..., array_ops.newaxis]
+
+ self.assertEqual([4, 3, 3], covariance.get_shape())
+ self.assertAllClose(expected_covariance, covariance.eval())
+
+ def testCovarianceMultidimensional(self):
+ alpha = np.random.rand(3, 5, 4).astype(np.float32)
+ alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
+
+ ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
+ ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
+
+ with self.test_session():
+ dist = ds.DirichletMultinomial(ns, alpha)
+ dist2 = ds.DirichletMultinomial(ns2, alpha2)
+
+ covariance = dist.covariance()
+ covariance2 = dist2.covariance()
+ self.assertEqual([3, 5, 4, 4], covariance.get_shape())
+ self.assertEqual([6, 3, 3, 3], covariance2.get_shape())
+
+ def testZeroCountsResultsInPmfEqualToOne(self):
+ # There is only one way for zero items to be selected, and this happens with
+ # probability 1.
+ alpha = [5, 0.5]
+ counts = [0., 0]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(0., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(1.0, pmf.eval())
+ self.assertEqual((), pmf.get_shape())
+
+ def testLargeTauGivesPreciseProbabilities(self):
+ # If tau is large, we are doing coin flips with probability mu.
+ mu = np.array([0.1, 0.1, 0.8], dtype=np.float32)
+ tau = np.array([100.], dtype=np.float32)
+ alpha = tau * mu
+
+ # One (three sided) coin flip. Prob[coin 3] = 0.8.
+ # Note that since it was one flip, value of tau didn't matter.
+ counts = [0., 0, 1]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(1., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(0.8, pmf.eval(), atol=1e-4)
+ self.assertEqual((), pmf.get_shape())
+
+ # Two (three sided) coin flips. Prob[coin 3] = 0.8.
+ counts = [0., 0, 2]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(2., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
+ self.assertEqual((), pmf.get_shape())
+
+ # Three (three sided) coin flips.
+ counts = [1., 0, 2]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(3., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
+ self.assertEqual((), pmf.get_shape())
+
+ def testSmallTauPrefersCorrelatedResults(self):
+ # If tau is small, then correlation between draws is large, so draws that
+ # are both of the same class are more likely.
+ mu = np.array([0.5, 0.5], dtype=np.float32)
+ tau = np.array([0.1], dtype=np.float32)
+ alpha = tau * mu
+
+ # If there is only one draw, it is still a coin flip, even with small tau.
+ counts = [1., 0]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(1., alpha)
+ pmf = dist.prob(counts)
+ self.assertAllClose(0.5, pmf.eval())
+ self.assertEqual((), pmf.get_shape())
+
+ # If there are two draws, it is much more likely that they are the same.
+ counts_same = [2., 0]
+ counts_different = [1, 1.]
+ with self.test_session():
+ dist = ds.DirichletMultinomial(2., alpha)
+ pmf_same = dist.prob(counts_same)
+ pmf_different = dist.prob(counts_different)
+ self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
+ self.assertEqual((), pmf_same.get_shape())
+
+ def testNonStrictTurnsOffAllChecks(self):
+ # Make totally invalid input.
+ with self.test_session():
+ alpha = [[-1., 2]] # alpha should be positive.
+ counts = [[1., 0], [0., -1]] # counts should be non-negative.
+ n = [-5.3] # n should be a non negative integer equal to counts.sum.
+ dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
+ dist.prob(counts).eval() # Should not raise.
+
+ def testSampleUnbiasedNonScalarBatch(self):
+ with self.test_session() as sess:
+ dist = ds.DirichletMultinomial(
+ total_count=5.,
+ concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
+ n = int(3e3)
+ x = dist.sample(n, seed=0)
+ sample_mean = math_ops.reduce_mean(x, 0)
+ # Cyclically rotate event dims left.
+ x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
+ sample_covariance = math_ops.matmul(
+ x_centered, x_centered, adjoint_b=True) / n
+ [
+ sample_mean_,
+ sample_covariance_,
+ actual_mean_,
+ actual_covariance_,
+ ] = sess.run([
+ sample_mean,
+ sample_covariance,
+ dist.mean(),
+ dist.covariance(),
+ ])
+ self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15)
+ self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
+ self.assertAllClose(
+ actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
+
+ def testSampleUnbiasedScalarBatch(self):
+ with self.test_session() as sess:
+ dist = ds.DirichletMultinomial(
+ total_count=5.,
+ concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
+ n = int(5e3)
+ x = dist.sample(n, seed=0)
+ sample_mean = math_ops.reduce_mean(x, 0)
+ x_centered = x - sample_mean # Already transposed to [n, 2].
+ sample_covariance = math_ops.matmul(
+ x_centered, x_centered, adjoint_a=True) / n
+ [
+ sample_mean_,
+ sample_covariance_,
+ actual_mean_,
+ actual_covariance_,
+ ] = sess.run([
+ sample_mean,
+ sample_covariance,
+ dist.mean(),
+ dist.covariance(),
+ ])
+ self.assertAllEqual([4], sample_mean.get_shape())
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05)
+ self.assertAllEqual([4, 4], sample_covariance.get_shape())
+ self.assertAllClose(
+ actual_covariance_, sample_covariance_, atol=0., rtol=0.15)
+
+
+if __name__ == "__main__":
+ test.main()