1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ExtractImagePatches gradient."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class ExtractImagePatchesGradTest(test.TestCase):
"""Gradient-checking for ExtractImagePatches op."""
_TEST_CASES = [
{
'in_shape': [2, 5, 5, 3],
'ksizes': [1, 1, 1, 1],
'strides': [1, 2, 3, 1],
'rates': [1, 1, 1, 1],
},
{
'in_shape': [2, 7, 7, 3],
'ksizes': [1, 3, 3, 1],
'strides': [1, 1, 1, 1],
'rates': [1, 1, 1, 1],
},
{
'in_shape': [2, 8, 7, 3],
'ksizes': [1, 2, 2, 1],
'strides': [1, 1, 1, 1],
'rates': [1, 1, 1, 1],
},
{
'in_shape': [2, 7, 8, 3],
'ksizes': [1, 3, 2, 1],
'strides': [1, 4, 3, 1],
'rates': [1, 1, 1, 1],
},
{
'in_shape': [1, 15, 20, 3],
'ksizes': [1, 4, 3, 1],
'strides': [1, 1, 1, 1],
'rates': [1, 2, 4, 1],
},
{
'in_shape': [2, 7, 8, 1],
'ksizes': [1, 3, 2, 1],
'strides': [1, 3, 2, 1],
'rates': [1, 2, 2, 1],
},
{
'in_shape': [2, 8, 9, 4],
'ksizes': [1, 2, 2, 1],
'strides': [1, 4, 2, 1],
'rates': [1, 3, 2, 1],
},
]
def testGradient(self):
# Set graph seed for determinism.
random_seed = 42
random_seed_lib.set_random_seed(random_seed)
with self.cached_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
in_val = constant_op.constant(
np.random.random(in_shape), dtype=dtypes.float32)
for padding in ['VALID', 'SAME']:
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
test_case['strides'],
test_case['rates'], padding)
out_shape = out_val.get_shape().as_list()
err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
def testConstructGradientWithLargeImages(self):
batch_size = 4
height = 1024
width = 1024
ksize = 5
images = variable_scope.get_variable('inputs',
(batch_size, height, width, 1))
patches = array_ops.extract_image_patches(images,
ksizes=[1, ksize, ksize, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='SAME')
# Github issue: #20146
# tf.extract_image_patches() gradient very slow at graph construction time
gradients = gradients_impl.gradients(patches, images)
# Won't time out.
self.assertIsNotNone(gradients)
if __name__ == '__main__':
test.main()
|