aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
index 4d2f40e27f..c6c8d2cf6e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib
+from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
@@ -514,6 +515,42 @@ class _BatchReshapeTest(object):
batch_shape=new_batch_shape_ph,
validate_args=True).sample().eval()
+ def test_broadcasting_explicitly_unsupported(self):
+ old_batch_shape = [4]
+ new_batch_shape = [1, 4, 1]
+ rate_ = self.dtype([1, 10, 2, 20])
+
+ rate = array_ops.placeholder_with_default(
+ rate_,
+ shape=old_batch_shape if self.is_static_shape else None)
+ poisson_4 = poisson_lib.Poisson(rate)
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+ poisson_141_reshaped = batch_reshape_lib.BatchReshape(
+ poisson_4, new_batch_shape_ph, validate_args=True)
+
+ x_4 = self.dtype([2, 12, 3, 23])
+ x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(NotImplementedError,
+ "too few event dims"):
+ poisson_141_reshaped.log_prob(x_4)
+ with self.assertRaisesRegexp(NotImplementedError,
+ "unexpected batch and event shape"):
+ poisson_141_reshaped.log_prob(x_114)
+ return
+
+ with self.assertRaisesOpError("too few event dims"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_4).eval()
+
+ with self.assertRaisesOpError("unexpected batch and event shape"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_114).eval()
+
class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase):