diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/multinomial_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/distributions/multinomial_test.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index ebc89f15c5..e24e8ade73 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -250,13 +250,11 @@ class MultinomialTest(test.TestCase): theta = np.array([[1., 2, 3], [2.5, 4, 0.01]], dtype=np.float32) theta /= np.sum(theta, 1)[..., array_ops.newaxis] - # Ideally we'd be able to test broadcasting but, the multinomial sampler - # doesn't support different total counts. - n = np.float32(5) + n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32) with self.test_session() as sess: - # batch_shape=[2], event_shape=[3] + # batch_shape=[3, 2], event_shape=[3] dist = multinomial.Multinomial(n, theta) - x = dist.sample(int(250e3), seed=1) + x = dist.sample(int(1000e3), seed=1) sample_mean = math_ops.reduce_mean(x, 0) x_centered = x - sample_mean[array_ops.newaxis, ...] sample_cov = math_ops.reduce_mean(math_ops.matmul( @@ -291,9 +289,9 @@ class MultinomialTest(test.TestCase): def testSampleUnbiasedNonScalarBatch(self): with self.test_session() as sess: dist = multinomial.Multinomial( - total_count=5., + total_count=[7., 6., 5.], logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32))) - n = int(3e3) + n = int(3e4) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) # Cyclically rotate event dims left. |