aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
blob: cf55da27236d17c709cbde689831ad68da9a8a7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for receptive_fields module."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib import slim
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test


# TODO(andrearaujo): Rename the create_test_network_* functions in order to have
# more descriptive names.
def create_test_network_1():
  """Aligned network for test.

  The graph corresponds to the example from the second figure in
  go/cnn-rf-computation#arbitrary-computation-graphs

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    # Right branch.
    l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
    # Addition.
    nn.relu(l1 + l3, name='output')
  return g


def create_test_network_2():
  """Aligned network for test.

  The graph corresponds to a variation to the example from the second figure in
  go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed
  to max-pooling operations. Since the functionality is the same as convolution,
  the network is aligned and the receptive field size is the same as from the
  network created using create_test_network_1().

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    # Right branch.
    l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID')
    l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID')
    # Addition.
    nn.relu(l1 + l3, name='output')
  return g


def create_test_network_3():
  """Misaligned network for test.

  The graph corresponds to the example from the first figure in
  go/cnn-rf-computation#arbitrary-computation-graphs

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch.
    l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]])
    l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID')
    # Right branch.
    l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID')
    l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID')
    # Addition.
    nn.relu(l1 + l3, name='output')
  return g


def create_test_network_4():
  """Misaligned network for test.

  The graph corresponds to a variation from the example from the second figure
  in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME'
  padding, which makes its padding dependent on the input image dimensionality.
  In this case, the effective padding will be undetermined, and the utility is
  not able to check the network alignment.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    # Right branch.
    l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
    # Addition.
    nn.relu(l1 + l3, name='output')
  return g


def create_test_network_5():
  """Single-path network for testing non-square kernels.

  The graph is similar to the right branch of the graph from
  create_test_network_1(), except that the kernel sizes are changed to be
  non-square.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Two convolutional layers, where the first one has non-square kernel.
    l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID')
    l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID')
    # ReLU.
    nn.relu(l2, name='output')
  return g


def create_test_network_6():
  """Aligned network with dropout for test.

  The graph is similar to create_test_network_1(), except that the right branch
  has dropout normalization.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    # Right branch.
    l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
    dropout = slim.dropout(l3)
    # Addition.
    nn.relu(l1 + dropout, name='output')
  return g


def create_test_network_7():
  """Aligned network for test, with a control dependency.

  The graph is similar to create_test_network_1(), except that it includes an
  assert operation on the left branch.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An 8x8 test image.
    x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
    # Left branch.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    l1_shape = array_ops.shape(l1)
    assert_op = control_flow_ops.Assert(
        gen_math_ops.equal(l1_shape[1], 2), [l1_shape], summarize=4)
    # Right branch.
    l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
    # Addition.
    with ops.control_dependencies([assert_op]):
      nn.relu(l1 + l3, name='output')
  return g


def create_test_network_8():
  """Aligned network for test, including an intermediate addition.

  The graph is similar to create_test_network_1(), except that it includes a few
  more layers on top. The added layers compose two different branches whose
  receptive fields are different. This makes this test case more challenging; in
  particular, this test fails if a naive DFS-like algorithm is used for RF
  computation.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch before first addition.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
    # Right branch before first addition.
    l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
    # First addition.
    l4 = nn.relu(l1 + l3)
    # Left branch after first addition.
    l5 = slim.conv2d(l4, 1, [1, 1], stride=2, scope='L5', padding='VALID')
    # Right branch after first addition.
    l6_pad = array_ops.pad(l4, [[0, 0], [1, 0], [1, 0], [0, 0]])
    l6 = slim.conv2d(l6_pad, 1, [3, 3], stride=2, scope='L6', padding='VALID')
    # Final addition.
    nn.relu(l5 + l6, name='output')

  return g


def create_test_network_9():
  """Aligned network for test, including an intermediate addition.

  The graph is the same as create_test_network_8(), except that VALID padding is
  changed to SAME.

  Returns:
    g: Tensorflow graph object (Graph proto).
  """
  g = ops.Graph()
  with g.as_default():
    # An input test image with unknown spatial resolution.
    x = array_ops.placeholder(
        dtypes.float32, (None, None, None, 1), name='input_image')
    # Left branch before first addition.
    l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='SAME')
    # Right branch before first addition.
    l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
    l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='SAME')
    # First addition.
    l4 = nn.relu(l1 + l3)
    # Left branch after first addition.
    l5 = slim.conv2d(l4, 1, [1, 1], stride=2, scope='L5', padding='SAME')
    # Right branch after first addition.
    l6 = slim.conv2d(l4, 1, [3, 3], stride=2, scope='L6', padding='SAME')
    # Final addition.
    nn.relu(l5 + l6, name='output')

  return g


class ReceptiveFieldTest(test.TestCase):

  def testComputeRFFromGraphDefAligned(self):
    graph_def = create_test_network_1().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 1)
    self.assertEqual(effective_padding_y, 1)

  def testComputeRFFromGraphDefAligned2(self):
    graph_def = create_test_network_2().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 1)
    self.assertEqual(effective_padding_y, 1)

  def testComputeRFFromGraphDefUnaligned(self):
    graph_def = create_test_network_3().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    with self.assertRaises(ValueError):
      receptive_field.compute_receptive_field_from_graph_def(
          graph_def, input_node, output_node)

  def testComputeRFFromGraphDefUndefinedPadding(self):
    graph_def = create_test_network_4().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, None)
    self.assertEqual(effective_padding_y, None)

  def testComputeRFFromGraphDefFixedInputDim(self):
    graph_def = create_test_network_4().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node, input_resolution=[9, 9]))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 1)
    self.assertEqual(effective_padding_y, 1)

  def testComputeRFFromGraphDefUnalignedFixedInputDim(self):
    graph_def = create_test_network_4().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    with self.assertRaises(ValueError):
      receptive_field.compute_receptive_field_from_graph_def(
          graph_def, input_node, output_node, input_resolution=[8, 8])

  def testComputeRFFromGraphDefNonSquareRF(self):
    graph_def = create_test_network_5().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 5)
    self.assertEqual(receptive_field_y, 7)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 0)
    self.assertEqual(effective_padding_y, 0)

  def testComputeRFFromGraphDefStopPropagation(self):
    graph_def = create_test_network_6().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    # Compute the receptive field but stop the propagation for the random
    # uniform variable of the dropout.
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node,
             ['Dropout/dropout/random_uniform']))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 1)
    self.assertEqual(effective_padding_y, 1)

  def testComputeCoordinatesRoundtrip(self):
    graph_def = create_test_network_1()
    input_node = 'input_image'
    output_node = 'output'
    rf = receptive_field.compute_receptive_field_from_graph_def(
        graph_def, input_node, output_node)

    x = np.random.randint(0, 100, (50, 2))
    y = rf.compute_feature_coordinates(x)
    x2 = rf.compute_input_center_coordinates(y)

    self.assertAllEqual(x, x2)

  def testComputeRFFromGraphDefAlignedWithControlDependencies(self):
    graph_def = create_test_network_7().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 3)
    self.assertEqual(receptive_field_y, 3)
    self.assertEqual(effective_stride_x, 4)
    self.assertEqual(effective_stride_y, 4)
    self.assertEqual(effective_padding_x, 1)
    self.assertEqual(effective_padding_y, 1)

  def testComputeRFFromGraphDefWithIntermediateAddNode(self):
    graph_def = create_test_network_8().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node))
    self.assertEqual(receptive_field_x, 11)
    self.assertEqual(receptive_field_y, 11)
    self.assertEqual(effective_stride_x, 8)
    self.assertEqual(effective_stride_y, 8)
    self.assertEqual(effective_padding_x, 5)
    self.assertEqual(effective_padding_y, 5)

  def testComputeRFFromGraphDefWithIntermediateAddNodeSamePaddingFixedInputDim(
      self):
    graph_def = create_test_network_9().as_graph_def()
    input_node = 'input_image'
    output_node = 'output'
    (receptive_field_x, receptive_field_y, effective_stride_x,
     effective_stride_y, effective_padding_x, effective_padding_y) = (
         receptive_field.compute_receptive_field_from_graph_def(
             graph_def, input_node, output_node, input_resolution=[17, 17]))
    self.assertEqual(receptive_field_x, 11)
    self.assertEqual(receptive_field_y, 11)
    self.assertEqual(effective_stride_x, 8)
    self.assertEqual(effective_stride_y, 8)
    self.assertEqual(effective_padding_x, 5)
    self.assertEqual(effective_padding_y, 5)


if __name__ == '__main__':
  test.main()