aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py')
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py27
1 files changed, 26 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index a25de55e18..31a6fe1d94 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import numpy
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
with self.test_session():
- variables.global_variables_initializer().run()
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
'4, to be the same as the length of the desired shape, 3'):
periodic_resample(input_tensor, [None, 4, 4]).eval()
+ def testPeriodicResampleGradient(self):
+ desired_shape = numpy.array([4, 4, None])
+ result_shape = (4, 4, 1)
+ input_shape = (2, 2, 4)
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=input_shape)
+ output = periodic_resample(x, desired_shape)
+ error = gradient_checker.compute_gradient_error(
+ x, input_shape, output, result_shape)
+ self.assertLess(error, 1e-4)
+
+ def testPeriodicResampleShapeInference(self):
+ with self.test_session() as sess:
+ # Case 1: output shape can be fully inferreed.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertEqual(output.shape, [4, 4, 1])
+ # Case 2: output shape can not be inferred - report desired shape.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertTrue(output.shape.is_compatible_with([4, 4, None]))
+ self.assertEqual(output.shape[2].value, None)
+
if __name__ == '__main__':
googletest.main()