aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/pooling_ops_3d_test.py')
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_3d_test.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
index ca38f1af9f..fa1553a3f6 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -234,7 +235,8 @@ class PoolingTest(test.TestCase):
x = np.arange(1, total_size + 1, dtype=np.float32)
with self.test_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(x, shape=input_sizes, name="input")
- err_margin = 1e-3
+ err_g_margin = 1e-3
+ err_gg_margin = 1.5e-2
if pool_func == nn_ops.avg_pool3d:
func_name = "avg_pool3d"
x_init_value = None
@@ -259,19 +261,27 @@ class PoolingTest(test.TestCase):
padding=padding,
data_format=data_format,
name=func_name)
+ t_g = gradients_impl.gradients(t**2, input_tensor)[0]
- if data_format == "NCDHW":
- t = test_util.NCHWToNHWC(t)
-
- err = gradient_checker.compute_gradient_error(
+ err_g = gradient_checker.compute_gradient_error(
input_tensor,
input_sizes,
t,
output_sizes,
x_init_value=x_init_value,
delta=1e-2)
- print("%s gradient error = " % func_name, err)
- self.assertLess(err, err_margin)
+ err_gg = gradient_checker.compute_gradient_error(
+ input_tensor,
+ input_sizes,
+ t_g,
+ input_sizes,
+ x_init_value=x_init_value,
+ delta=1e-2)
+
+ print("%s gradient error = " % func_name, err_g)
+ self.assertLess(err_g, err_g_margin)
+ print("%s second-order gradient error = " % func_name, err_gg)
+ self.assertLess(err_gg, err_gg_margin)
def _ConstructAndTestGradient(self,
pool_func,