diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/multinomial_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/multinomial_test.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index e24e8ade73..ebc89f15c5 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -250,11 +250,13 @@ class MultinomialTest(test.TestCase): theta = np.array([[1., 2, 3], [2.5, 4, 0.01]], dtype=np.float32) theta /= np.sum(theta, 1)[..., array_ops.newaxis] - n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32) + # Ideally we'd be able to test broadcasting but, the multinomial sampler + # doesn't support different total counts. + n = np.float32(5) with self.test_session() as sess: - # batch_shape=[3, 2], event_shape=[3] + # batch_shape=[2], event_shape=[3] dist = multinomial.Multinomial(n, theta) - x = dist.sample(int(1000e3), seed=1) + x = dist.sample(int(250e3), seed=1) sample_mean = math_ops.reduce_mean(x, 0) x_centered = x - sample_mean[array_ops.newaxis, ...] sample_cov = math_ops.reduce_mean(math_ops.matmul( @@ -289,9 +291,9 @@ class MultinomialTest(test.TestCase): def testSampleUnbiasedNonScalarBatch(self): with self.test_session() as sess: dist = multinomial.Multinomial( - total_count=[7., 6., 5.], + total_count=5., logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32))) - n = int(3e4) + n = int(3e3) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) # Cyclically rotate event dims left. |