aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/distributions/multinomial_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/multinomial_test.py')
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py20
1 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index ebc89f15c5..d62aca151a 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -250,13 +250,11 @@ class MultinomialTest(test.TestCase):
theta = np.array([[1., 2, 3],
[2.5, 4, 0.01]], dtype=np.float32)
theta /= np.sum(theta, 1)[..., array_ops.newaxis]
- # Ideally we'd be able to test broadcasting but, the multinomial sampler
- # doesn't support different total counts.
- n = np.float32(5)
+ n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32)
with self.test_session() as sess:
- # batch_shape=[2], event_shape=[3]
+ # batch_shape=[3, 2], event_shape=[3]
dist = multinomial.Multinomial(n, theta)
- x = dist.sample(int(250e3), seed=1)
+ x = dist.sample(int(1000e3), seed=1)
sample_mean = math_ops.reduce_mean(x, 0)
x_centered = x - sample_mean[array_ops.newaxis, ...]
sample_cov = math_ops.reduce_mean(math_ops.matmul(
@@ -283,17 +281,17 @@ class MultinomialTest(test.TestCase):
dist.variance(),
dist.stddev(),
])
- self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01)
- self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01)
- self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01)
- self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
+ self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01)
+ self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01)
+ self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01)
+ self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01)
def testSampleUnbiasedNonScalarBatch(self):
with self.test_session() as sess:
dist = multinomial.Multinomial(
- total_count=5.,
+ total_count=[7., 6., 5.],
logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
- n = int(3e3)
+ n = int(3e4)
x = dist.sample(n, seed=0)
sample_mean = math_ops.reduce_mean(x, 0)
# Cyclically rotate event dims left.