aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py')
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 1d727870f6..30a2077570 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -19,8 +19,9 @@ from __future__ import division
from __future__ import print_function
import numpy
-import tensorflow
+
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -96,6 +97,19 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
+ def testPeriodicResampleErrors(self):
+ input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ 'Dimension 3 input tensor has size 4, desired shape has size 1'):
+ periodic_resample(input_tensor, [None, 4, 4, 1]).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ '4, to be the same as the length of the desired shape, 3'):
+ periodic_resample(input_tensor, [None, 4, 4]).eval()
+
if __name__ == "__main__":
googletest.main()