aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/receptive_field
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-16 15:53:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 15:59:31 -0800
commit7e13a9ea3709301186b946a8c1f864e1245e6271 (patch)
treec23f7516170a3e7d953257795b50b1e755d5a98a /tensorflow/contrib/receptive_field
parentccbd14b741e6efbe51769f0f1b9cb3719c42c23b (diff)
Introducing RF computation considering models with specific input resolution. Previously, the input resolution was not taken into account, which led to undefined padding for many well-known models (since those rely on SAME padding, and in some cases SAME padding depends on input resolution).
This change also redesigns many aspects of the topological sorting and layer parsing functions, introducing new modules and tests. PiperOrigin-RevId: 182124694
Diffstat (limited to 'tensorflow/contrib/receptive_field')
-rw-r--r--tensorflow/contrib/receptive_field/BUILD50
-rw-r--r--tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py146
-rw-r--r--tensorflow/contrib/receptive_field/python/util/graph_compute_order.py184
-rw-r--r--tensorflow/contrib/receptive_field/python/util/graph_compute_order_test.py152
-rw-r--r--tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py297
-rw-r--r--tensorflow/contrib/receptive_field/python/util/parse_layer_parameters_test.py149
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field.py351
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field_test.py108
8 files changed, 1060 insertions, 377 deletions
diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD
index c67974c797..e975aeaea7 100644
--- a/tensorflow/contrib/receptive_field/BUILD
+++ b/tensorflow/contrib/receptive_field/BUILD
@@ -25,18 +25,34 @@ py_library(
"python/util/graph_compute_order.py",
],
srcs_version = "PY2AND3",
+ deps = [
+ ":parse_layer_parameters_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_library(
+ name = "parse_layer_parameters_py",
+ srcs = [
+ "python/util/parse_layer_parameters.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
)
py_library(
name = "receptive_field_py",
srcs = [
+ "python/util/parse_layer_parameters.py",
"python/util/receptive_field.py",
"receptive_field_api.py",
],
srcs_version = "PY2AND3",
deps = [
":graph_compute_order_py",
- "//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//third_party/py/numpy",
@@ -44,6 +60,38 @@ py_library(
)
py_test(
+ name = "graph_compute_order_test",
+ srcs = ["python/util/graph_compute_order_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":graph_compute_order_py",
+ ":receptive_field_py",
+ "//tensorflow/contrib/slim",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
+py_test(
+ name = "parse_layer_parameters_test",
+ srcs = ["python/util/parse_layer_parameters_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":graph_compute_order_py",
+ ":parse_layer_parameters_py",
+ "//tensorflow/contrib/slim",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
+py_test(
name = "receptive_field_test",
srcs = ["python/util/receptive_field_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py
index cd16abd5ab..a298b4d490 100644
--- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py
+++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py
@@ -245,7 +245,8 @@ def _model_rf(graphdef,
end_points,
desired_end_point_keys,
model_type='resnet_v1_50',
- csv_writer=None):
+ csv_writer=None,
+ input_resolution=None):
"""Computes receptive field information for a given CNN model.
The information will be printed to stdout. If the RF parameters are the same
@@ -261,45 +262,93 @@ def _model_rf(graphdef,
information will be computed.
model_type: Type of model to be used, used only for printing purposes.
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ input_resolution: Input resolution to use when computing RF
+ parameters. This is important for the case where padding can only be
+ defined if the input resolution is known, which may happen if using SAME
+ padding. This is assumed the resolution for both height and width. If
+ None, we consider the resolution is unknown.
"""
for desired_end_point_key in desired_end_point_keys:
print('- %s:' % desired_end_point_key)
output_node_with_colon = end_points[desired_end_point_key].name
pos = output_node_with_colon.rfind(':')
output_node = output_node_with_colon[:pos]
- (receptive_field_x, receptive_field_y, effective_stride_x,
- effective_stride_y, effective_padding_x, effective_padding_y
- ) = receptive_field.compute_receptive_field_from_graph_def(
- graphdef, _INPUT_NODE, output_node)
- # If values are the same in horizontal/vertical directions, just report one
- # of them. Otherwise, report both.
- if (receptive_field_x == receptive_field_y) and (
- effective_stride_x == effective_stride_y) and (
- effective_padding_x == effective_padding_y):
- print('Receptive field size = %5s, effective stride = %5s, effective '
- 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x),
- str(effective_padding_x)))
- else:
- print('Receptive field size: horizontal = %5s, vertical = %5s. '
- 'Effective stride: horizontal = %5s, vertical = %5s. Effective '
- 'padding: horizontal = %5s, vertical = %5s' %
- (str(receptive_field_x), str(receptive_field_y),
- str(effective_stride_x), str(effective_stride_y),
- str(effective_padding_x), str(effective_padding_y)))
- if csv_writer is not None:
- csv_writer.writerow({
- 'CNN': model_type,
- 'end_point': desired_end_point_key,
- 'RF size hor': str(receptive_field_x),
- 'RF size ver': str(receptive_field_y),
- 'effective stride hor': str(effective_stride_x),
- 'effective stride ver': str(effective_stride_y),
- 'effective padding hor': str(effective_padding_x),
- 'effective padding ver': str(effective_padding_y)
- })
-
-
-def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None):
+ try:
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y
+ ) = receptive_field.compute_receptive_field_from_graph_def(
+ graphdef, _INPUT_NODE, output_node, input_resolution=input_resolution)
+ # If values are the same in horizontal/vertical directions, just report
+ # one of them. Otherwise, report both.
+ if (receptive_field_x == receptive_field_y) and (
+ effective_stride_x == effective_stride_y) and (
+ effective_padding_x == effective_padding_y):
+ print('Receptive field size = %5s, effective stride = %5s, effective '
+ 'padding = %5s' % (str(receptive_field_x),
+ str(effective_stride_x),
+ str(effective_padding_x)))
+ else:
+ print('Receptive field size: horizontal = %5s, vertical = %5s. '
+ 'Effective stride: horizontal = %5s, vertical = %5s. Effective '
+ 'padding: horizontal = %5s, vertical = %5s' %
+ (str(receptive_field_x), str(receptive_field_y),
+ str(effective_stride_x), str(effective_stride_y),
+ str(effective_padding_x), str(effective_padding_y)))
+ if csv_writer is not None:
+ csv_writer.writerow({
+ 'CNN':
+ model_type,
+ 'input resolution':
+ str(input_resolution[0])
+ if input_resolution is not None else 'None',
+ 'end_point':
+ desired_end_point_key,
+ 'RF size hor':
+ str(receptive_field_x),
+ 'RF size ver':
+ str(receptive_field_y),
+ 'effective stride hor':
+ str(effective_stride_x),
+ 'effective stride ver':
+ str(effective_stride_y),
+ 'effective padding hor':
+ str(effective_padding_x),
+ 'effective padding ver':
+ str(effective_padding_y)
+ })
+ except ValueError as e:
+ print('---->ERROR: Computing RF parameters for model %s with final end '
+ 'point %s and input resolution %s did not work' %
+ (model_type, desired_end_point_key, input_resolution))
+ print('---->The returned error is: %s' % e)
+ if csv_writer is not None:
+ csv_writer.writerow({
+ 'CNN':
+ model_type,
+ 'input resolution':
+ str(input_resolution[0])
+ if input_resolution is not None else 'None',
+ 'end_point':
+ desired_end_point_key,
+ 'RF size hor':
+ 'None',
+ 'RF size ver':
+ 'None',
+ 'effective stride hor':
+ 'None',
+ 'effective stride ver':
+ 'None',
+ 'effective padding hor':
+ 'None',
+ 'effective padding ver':
+ 'None'
+ })
+
+
+def _process_model_rf(model_type='resnet_v1_50',
+ csv_writer=None,
+ arg_sc=None,
+ input_resolutions=None):
"""Contructs model graph and desired end-points, and compute RF.
The computed RF parameters are printed to stdout by the _model_rf function.
@@ -308,13 +357,30 @@ def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None):
model_type: Type of model to be used.
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
arg_sc: Optional arg scope to use in constructing the graph.
+ input_resolutions: List of 1D input resolutions to use when computing RF
+ parameters. This is important for the case where padding can only be
+ defined if the input resolution is known, which may happen if using SAME
+ padding. The entries in the list are assumed the resolution for both
+ height and width. If one of the elements in the list is None, we consider
+ it to mean that the resolution is unknown. If the list itself is None,
+ we use the default list [None, 224, 321].
"""
- print('********************%s' % model_type)
- graphdef, end_points = _model_graph_def(model_type, arg_sc)
- desired_end_point_keys = _get_desired_end_point_keys(model_type)
- _model_rf(graphdef, end_points, desired_end_point_keys, model_type,
- csv_writer)
+ # Process default value for this list.
+ if input_resolutions is None:
+ input_resolutions = [None, 224, 321]
+
+ for n in input_resolutions:
+ print('********************%s, input resolution = %s' % (model_type, n))
+ graphdef, end_points = _model_graph_def(model_type, arg_sc)
+ desired_end_point_keys = _get_desired_end_point_keys(model_type)
+ _model_rf(
+ graphdef,
+ end_points,
+ desired_end_point_keys,
+ model_type,
+ csv_writer,
+ input_resolution=[n, n] if n is not None else None)
def _resnet_rf(csv_writer=None):
@@ -421,7 +487,7 @@ def main(unused_argv):
if cmd_args.csv_path:
csv_file = open(cmd_args.csv_path, 'w')
field_names = [
- 'CNN', 'end_point', 'RF size hor', 'RF size ver',
+ 'CNN', 'input resolution', 'end_point', 'RF size hor', 'RF size ver',
'effective stride hor', 'effective stride ver', 'effective padding hor',
'effective padding ver'
]
diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
index 6153607656..b2360fec6c 100644
--- a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
+++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
@@ -20,45 +20,100 @@ from __future__ import division
from __future__ import print_function
import collections
+import math
+from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
+from tensorflow.python.platform import tf_logging as logging
-class GraphDefHelper(object):
- """Helper class to collect node names and definitions.
+def parse_graph_nodes(graph_def):
+ """Helper function to parse GraphDef's nodes.
- Example:
- b = GraphDefHelper(graph_def)
- # Prints node that produces given output.
- print b.output_of['conv/foo/bar']
+ It returns a dict mapping from node name to NodeDef.
+
+ Args:
+ graph_def: A GraphDef object.
+
+ Returns:
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
"""
+ name_to_node = {}
+ for node_def in graph_def.node:
+ name_to_node[node_def.name] = node_def
+ return name_to_node
- def __init__(self, gd):
- self.output_of = {}
- for each in gd.node:
- self.output_of[each.name] = each
+# Named tuple used to collect information from each node in a computation graph.
+_node_info = collections.namedtuple(
+ 'NodeInfo', field_names=['order', 'node', 'input_size', 'output_size'])
-# pylint: disable=invalid-name
-_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node'])
+def _compute_output_resolution(input_spatial_resolution, kernel_size, stride,
+ total_padding):
+ """Computes output resolution, given input resolution and layer parameters.
-def _get_computed_nodes(g, output, seen):
- """Traverses the graph in topological order.
+ Note that this computation is done only over one dimension (eg, x or y).
+ If any of the inputs is None, returns None.
+
+ Args:
+ input_spatial_resolution: Input spatial resolution (int).
+ kernel_size: Kernel size (int).
+ stride: Stride (int).
+ total_padding: Total padding to be applied (int).
+ Returns:
+ output_resolution: Ouput dimension (int) or None.
+ """
+ if (input_spatial_resolution is None) or (kernel_size is None) or (
+ stride is None) or (total_padding is None):
+ return None
+ return int(
+ math.ceil((
+ input_spatial_resolution + total_padding - kernel_size + 1) / stride))
+
+
+def _get_computed_nodes(name_to_node,
+ current,
+ node_info,
+ input_node_name='',
+ input_node_size=None):
+ """Traverses the graph recursively to compute its topological order.
+
+ Optionally, the function may also compute the input and output feature map
+ resolutions at each node. In this case, input_node_name and input_node_size
+ must be set. Note that if a node's op type is unknown, the input and output
+ resolutions are ignored and set to None.
Args:
- g: GraphDefHelper object.
- output: current node.
- seen: map of nodes we've already traversed.
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
+ current: Current node name.
+ node_info: Map of nodes we've already traversed, containing their _node_info
+ information.
+ input_node_name: Name of node with fixed input resolution (optional).
+ input_node_size: Fixed input resolution to use (optional).
Returns:
- order in topological sort for 'output'.
+ order: Order in topological sort for 'current'.
+ input_size: Tensor spatial resolution at input of current node.
+ output_size: Tensor spatial resolution at output of current node.
"""
- if output in seen:
- return seen[output].order
- node_def = g.output_of.get(output, None)
- if node_def is None:
- seen[output] = _NodeEntry(0, None)
- return 0
-
- r = 0
+ if current in node_info:
+ return (node_info[current].order, node_info[current].input_size,
+ node_info[current].output_size)
+
+ node_def = name_to_node[current]
+
+ if current == input_node_name:
+ order = 0
+ input_size = None
+ output_size = input_node_size
+ node_info[current] = _node_info(order, node_def, input_size, output_size)
+ return (order, input_size, output_size)
+
+ input_size = None
+ output_size = None
+
+ order = 0
+ number_inputs = 0
for each in node_def.input:
# Parses name of input node.
if each.startswith('^'):
@@ -67,24 +122,71 @@ def _get_computed_nodes(g, output, seen):
continue
each = each.split(':')[0]
# Recursively computes ordering.
- new_v = _get_computed_nodes(g, each, seen)
- r = max(r, new_v + 1)
-
- seen[output] = _NodeEntry(r, node_def)
-
- return seen[output].order
-
-
-def get_compute_order(graph_def):
- """Computes order of computation for a given graph.
+ (parent_order, _, parent_output_size) = _get_computed_nodes(
+ name_to_node, each, node_info, input_node_name, input_node_size)
+ order = max(order, parent_order + 1)
+ if number_inputs == 0:
+ # For all the types of nodes we consider, the first input corresponds to
+ # the feature map.
+ input_size = parent_output_size
+ number_inputs += 1
+
+ # Figure out output size for this layer.
+ logging.vlog(3, 'input_size = %s', input_size)
+ if input_size is None:
+ output_size = None
+ else:
+ (kernel_size_x, kernel_size_y, stride_x, stride_y, _, _, total_padding_x,
+ total_padding_y) = (
+ parse_layer_parameters.get_layer_params(
+ node_def, name_to_node, input_size, force=True))
+ logging.vlog(3, 'kernel_size_x = %s, kernel_size_y = %s, '
+ 'stride_x = %s, stride_y = %s, '
+ 'total_padding_x = %s, total_padding_y = %s' %
+ (kernel_size_x, kernel_size_y, stride_x, stride_y,
+ total_padding_x, total_padding_y))
+ output_size = [None] * 2
+ output_size[0] = _compute_output_resolution(input_size[0], kernel_size_x,
+ stride_x, total_padding_x)
+ output_size[1] = _compute_output_resolution(input_size[1], kernel_size_y,
+ stride_y, total_padding_y)
+
+ logging.vlog(3, 'output_size = %s', output_size)
+ node_info[current] = _node_info(order, node_def, input_size, output_size)
+
+ return order, input_size, output_size
+
+
+def get_compute_order(graph_def, input_node_name='', input_node_size=None):
+ """Computes order of computation for a given CNN graph.
+
+ Optionally, the function may also compute the input and output feature map
+ resolutions at each node. In this case, input_node_name and input_node_size
+ must be set. Note that if a node's op type is unknown, the input and output
+ resolutions are ignored and set to None.
Args:
graph_def: GraphDef object.
+ input_node_name: Name of node with fixed input resolution (optional). This
+ is usually the node name for the input image in a CNN.
+ input_node_size: 2D list of integers, fixed input resolution to use
+ (optional). This is usually the input resolution used for the input image
+ in a CNN (common examples are: [224, 224], [299, 299], [321, 321]).
Returns:
- map: name -> {order, node}
+ node_info: Default dict keyed by node name, mapping to a named tuple with
+ the following fields:
+ - order: Integer denoting topological order;
+ - node: NodeDef for the given node;
+ - input_size: 2D list of integers, denoting the input spatial resolution
+ to the node;
+ - output_size: 2D list of integers, denoting the output spatial resolution
+ of the node.
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
"""
- helper = GraphDefHelper(graph_def)
- seen = collections.defaultdict(_NodeEntry)
+ name_to_node = parse_graph_nodes(graph_def)
+ node_info = collections.defaultdict(_node_info)
for each in graph_def.node:
- _get_computed_nodes(helper, each.name, seen)
- return seen
+ _get_computed_nodes(name_to_node, each.name, node_info, input_node_name,
+ input_node_size)
+ return node_info, name_to_node
diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order_test.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order_test.py
new file mode 100644
index 0000000000..94c992ad21
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order_test.py
@@ -0,0 +1,152 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for graph_compute_order module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import slim
+from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+def create_test_network():
+ """Convolutional neural network for test.
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
+ # Left branch before first addition.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ # Right branch before first addition.
+ l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]], name='L2_pad')
+ l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
+ l3 = slim.max_pool2d(l2, [3, 3], stride=2, scope='L3', padding='SAME')
+ # First addition.
+ l4 = nn.relu(l1 + l3, name='L4_relu')
+ # Left branch after first addition.
+ l5 = slim.conv2d(l4, 1, [1, 1], stride=2, scope='L5', padding='SAME')
+ # Right branch after first addition.
+ l6 = slim.conv2d(l4, 1, [3, 3], stride=2, scope='L6', padding='SAME')
+ # Final addition.
+ gen_math_ops.add(l5, l6, name='L7_add')
+
+ return g
+
+
+class GraphComputeOrderTest(test.TestCase):
+
+ def check_topological_sort_and_sizes(self,
+ node_info,
+ expected_input_sizes=None,
+ expected_output_sizes=None):
+ """Helper function to check topological sorting and sizes are correct.
+
+ The arguments expected_input_sizes and expected_output_sizes are used to
+ check that the sizes are correct, if they are given.
+
+ Args:
+ node_info: Default dict keyed by node name, mapping to a named tuple with
+ the following keys: {order, node, input_size, output_size}.
+ expected_input_sizes: Dict mapping node names to expected input sizes
+ (optional).
+ expected_output_sizes: Dict mapping node names to expected output sizes
+ (optional).
+ """
+ # Loop over nodes in sorted order, collecting those that were already seen.
+ # These will be used to make sure that the graph is topologically sorted.
+ # At the same time, we construct dicts from node name to input/output size,
+ # which will be used to check those.
+ already_seen_nodes = []
+ input_sizes = {}
+ output_sizes = {}
+ for _, (_, node, input_size, output_size) in sorted(
+ node_info.items(), key=lambda x: x[1].order):
+ for inp_name in node.input:
+ # Since the graph is topologically sorted, the inputs to the current
+ # node must have been seen beforehand.
+ self.assertIn(inp_name, already_seen_nodes)
+ input_sizes[node.name] = input_size
+ output_sizes[node.name] = output_size
+ already_seen_nodes.append(node.name)
+
+ # Check input sizes, if desired.
+ if expected_input_sizes is not None:
+ for k, v in expected_input_sizes.items():
+ self.assertIn(k, input_sizes)
+ self.assertEqual(input_sizes[k], v)
+
+ # Check output sizes, if desired.
+ if expected_output_sizes is not None:
+ for k, v in expected_output_sizes.items():
+ self.assertIn(k, output_sizes)
+ self.assertEqual(output_sizes[k], v)
+
+ def testGraphOrderIsCorrect(self):
+ """Tests that the order and sizes of create_test_network() are correct."""
+
+ graph_def = create_test_network().as_graph_def()
+
+ # Case 1: Input node name/size are not given.
+ node_info, _ = receptive_field.get_compute_order(graph_def)
+ self.check_topological_sort_and_sizes(node_info)
+
+ # Case 2: Input node name is given, but not size.
+ node_info, _ = receptive_field.get_compute_order(
+ graph_def, input_node_name='input_image')
+ self.check_topological_sort_and_sizes(node_info)
+
+ # Case 3: Input node name and size (224) are given.
+ node_info, _ = receptive_field.get_compute_order(
+ graph_def, input_node_name='input_image', input_node_size=[224, 224])
+ expected_input_sizes = {
+ 'input_image': None,
+ 'L1/Conv2D': [224, 224],
+ 'L2_pad': [224, 224],
+ 'L2/Conv2D': [225, 225],
+ 'L3/MaxPool': [112, 112],
+ 'L4_relu': [56, 56],
+ 'L5/Conv2D': [56, 56],
+ 'L6/Conv2D': [56, 56],
+ 'L7_add': [28, 28],
+ }
+ expected_output_sizes = {
+ 'input_image': [224, 224],
+ 'L1/Conv2D': [56, 56],
+ 'L2_pad': [225, 225],
+ 'L2/Conv2D': [112, 112],
+ 'L3/MaxPool': [56, 56],
+ 'L4_relu': [56, 56],
+ 'L5/Conv2D': [28, 28],
+ 'L6/Conv2D': [28, 28],
+ 'L7_add': [28, 28],
+ }
+ self.check_topological_sort_and_sizes(node_info, expected_input_sizes,
+ expected_output_sizes)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py
new file mode 100644
index 0000000000..44998b3b65
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py
@@ -0,0 +1,297 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to parse RF-related parameters from TF layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from tensorflow.contrib.util import make_ndarray
+from tensorflow.python.platform import tf_logging as logging
+
+# White-listed layer operations, which do not affect the receptive field
+# computation.
+_UNCHANGED_RF_LAYER_OPS = [
+ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor",
+ "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu",
+ "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2"
+]
+
+# Different ways in which padding modes may be spelled.
+_VALID_PADDING = ["VALID", b"VALID"]
+_SAME_PADDING = ["SAME", b"SAME"]
+
+
+def _stride_size(node):
+ """Computes stride size given a TF node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+
+ Returns:
+ stride_x: Stride size for horizontal direction (integer).
+ stride_y: Stride size for vertical direction (integer).
+ """
+ strides_attr = node.attr["strides"]
+ logging.vlog(4, "strides_attr = %s", strides_attr)
+ stride_y = strides_attr.list.i[1]
+ stride_x = strides_attr.list.i[2]
+ return stride_x, stride_y
+
+
+def _conv_kernel_size(node, name_to_node):
+ """Computes kernel size given a TF convolution or pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+
+ Raises:
+ ValueError: If the weight layer node is invalid.
+ """
+ weights_layer_read_name = node.input[1]
+ if not weights_layer_read_name.endswith("/read"):
+ raise ValueError(
+ "Weight layer's name input to conv layer does not end with '/read'")
+ weights_layer_param_name = weights_layer_read_name[:-5]
+ weights_node = name_to_node[weights_layer_param_name]
+ if weights_node.op != "VariableV2":
+ raise ValueError("Weight layer is not of type VariableV2")
+ shape = weights_node.attr["shape"]
+ logging.vlog(4, "weight shape = %s", shape)
+ kernel_size_y = shape.shape.dim[0].size
+ kernel_size_x = shape.shape.dim[1].size
+ return kernel_size_x, kernel_size_y
+
+
+def _padding_size_conv_pool(node, kernel_size, stride, input_resolution=None):
+ """Computes padding size given a TF convolution or pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ kernel_size: Kernel size of node (integer).
+ stride: Stride size of node (integer).
+ input_resolution: Input resolution to assume, if not None (integer).
+
+ Returns:
+ total_padding: Total padding size (integer).
+ padding: Padding size, applied to the left or top (integer).
+
+ Raises:
+ ValueError: If padding is invalid.
+ """
+ # In this case, we need to carefully consider the different TF padding modes.
+ # The padding depends on kernel size, and may depend on input size. If it
+ # depends on input size and input_resolution is None, we raise an exception.
+ padding_attr = node.attr["padding"]
+ logging.vlog(4, "padding_attr = %s", padding_attr)
+ if padding_attr.s in _VALID_PADDING:
+ total_padding = 0
+ padding = 0
+ elif padding_attr.s in _SAME_PADDING:
+ if input_resolution is None:
+ # In this case, we do not know the input resolution, so we can only know
+ # the padding in some special cases.
+ if kernel_size == 1:
+ total_padding = 0
+ padding = 0
+ elif stride == 1:
+ total_padding = kernel_size - 1
+ padding = int(math.floor(float(total_padding) / 2))
+ elif stride == 2 and kernel_size % 2 == 0:
+ # In this case, we can be sure of the left/top padding, but not of the
+ # total padding.
+ total_padding = None
+ padding = int(math.floor((float(kernel_size) - 1) / 2))
+ else:
+ total_padding = None
+ padding = None
+ logging.warning(
+ "Padding depends on input size, which means that the effective "
+ "padding may be different depending on the input image "
+ "dimensionality. In this case, alignment check will be skipped. If"
+ " you know the input resolution, please set it.")
+ else:
+ # First, compute total_padding based on documentation.
+ if input_resolution % stride == 0:
+ total_padding = int(max(float(kernel_size - stride), 0.0))
+ else:
+ total_padding = int(
+ max(float(kernel_size - (input_resolution % stride)), 0.0))
+ # Then, compute left/top padding.
+ padding = int(math.floor(float(total_padding) / 2))
+
+ else:
+ raise ValueError("Invalid padding operation %s" % padding_attr.s)
+ return total_padding, padding
+
+
+def _pool_kernel_size(node):
+ """Computes kernel size given a TF pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+
+ Raises:
+ ValueError: If pooling is invalid.
+ """
+ ksize = node.attr["ksize"]
+ kernel_size_y = ksize.list.i[1]
+ kernel_size_x = ksize.list.i[2]
+ if ksize.list.i[0] != 1:
+ raise ValueError("pool ksize for first dim is not 1")
+ if ksize.list.i[3] != 1:
+ raise ValueError("pool ksize for last dim is not 1")
+ return kernel_size_x, kernel_size_y
+
+
+def _padding_size_pad_layer(node, name_to_node):
+ """Computes padding size given a TF padding node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
+
+ Returns:
+ total_padding_x: Total padding size for horizontal direction (integer).
+ padding_x: Padding size for horizontal direction, left side (integer).
+ total_padding_y: Total padding size for vertical direction (integer).
+ padding_y: Padding size for vertical direction, top side (integer).
+
+ Raises:
+ ValueError: If padding layer is invalid.
+ """
+ paddings_layer_name = node.input[1]
+ if not paddings_layer_name.endswith("/paddings"):
+ raise ValueError("Padding layer name does not end with '/paddings'")
+ paddings_node = name_to_node[paddings_layer_name]
+ if paddings_node.op != "Const":
+ raise ValueError("Padding op is not Const")
+ value = paddings_node.attr["value"]
+ t = make_ndarray(value.tensor)
+ padding_y = t[1][0]
+ padding_x = t[2][0]
+ total_padding_y = padding_y + t[1][1]
+ total_padding_x = padding_x + t[2][1]
+ if (t[0][0] != 0) or (t[0][1] != 0):
+ raise ValueError("padding is not zero for first tensor dim")
+ if (t[3][0] != 0) or (t[3][1] != 0):
+ raise ValueError("padding is not zero for last tensor dim")
+ return total_padding_x, padding_x, total_padding_y, padding_y
+
+
+def get_layer_params(node, name_to_node, input_resolution=None, force=False):
+ """Gets layer parameters relevant for RF computation.
+
+ Currently, only these nodes are supported:
+ - Conv2D
+ - DepthwiseConv2dNative
+ - Pad
+ - MaxPool
+ - AvgPool
+ - all nodes listed in _UNCHANGED_RF_LAYER_OPS
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
+ input_resolution: List with 2 dimensions, denoting the height/width of the
+ input feature map to this layer. If set to None, then the padding may be
+ undefined (in tensorflow, SAME padding depends on input spatial
+ resolution).
+ force: If True, the function does not raise a ValueError if the layer op is
+ unknown. Instead, in this case it sets each of the returned parameters to
+ None.
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+ stride_x: Stride size for horizontal direction (integer).
+ stride_y: Stride size for vertical direction (integer).
+ padding_x: Padding size for horizontal direction, left side (integer).
+ padding_y: Padding size for vertical direction, top side (integer).
+ total_padding_x: Total padding size for horizontal direction (integer).
+ total_padding_y: Total padding size for vertical direction (integer).
+
+ Raises:
+ ValueError: If layer op is unknown and force is False.
+ """
+ logging.vlog(3, "node.name = %s", node.name)
+ logging.vlog(3, "node.op = %s", node.op)
+ logging.vlog(4, "node = %s", node)
+ if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
+ stride_x, stride_y = _stride_size(node)
+ kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node)
+ # Compute the padding for this node separately for each direction.
+ total_padding_x, padding_x = _padding_size_conv_pool(
+ node, kernel_size_x, stride_x, input_resolution[1]
+ if input_resolution is not None else None)
+ total_padding_y, padding_y = _padding_size_conv_pool(
+ node, kernel_size_y, stride_y, input_resolution[0]
+ if input_resolution is not None else None)
+ elif node.op == "Pad":
+ # Kernel and stride are simply 1 in this case.
+ kernel_size_x = 1
+ kernel_size_y = 1
+ stride_x = 1
+ stride_y = 1
+ total_padding_x, padding_x, total_padding_y, padding_y = (
+ _padding_size_pad_layer(node, name_to_node))
+ elif node.op == "MaxPool" or node.op == "AvgPool":
+ stride_x, stride_y = _stride_size(node)
+ kernel_size_x, kernel_size_y = _pool_kernel_size(node)
+ # Compute the padding for this node separately for each direction.
+ total_padding_x, padding_x = _padding_size_conv_pool(
+ node, kernel_size_x, stride_x, input_resolution[1]
+ if input_resolution is not None else None)
+ total_padding_y, padding_y = _padding_size_conv_pool(
+ node, kernel_size_y, stride_y, input_resolution[0]
+ if input_resolution is not None else None)
+ elif node.op in _UNCHANGED_RF_LAYER_OPS:
+ # These nodes do not modify the RF parameters.
+ kernel_size_x = 1
+ kernel_size_y = 1
+ stride_x = 1
+ stride_y = 1
+ total_padding_x = 0
+ padding_x = 0
+ total_padding_y = 0
+ padding_y = 0
+ else:
+ if force:
+ kernel_size_x = None
+ kernel_size_y = None
+ stride_x = None
+ stride_y = None
+ total_padding_x = None
+ padding_x = None
+ total_padding_y = None
+ padding_y = None
+ else:
+ raise ValueError("Unknown layer for operation '%s': %s" % (node.name,
+ node.op))
+ return (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
+ padding_y, total_padding_x, total_padding_y)
diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters_test.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters_test.py
new file mode 100644
index 0000000000..369758a284
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters_test.py
@@ -0,0 +1,149 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for parse_layer_parameters module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import slim
+from tensorflow.contrib.receptive_field.python.util import graph_compute_order
+from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+def create_test_network():
+ """Convolutional neural network for test.
+
+ Returns:
+ name_to_node: Dict keyed by node name, each entry containing the node's
+ NodeDef.
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
+ # Left branch before first addition.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ # Right branch before first addition.
+ l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]], name='L2_pad')
+ l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
+ l3 = slim.max_pool2d(l2, [3, 3], stride=2, scope='L3', padding='SAME')
+ # First addition.
+ l4 = nn.relu(l1 + l3, name='L4_relu')
+ # Left branch after first addition.
+ l5 = slim.conv2d(l4, 1, [1, 1], stride=2, scope='L5', padding='SAME')
+ # Right branch after first addition.
+ l6 = slim.conv2d(l4, 1, [3, 3], stride=2, scope='L6', padding='SAME')
+ # Final addition.
+ gen_math_ops.add(l5, l6, name='L7_add')
+
+ name_to_node = graph_compute_order.parse_graph_nodes(g.as_graph_def())
+ return name_to_node
+
+
+class ParseLayerParametersTest(test.TestCase):
+
+ def testParametersAreParsedCorrectly(self):
+ """Checks parameters from create_test_network() are parsed correctly."""
+ name_to_node = create_test_network()
+
+ # L1.
+ l1_node_name = 'L1/Conv2D'
+ l1_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l1_node_name], name_to_node)
+ expected_l1_params = (1, 1, 4, 4, 0, 0, 0, 0)
+ self.assertEqual(l1_params, expected_l1_params)
+
+ # L2 padding.
+ l2_pad_name = 'L2_pad'
+ l2_pad_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l2_pad_name], name_to_node)
+ expected_l2_pad_params = (1, 1, 1, 1, 1, 1, 1, 1)
+ self.assertEqual(l2_pad_params, expected_l2_pad_params)
+
+ # L2.
+ l2_node_name = 'L2/Conv2D'
+ l2_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l2_node_name], name_to_node)
+ expected_l2_params = (3, 3, 2, 2, 0, 0, 0, 0)
+ self.assertEqual(l2_params, expected_l2_params)
+
+ # L3.
+ l3_node_name = 'L3/MaxPool'
+ # - Without knowing input size.
+ l3_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l3_node_name], name_to_node)
+ expected_l3_params = (3, 3, 2, 2, None, None, None, None)
+ self.assertEqual(l3_params, expected_l3_params)
+ # - Input size is even.
+ l3_even_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l3_node_name], name_to_node, input_resolution=[4, 4])
+ expected_l3_even_params = (3, 3, 2, 2, 0, 0, 1, 1)
+ self.assertEqual(l3_even_params, expected_l3_even_params)
+ # - Input size is odd.
+ l3_odd_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l3_node_name], name_to_node, input_resolution=[5, 5])
+ expected_l3_odd_params = (3, 3, 2, 2, 1, 1, 2, 2)
+ self.assertEqual(l3_odd_params, expected_l3_odd_params)
+
+ # L4.
+ l4_node_name = 'L4_relu'
+ l4_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l4_node_name], name_to_node)
+ expected_l4_params = (1, 1, 1, 1, 0, 0, 0, 0)
+ self.assertEqual(l4_params, expected_l4_params)
+
+ # L5.
+ l5_node_name = 'L5/Conv2D'
+ l5_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l5_node_name], name_to_node)
+ expected_l5_params = (1, 1, 2, 2, 0, 0, 0, 0)
+ self.assertEqual(l5_params, expected_l5_params)
+
+ # L6.
+ l6_node_name = 'L6/Conv2D'
+ # - Without knowing input size.
+ l6_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l6_node_name], name_to_node)
+ expected_l6_params = (3, 3, 2, 2, None, None, None, None)
+ self.assertEqual(l6_params, expected_l6_params)
+ # - Input size is even.
+ l6_even_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l6_node_name], name_to_node, input_resolution=[4, 4])
+ expected_l6_even_params = (3, 3, 2, 2, 0, 0, 1, 1)
+ self.assertEqual(l6_even_params, expected_l6_even_params)
+ # - Input size is odd.
+ l6_odd_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l6_node_name], name_to_node, input_resolution=[5, 5])
+ expected_l6_odd_params = (3, 3, 2, 2, 1, 1, 2, 2)
+ self.assertEqual(l6_odd_params, expected_l6_odd_params)
+
+ # L7.
+ l7_node_name = 'L7_add'
+ l7_params = parse_layer_parameters.get_layer_params(
+ name_to_node[l7_node_name], name_to_node)
+ expected_l7_params = (1, 1, 1, 1, 0, 0, 0, 0)
+ self.assertEqual(l7_params, expected_l7_params)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
index 0955e1fc39..b9bd2f0976 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
@@ -23,243 +23,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
+import numpy as np
from tensorflow.contrib.receptive_field.python.util import graph_compute_order
-from tensorflow.contrib.util import make_ndarray
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
from tensorflow.python.framework import ops as framework_ops
-import numpy as np
-
-# White-listed layer operations, which do not affect the receptive field
-# computation.
-_UNCHANGED_RF_LAYER_OPS = [
- "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor",
- "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu",
- "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2"
-]
-
-# Different ways in which padding modes may be spelled.
-_VALID_PADDING = ["VALID", b"VALID"]
-_SAME_PADDING = ["SAME", b"SAME"]
-
-
-def _stride_size(node):
- """Computes stride size given a TF node.
-
- Args:
- node: Tensorflow node (NodeDef proto).
-
- Returns:
- stride_x: Stride size for horizontal direction (integer).
- stride_y: Stride size for vertical direction (integer).
- """
- strides_attr = node.attr["strides"]
- logging.vlog(4, "strides_attr = %s", strides_attr)
- stride_y = strides_attr.list.i[1]
- stride_x = strides_attr.list.i[2]
- return stride_x, stride_y
-
-
-def _conv_kernel_size(node, name_to_order_node):
- """Computes kernel size given a TF convolution or pooling node.
-
- Args:
- node: Tensorflow node (NodeDef proto).
- name_to_order_node: Map from name to {order, node}. Output of
- graph_compute_order.get_compute_order().
-
- Returns:
- kernel_size_x: Kernel size for horizontal direction (integer).
- kernel_size_y: Kernel size for vertical direction (integer).
-
- Raises:
- ValueError: If the weight layer node is invalid.
- """
- weights_layer_read_name = node.input[1]
- if not weights_layer_read_name.endswith("/read"):
- raise ValueError(
- "Weight layer's name input to conv layer does not end with '/read'")
- weights_layer_param_name = weights_layer_read_name[:-5]
- weights_node = name_to_order_node[weights_layer_param_name].node
- if weights_node.op != "VariableV2":
- raise ValueError("Weight layer is not of type VariableV2")
- shape = weights_node.attr["shape"]
- logging.vlog(4, "weight shape = %s", shape)
- kernel_size_y = shape.shape.dim[0].size
- kernel_size_x = shape.shape.dim[1].size
- return kernel_size_x, kernel_size_y
-
-
-def _padding_size_conv_pool(node, kernel_size, stride):
- """Computes padding size given a TF convolution or pooling node.
-
- Args:
- node: Tensorflow node (NodeDef proto).
- kernel_size: Kernel size of node (integer).
- stride: Stride size of node (integer).
-
- Returns:
- padding: Padding size (integer).
-
- Raises:
- ValueError: If padding is invalid.
- """
- # In this case, we need to carefully consider the different TF padding modes.
- # The padding depends on kernel size, and may depend on input size. If it
- # depends on input size, we raise an exception.
- padding_attr = node.attr["padding"]
- logging.vlog(4, "padding_attr = %s", padding_attr)
- if padding_attr.s in _VALID_PADDING:
- padding = 0
- elif padding_attr.s in _SAME_PADDING:
- if kernel_size == 1:
- padding = 0
- elif stride == 1:
- padding = int(math.floor((float(kernel_size) - 1) / 2))
- elif stride == 2 and kernel_size % 2 == 0:
- padding = int(math.floor((float(kernel_size) - 1) / 2))
- else:
- padding = None
- logging.warning(
- "Padding depends on input size, which means that the effective "
- "padding may be different depending on the input image "
- "dimensionality. In this case, alignment check will be skipped.")
- else:
- raise ValueError("Invalid padding operation %s" % padding_attr.s)
- return padding
-
-
-def _pool_kernel_size(node):
- """Computes kernel size given a TF pooling node.
-
- Args:
- node: Tensorflow node (NodeDef proto).
-
- Returns:
- kernel_size_x: Kernel size for horizontal direction (integer).
- kernel_size_y: Kernel size for vertical direction (integer).
-
- Raises:
- ValueError: If pooling is invalid.
- """
- ksize = node.attr["ksize"]
- kernel_size_y = ksize.list.i[1]
- kernel_size_x = ksize.list.i[2]
- if ksize.list.i[0] != 1:
- raise ValueError("pool ksize for first dim is not 1")
- if ksize.list.i[3] != 1:
- raise ValueError("pool ksize for last dim is not 1")
- return kernel_size_x, kernel_size_y
-
-
-def _padding_size_pad_layer(node, name_to_order_node):
- """Computes padding size given a TF padding node.
-
- Args:
- node: Tensorflow node (NodeDef proto).
- name_to_order_node: Map from name to {order, node}. Output of
- graph_compute_order.get_compute_order().
-
- Returns:
- padding_x: Padding size for horizontal direction (integer).
- padding_y: Padding size for vertical direction (integer).
-
- Raises:
- ValueError: If padding layer is invalid.
- """
- paddings_layer_name = node.input[1]
- if not paddings_layer_name.endswith("/paddings"):
- raise ValueError("Padding layer name does not end with '/paddings'")
- paddings_node = name_to_order_node[paddings_layer_name].node
- if paddings_node.op != "Const":
- raise ValueError("Padding op is not Const")
- value = paddings_node.attr["value"]
- t = make_ndarray(value.tensor)
- padding_y = t[1][0]
- padding_x = t[2][0]
- if t[0][0] != 0:
- raise ValueError("padding is not zero for first tensor dim")
- if t[3][0] != 0:
- raise ValueError("padding is not zero for last tensor dim")
- return padding_x, padding_y
-
-
-def _get_layer_params(node, name_to_order_node):
- """Gets layer parameters relevant for RF computation.
-
- Currently, only these nodes are supported:
- - Conv2D
- - DepthwiseConv2dNative
- - Pad
- - MaxPool
- - AvgPool
- - all nodes listed in _UNCHANGED_RF_LAYER_OPS
-
- Args:
- node: Tensorflow node (NodeDef proto).
- name_to_order_node: Map from name to {order, node}. Output of
- graph_compute_order.get_compute_order().
-
- Returns:
- kernel_size_x: Kernel size for horizontal direction (integer).
- kernel_size_y: Kernel size for vertical direction (integer).
- stride_x: Stride size for horizontal direction (integer).
- stride_y: Stride size for vertical direction (integer).
- padding_x: Padding size for horizontal direction (integer).
- padding_y: Padding size for vertical direction (integer).
-
- Raises:
- ValueError: If layer op is unknown.
- """
- logging.vlog(3, "node.op = %s", node.op)
- logging.vlog(4, "node = %s", node)
- if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
- stride_x, stride_y = _stride_size(node)
- kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node)
- # Compute the padding for this node separately for each direction.
- padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
- padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
- elif node.op == "Pad":
- # Kernel and stride are simply 1 in this case.
- kernel_size_x = 1
- kernel_size_y = 1
- stride_x = 1
- stride_y = 1
- padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node)
- elif node.op == "MaxPool" or node.op == "AvgPool":
- stride_x, stride_y = _stride_size(node)
- kernel_size_x, kernel_size_y = _pool_kernel_size(node)
- # Compute the padding for this node separately for each direction.
- padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
- padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
- elif node.op in _UNCHANGED_RF_LAYER_OPS:
- # These nodes do not modify the RF parameters.
- kernel_size_x = 1
- kernel_size_y = 1
- stride_x = 1
- stride_y = 1
- padding_x = 0
- padding_y = 0
- else:
- raise ValueError("Unknown layer for operation '%s': %s" % (node.name,
- node.op))
- return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
-
-
-def _reverse_sort_by_order(name_to_order_node):
- """Sorts map of name_to_order_node nodes in reverse order.
-
- The output is such that the nodes in name_to_order_node are sorted in
- descending order of the "order" field.
-
- Args:
- name_to_order_node: Map from name to {order, node}. Output of
- graph_compute_order.get_compute_order().
-
- Returns:
- sorted_name_to_order_node: Sorted version of the input, in descending order.
- """
- return sorted(name_to_order_node.items(), key=lambda x: -x[1].order)
+from tensorflow.python.platform import tf_logging as logging
def _get_rf_size_node_input(stride, kernel_size, rf_size_output):
@@ -308,7 +76,7 @@ def _get_effective_padding_node_input(stride, padding,
return stride * effective_padding_output + padding
-class ReceptiveField:
+class ReceptiveField(object):
"""Receptive field of a convolutional neural network.
Args:
@@ -350,8 +118,8 @@ class ReceptiveField:
raise ValueError("Dimensionality of the feature coordinates `y` (%d) "
"does not match dimensionality of `axis` (%d)" %
(y.shape[-1], len(axis)))
- return - self.padding[axis] + y * self.stride[axis] + \
- (self.size[axis] - 1) / 2
+ return -self.padding[axis] + y * self.stride[axis] + (
+ self.size[axis] - 1) / 2
def compute_feature_coordinates(self, x, axis=None):
"""Computes the position of a feature given the center of a receptive field.
@@ -380,8 +148,8 @@ class ReceptiveField:
raise ValueError("Dimensionality of the input center coordinates `x` "
"(%d) does not match dimensionality of `axis` (%d)" %
(x.shape[-1], len(axis)))
- return (x + self.padding[axis] + (1 - self.size[axis]) / 2) / \
- self.stride[axis]
+ return (x + self.padding[axis] +
+ (1 - self.size[axis]) / 2) / self.stride[axis]
def __iter__(self):
return iter(np.concatenate([self.size, self.stride, self.padding]))
@@ -390,7 +158,8 @@ class ReceptiveField:
def compute_receptive_field_from_graph_def(graph_def,
input_node,
output_node,
- stop_propagation=None):
+ stop_propagation=None,
+ input_resolution=None):
"""Computes receptive field (RF) parameters from a Graph or GraphDef object.
The algorithm stops the calculation of the receptive field whenever it
@@ -403,8 +172,14 @@ def compute_receptive_field_from_graph_def(graph_def,
graph_def: Graph or GraphDef object.
input_node: Name of the input node or Tensor object from graph.
output_node: Name of the output node or Tensor object from graph.
- stop_propagation: List of operation or scope names for which to stop the
+ stop_propagation: List of operations or scope names for which to stop the
propagation of the receptive field.
+ input_resolution: 2D list. If the input resolution to the model is fixed and
+ known, this may be set. This is helpful for cases where the RF parameters
+ vary depending on the input resolution (this happens since SAME padding in
+ tensorflow depends on input resolution in general). If this is None, it is
+ assumed that the input resolution is unknown, so some RF parameters may be
+ unknown (depending on the model architecture).
Returns:
rf_size_x: Receptive field size of network in the horizontal direction, with
@@ -438,11 +213,13 @@ def compute_receptive_field_from_graph_def(graph_def,
stop_propagation = stop_propagation or []
# Computes order of computation for a given graph.
- name_to_order_node = graph_compute_order.get_compute_order(
- graph_def=graph_def)
+ node_info, name_to_node = graph_compute_order.get_compute_order(
+ graph_def=graph_def,
+ input_node_name=input_node,
+ input_node_size=input_resolution)
# Sort in reverse topological order.
- order = _reverse_sort_by_order(name_to_order_node)
+ ordered_node_info = sorted(node_info.items(), key=lambda x: -x[1].order)
# Dictionaries to keep track of receptive field, effective stride and
# effective padding of different nodes.
@@ -471,7 +248,7 @@ def compute_receptive_field_from_graph_def(graph_def,
# alignment checks are skipped, and the effective padding is None.
undefined_padding = False
- for _, (o, node) in order:
+ for _, (o, node, _, _) in ordered_node_info:
if node:
logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op))
else:
@@ -497,13 +274,14 @@ def compute_receptive_field_from_graph_def(graph_def,
continue
# Get params for this layer.
- kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = (
- _get_layer_params(node, name_to_order_node))
+ (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
+ padding_y, _, _) = parse_layer_parameters.get_layer_params(
+ node, name_to_node, node_info[node.name].input_size)
logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, "
"stride_x = %s, stride_y = %s, "
- "padding_x = %s, padding_y = %s" %
+ "padding_x = %s, padding_y = %s, input size = %s" %
(kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
- padding_y))
+ padding_y, node_info[node.name].input_size))
if padding_x is None or padding_y is None:
undefined_padding = True
@@ -525,12 +303,19 @@ def compute_receptive_field_from_graph_def(graph_def,
else:
effective_padding_input_x = None
effective_padding_input_y = None
+ logging.vlog(
+ 4, "rf_size_input_x = %s, rf_size_input_y = %s, "
+ "effective_stride_input_x = %s, effective_stride_input_y = %s, "
+ "effective_padding_input_x = %s, effective_padding_input_y = %s" %
+ (rf_size_input_x, rf_size_input_y, effective_stride_input_x,
+ effective_stride_input_y, effective_padding_input_x,
+ effective_padding_input_y))
# Loop over this node's inputs and potentially propagate information down.
for inp_name in node.input:
# Stop the propagation of the receptive field.
if any(inp_name.startswith(stop) for stop in stop_propagation):
- logging.vlog(3, "Skipping explicitly ignored node %s.", node.name)
+ logging.vlog(3, "Skipping explicitly ignored node %s.", inp_name)
continue
logging.vlog(4, "inp_name = %s", inp_name)
@@ -539,58 +324,66 @@ def compute_receptive_field_from_graph_def(graph_def,
# can be safely ignored.
continue
- inp_node = name_to_order_node[inp_name].node
+ inp_node = name_to_node[inp_name]
logging.vlog(4, "inp_node = \n%s", inp_node)
- if inp_node.name in rf_sizes_x:
- assert inp_node.name in rf_sizes_y, (
- "Node %s is in rf_sizes_x, but "
- "not in rf_sizes_y" % inp_node.name)
+ if inp_name in rf_sizes_x:
+ assert inp_name in rf_sizes_y, ("Node %s is in rf_sizes_x, but "
+ "not in rf_sizes_y" % inp_name)
+ logging.vlog(
+ 4, "rf_sizes_x[inp_name] = %s,"
+ " rf_sizes_y[inp_name] = %s, "
+ "effective_strides_x[inp_name] = %s,"
+ " effective_strides_y[inp_name] = %s, "
+ "effective_paddings_x[inp_name] = %s,"
+ " effective_paddings_y[inp_name] = %s" %
+ (rf_sizes_x[inp_name], rf_sizes_y[inp_name],
+ effective_strides_x[inp_name], effective_strides_y[inp_name],
+ effective_paddings_x[inp_name], effective_paddings_y[inp_name]))
# This node was already discovered through a previous path, so we need
# to make sure that graph is aligned. This alignment check is skipped
# if the padding is not defined, since in this case alignment cannot
# be checked.
if not undefined_padding:
- if effective_strides_x[inp_node.name] != effective_stride_input_x:
+ if effective_strides_x[inp_name] != effective_stride_input_x:
raise ValueError(
"Graph is not aligned since effective stride from different "
"paths is different in horizontal direction")
- if effective_strides_y[inp_node.name] != effective_stride_input_y:
+ if effective_strides_y[inp_name] != effective_stride_input_y:
raise ValueError(
"Graph is not aligned since effective stride from different "
"paths is different in vertical direction")
- if (rf_sizes_x[inp_node.name] - 1
- ) / 2 - effective_paddings_x[inp_node.name] != (
+ if (rf_sizes_x[inp_name] - 1
+ ) / 2 - effective_paddings_x[inp_name] != (
rf_size_input_x - 1) / 2 - effective_padding_input_x:
raise ValueError(
"Graph is not aligned since center shift from different "
"paths is different in horizontal direction")
- if (rf_sizes_y[inp_node.name] - 1
- ) / 2 - effective_paddings_y[inp_node.name] != (
+ if (rf_sizes_y[inp_name] - 1
+ ) / 2 - effective_paddings_y[inp_name] != (
rf_size_input_y - 1) / 2 - effective_padding_input_y:
raise ValueError(
"Graph is not aligned since center shift from different "
"paths is different in vertical direction")
# Keep track of path with largest RF, for both directions.
- if rf_sizes_x[inp_node.name] < rf_size_input_x:
- rf_sizes_x[inp_node.name] = rf_size_input_x
- effective_strides_x[inp_node.name] = effective_stride_input_x
- effective_paddings_x[inp_node.name] = effective_padding_input_x
- if rf_sizes_y[inp_node.name] < rf_size_input_y:
- rf_sizes_y[inp_node.name] = rf_size_input_y
- effective_strides_y[inp_node.name] = effective_stride_input_y
- effective_paddings_y[inp_node.name] = effective_padding_input_y
+ if rf_sizes_x[inp_name] < rf_size_input_x:
+ rf_sizes_x[inp_name] = rf_size_input_x
+ effective_strides_x[inp_name] = effective_stride_input_x
+ effective_paddings_x[inp_name] = effective_padding_input_x
+ if rf_sizes_y[inp_name] < rf_size_input_y:
+ rf_sizes_y[inp_name] = rf_size_input_y
+ effective_strides_y[inp_name] = effective_stride_input_y
+ effective_paddings_y[inp_name] = effective_padding_input_y
else:
- assert inp_node.name not in rf_sizes_y, (
- "Node %s is in rf_sizes_y, but "
- "not in rf_sizes_x" % inp_node.name)
+ assert inp_name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
+ "not in rf_sizes_x" % inp_name)
# In this case, it is the first time we encounter this node. So we
# propagate the RF parameters.
- rf_sizes_x[inp_node.name] = rf_size_input_x
- rf_sizes_y[inp_node.name] = rf_size_input_y
- effective_strides_x[inp_node.name] = effective_stride_input_x
- effective_strides_y[inp_node.name] = effective_stride_input_y
- effective_paddings_x[inp_node.name] = effective_padding_input_x
- effective_paddings_y[inp_node.name] = effective_padding_input_y
+ rf_sizes_x[inp_name] = rf_size_input_x
+ rf_sizes_y[inp_name] = rf_size_input_y
+ effective_strides_x[inp_name] = effective_stride_input_x
+ effective_strides_y[inp_name] = effective_stride_input_y
+ effective_paddings_x[inp_name] = effective_padding_input_x
+ effective_paddings_y[inp_name] = effective_padding_input_y
if not found_output_node:
raise ValueError("Output node was not found")
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
index 1ea72c0f29..cf55da2723 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
@@ -44,8 +44,9 @@ def create_test_network_1():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
@@ -71,8 +72,9 @@ def create_test_network_2():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
@@ -95,8 +97,9 @@ def create_test_network_3():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch.
l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]])
l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID')
@@ -122,8 +125,9 @@ def create_test_network_4():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
@@ -146,8 +150,9 @@ def create_test_network_5():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Two convolutional layers, where the first one has non-square kernel.
l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID')
l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID')
@@ -167,8 +172,9 @@ def create_test_network_6():
"""
g = ops.Graph()
with g.as_default():
- # An 8x8 test image.
- x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
@@ -223,9 +229,9 @@ def create_test_network_8():
"""
g = ops.Graph()
with g.as_default():
- # A 16x16 test image.
+ # An input test image with unknown spatial resolution.
x = array_ops.placeholder(
- dtypes.float32, (1, 16, 16, 1), name='input_image')
+ dtypes.float32, (None, None, None, 1), name='input_image')
# Left branch before first addition.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch before first addition.
@@ -245,7 +251,38 @@ def create_test_network_8():
return g
-class RfUtilsTest(test.TestCase):
+def create_test_network_9():
+ """Aligned network for test, including an intermediate addition.
+
+ The graph is the same as create_test_network_8(), except that VALID padding is
+ changed to SAME.
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An input test image with unknown spatial resolution.
+ x = array_ops.placeholder(
+ dtypes.float32, (None, None, None, 1), name='input_image')
+ # Left branch before first addition.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='SAME')
+ # Right branch before first addition.
+ l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
+ l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='SAME')
+ # First addition.
+ l4 = nn.relu(l1 + l3)
+ # Left branch after first addition.
+ l5 = slim.conv2d(l4, 1, [1, 1], stride=2, scope='L5', padding='SAME')
+ # Right branch after first addition.
+ l6 = slim.conv2d(l4, 1, [3, 3], stride=2, scope='L6', padding='SAME')
+ # Final addition.
+ nn.relu(l5 + l6, name='output')
+
+ return g
+
+
+class ReceptiveFieldTest(test.TestCase):
def testComputeRFFromGraphDefAligned(self):
graph_def = create_test_network_1().as_graph_def()
@@ -285,7 +322,7 @@ class RfUtilsTest(test.TestCase):
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node)
- def testComputeRFFromGraphDefUnaligned2(self):
+ def testComputeRFFromGraphDefUndefinedPadding(self):
graph_def = create_test_network_4().as_graph_def()
input_node = 'input_image'
output_node = 'output'
@@ -300,6 +337,29 @@ class RfUtilsTest(test.TestCase):
self.assertEqual(effective_padding_x, None)
self.assertEqual(effective_padding_y, None)
+ def testComputeRFFromGraphDefFixedInputDim(self):
+ graph_def = create_test_network_4().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node, input_resolution=[9, 9]))
+ self.assertEqual(receptive_field_x, 3)
+ self.assertEqual(receptive_field_y, 3)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, 1)
+ self.assertEqual(effective_padding_y, 1)
+
+ def testComputeRFFromGraphDefUnalignedFixedInputDim(self):
+ graph_def = create_test_network_4().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ with self.assertRaises(ValueError):
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node, input_resolution=[8, 8])
+
def testComputeRFFromGraphDefNonSquareRF(self):
graph_def = create_test_network_5().as_graph_def()
input_node = 'input_image'
@@ -376,6 +436,22 @@ class RfUtilsTest(test.TestCase):
self.assertEqual(effective_padding_x, 5)
self.assertEqual(effective_padding_y, 5)
+ def testComputeRFFromGraphDefWithIntermediateAddNodeSamePaddingFixedInputDim(
+ self):
+ graph_def = create_test_network_9().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node, input_resolution=[17, 17]))
+ self.assertEqual(receptive_field_x, 11)
+ self.assertEqual(receptive_field_y, 11)
+ self.assertEqual(effective_stride_x, 8)
+ self.assertEqual(effective_stride_y, 8)
+ self.assertEqual(effective_padding_x, 5)
+ self.assertEqual(effective_padding_y, 5)
+
if __name__ == '__main__':
test.main()