diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 4d2f40e27f..c6c8d2cf6e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops @@ -514,6 +515,42 @@ class _BatchReshapeTest(object): batch_shape=new_batch_shape_ph, validate_args=True).sample().eval() + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): |