diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/pooling_ops_3d_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/pooling_ops_3d_test.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py index ca38f1af9f..fa1553a3f6 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -234,7 +235,8 @@ class PoolingTest(test.TestCase): x = np.arange(1, total_size + 1, dtype=np.float32) with self.test_session(use_gpu=use_gpu): input_tensor = constant_op.constant(x, shape=input_sizes, name="input") - err_margin = 1e-3 + err_g_margin = 1e-3 + err_gg_margin = 1.5e-2 if pool_func == nn_ops.avg_pool3d: func_name = "avg_pool3d" x_init_value = None @@ -259,19 +261,27 @@ class PoolingTest(test.TestCase): padding=padding, data_format=data_format, name=func_name) + t_g = gradients_impl.gradients(t**2, input_tensor)[0] - if data_format == "NCDHW": - t = test_util.NCHWToNHWC(t) - - err = gradient_checker.compute_gradient_error( + err_g = gradient_checker.compute_gradient_error( input_tensor, input_sizes, t, output_sizes, x_init_value=x_init_value, delta=1e-2) - print("%s gradient error = " % func_name, err) - self.assertLess(err, err_margin) + err_gg = gradient_checker.compute_gradient_error( + input_tensor, + input_sizes, + t_g, + input_sizes, + x_init_value=x_init_value, + delta=1e-2) + + print("%s gradient error = " % func_name, err_g) + self.assertLess(err_g, err_g_margin) + print("%s second-order gradient error = " % func_name, err_gg) + self.assertLess(err_gg, err_gg_margin) def _ConstructAndTestGradient(self, pool_func, |