diff options
Diffstat (limited to 'tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py')
-rw-r--r-- | tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index a25de55e18..31a6fe1d94 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import numpy from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) with self.test_session(): - variables.global_variables_initializer().run() with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): '4, to be the same as the length of the desired shape, 3'): periodic_resample(input_tensor, [None, 4, 4]).eval() + def testPeriodicResampleGradient(self): + desired_shape = numpy.array([4, 4, None]) + result_shape = (4, 4, 1) + input_shape = (2, 2, 4) + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=input_shape) + output = periodic_resample(x, desired_shape) + error = gradient_checker.compute_gradient_error( + x, input_shape, output, result_shape) + self.assertLess(error, 1e-4) + + def testPeriodicResampleShapeInference(self): + with self.test_session() as sess: + # Case 1: output shape can be fully inferreed. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) + output = periodic_resample(x, [4, 4, None]) + self.assertEqual(output.shape, [4, 4, 1]) + # Case 2: output shape can not be inferred - report desired shape. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None)) + output = periodic_resample(x, [4, 4, None]) + self.assertTrue(output.shape.is_compatible_with([4, 4, None])) + self.assertEqual(output.shape[2].value, None) + if __name__ == '__main__': googletest.main() |