aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
blob: 7d9d4e517527e457c0da73d4f4b2a8763359a693 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ExtractImagePatches gradient."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test


class ExtractImagePatchesGradTest(test.TestCase):
  """Gradient-checking for ExtractImagePatches op."""

  _TEST_CASES = [
      {
          'in_shape': [2, 5, 5, 3],
          'ksizes': [1, 1, 1, 1],
          'strides': [1, 2, 3, 1],
          'rates': [1, 1, 1, 1],
      },
      {
          'in_shape': [2, 7, 7, 3],
          'ksizes': [1, 3, 3, 1],
          'strides': [1, 1, 1, 1],
          'rates': [1, 1, 1, 1],
      },
      {
          'in_shape': [2, 8, 7, 3],
          'ksizes': [1, 2, 2, 1],
          'strides': [1, 1, 1, 1],
          'rates': [1, 1, 1, 1],
      },
      {
          'in_shape': [2, 7, 8, 3],
          'ksizes': [1, 3, 2, 1],
          'strides': [1, 4, 3, 1],
          'rates': [1, 1, 1, 1],
      },
      {
          'in_shape': [1, 15, 20, 3],
          'ksizes': [1, 4, 3, 1],
          'strides': [1, 1, 1, 1],
          'rates': [1, 2, 4, 1],
      },
      {
          'in_shape': [2, 7, 8, 1],
          'ksizes': [1, 3, 2, 1],
          'strides': [1, 3, 2, 1],
          'rates': [1, 2, 2, 1],
      },
      {
          'in_shape': [2, 8, 9, 4],
          'ksizes': [1, 2, 2, 1],
          'strides': [1, 4, 2, 1],
          'rates': [1, 3, 2, 1],
      },
  ]

  def testGradient(self):
    # Set graph seed for determinism.
    random_seed = 42
    random_seed_lib.set_random_seed(random_seed)

    with self.cached_session():
      for test_case in self._TEST_CASES:
        np.random.seed(random_seed)
        in_shape = test_case['in_shape']
        in_val = constant_op.constant(
            np.random.random(in_shape), dtype=dtypes.float32)

        for padding in ['VALID', 'SAME']:
          out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
                                                    test_case['strides'],
                                                    test_case['rates'], padding)
          out_shape = out_val.get_shape().as_list()

          err = gradient_checker.compute_gradient_error(in_val, in_shape,
                                                        out_val, out_shape)

          print('extract_image_patches gradient err: %.4e' % err)
          self.assertLess(err, 1e-4)

  def testConstructGradientWithLargeImages(self):
    batch_size = 4
    height = 1024
    width = 1024
    ksize = 5
    images = variable_scope.get_variable('inputs',
                                         (batch_size, height, width, 1))
    patches = array_ops.extract_image_patches(images,
                                              ksizes=[1, ksize, ksize, 1],
                                              strides=[1, 1, 1, 1],
                                              rates=[1, 1, 1, 1],
                                              padding='SAME')
    # Github issue: #20146
    # tf.extract_image_patches() gradient very slow at graph construction time
    gradients = gradients_impl.gradients(patches, images)
    # Won't time out.
    self.assertIsNotNone(gradients)


if __name__ == '__main__':
  test.main()