aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-14 12:41:36 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-14 12:56:29 +0800
commitf982cfe9f943c9920cafeefff7818ea298d5b509 (patch)
tree591a63992810421d476a790130f989c94a1c93e6 /tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
parent4aaab50552a3cdb4b785653f071ae6c7193992ca (diff)
TST: add benchmark
Diffstat (limited to 'tensorflow/python/kernel_tests/extract_image_patches_grad_test.py')
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_grad_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index 60090a1510..e1f5a6b620 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -25,6 +25,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@@ -100,6 +102,24 @@ class ExtractImagePatchesGradTest(test.TestCase):
print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
+ def testConstructGradientWithLargeImages(self):
+ batch_size = 4
+ height = 1024
+ width = 1024
+ ksize = 5
+ images = variable_scope.get_variable('inputs',
+ (batch_size, height, width, 1))
+ patches = array_ops.extract_image_patches(images,
+ ksizes=[1, ksize, ksize, 1],
+ strides=[1, 1, 1, 1],
+ rates=[1, 1, 1, 1],
+ padding='SAME')
+ # Github issue: #20146
+ # tf.extract_image_patches() gradient very slow at graph construction time
+ gradients = gradients_impl.gradients(patches, images)
+ # Won't time out.
+ self.assertIsNotNone(gradients)
+
if __name__ == '__main__':
test.main()