aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/receptive_field
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-31 11:00:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-31 11:04:09 -0700
commit863163067ad38e093ae7c4557c5065f589eef70c (patch)
tree8d4f97f266950bc66daad93f9bbc0f76ee3830da /tensorflow/contrib/receptive_field
parent8dbd2b91f9b387e6f3b0084671e46325fc9915be (diff)
Introducing tf.contrib.receptive_field
PiperOrigin-RevId: 167160384
Diffstat (limited to 'tensorflow/contrib/receptive_field')
-rw-r--r--tensorflow/contrib/receptive_field/BUILD71
-rw-r--r--tensorflow/contrib/receptive_field/README.md164
-rw-r--r--tensorflow/contrib/receptive_field/__init__.py23
-rw-r--r--tensorflow/contrib/receptive_field/python/__init__.py19
-rw-r--r--tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py90
-rw-r--r--tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py460
-rw-r--r--tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py57
-rw-r--r--tensorflow/contrib/receptive_field/python/util/graph_compute_order.py88
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field.py481
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field_test.py221
10 files changed, 1674 insertions, 0 deletions
diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD
new file mode 100644
index 0000000000..ed2f3af08c
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/BUILD
@@ -0,0 +1,71 @@
+# Description:
+# Contains modules to compute receptive field parameters for CNN models.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+# Transitive dependencies of this target will be included in the pip package.
+py_library(
+ name = "receptive_field_pip",
+ deps = [
+ ":graph_compute_order_py",
+ ":receptive_field_py",
+ ],
+)
+
+py_library(
+ name = "graph_compute_order_py",
+ srcs = [
+ "__init__.py",
+ "python/util/graph_compute_order.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "receptive_field_py",
+ srcs = [
+ "__init__.py",
+ "python/util/receptive_field.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":graph_compute_order_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "receptive_field_test",
+ srcs = ["python/util/receptive_field_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":receptive_field_py",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/slim",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:nn",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md
new file mode 100644
index 0000000000..f7539ec145
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/README.md
@@ -0,0 +1,164 @@
+# Receptive field computation for convnets
+
+This library enables you to easily compute the receptive field parameters of
+your favorite convnet. You can use it to understand how big of an input image
+region your output features depend on. Better yet, using the parameters computed
+by the library, you can easily find the exact image region which is used to
+compute each convnet feature.
+
+## Basic usage
+
+The main function to be called is `compute_receptive_field_from_graph_def`,
+which will return the receptive field, effective stride and effective padding
+for both horizontal and vertical directions.
+
+For example, if your model is constructed using the function
+`my_model_construction()`, you can use the library as follows:
+
+```python
+import tensorflow as tf
+from tensorflow.contrib import receptive_field
+
+# Construct graph.
+g = tf.Graph()
+with g.as_default():
+ images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
+ my_model_construction(images)
+
+# Compute receptive field parameters.
+rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
+ receptive_field.compute_receptive_field_from_graph_def( \
+ g.as_graph_def(), 'input_image', 'my_output_endpoint')
+```
+
+Here's a simple example of computing the receptive field parameters for
+Inception-Resnet-v2. To get this to work, be sure to checkout
+[tensorflow/models](https://github.com/tensorflow/models), so that the Inception
+models are available to you. This can be done in three simple commands:
+
+```sh
+git clone https://github.com/tensorflow/models
+cd models/slim
+sudo python setup.py install_lib
+```
+
+You can then compute the receptive field parameters for Inception-Resnet-v2 as:
+
+```python
+from nets import inception
+import tensorflow as tf
+from tensorflow.contrib import receptive_field
+
+# Construct graph.
+g = tf.Graph()
+with g.as_default():
+ images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
+ inception.inception_resnet_v2_base(images)
+
+# Compute receptive field parameters.
+rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
+ receptive_field.compute_receptive_field_from_graph_def( \
+ g.as_graph_def(), 'input_image', 'InceptionResnetV2/Conv2d_7b_1x1/Relu')
+```
+
+This will give you `rf_x = rf_y = 3039`, `eff_stride_x = eff_stride_y = 32`, and
+`eff_pad_x = eff_pad_y = 1482`. This means that each feature that is output at
+the node `'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is computed from a region
+which is of size `3039x3039`. Further, by using the expressions
+
+```python
+center_x = -eff_pad_x + feature_x*eff_stride_x + (rf_x - 1)/2
+center_y = -eff_pad_y + feature_y*eff_stride_y + (rf_y - 1)/2
+```
+
+one can compute the center of the region in the input image that is used to
+compute the output feature at position `[feature_x, feature_y]`. For example,
+the feature at position `[0, 2]` at the output of the layer
+`'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is centered in the original image in
+the position `[37, 101]`.
+
+TODO: include link to derivations and definitions of different parameters.
+
+## Receptive field benchmark
+
+As you might expect, it is straightforward to run this library on the popular
+convnets, and gather their receptive fields. We provide a python script which
+does exactly that, available under `python/util/examples/rf_benchmark.py`.
+
+To get this to work, be sure to checkout
+[tensorflow/models](https://github.com/tensorflow/models) (see the 3-command
+instructions for this above). Then, simply:
+
+```sh
+cd python/util/examples
+python rf_benchmark.py --csv_path /tmp/rf_benchmark_results.csv
+```
+
+The script will write to stdout the receptive field parameters for many variants
+of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They
+are also written to the file `/tmp/rf_benchmark_results.csv`.
+
+TODO: include here a plot for receptive field sizes of different convnets.
+
+TODO: include table/link to pre-computed RF parameters.
+
+## Compute RF parameters from a graph pbtxt
+
+We also provide a utility to compute the receptive field parameters directly
+from a graph protobuf file.
+
+Have a `graph.pbtxt` file and want to compute its receptive field parameters? We
+got you covered. The only prerequisite is to install
+[google/protobuf](https://github.com/google/protobuf), which you probably
+already have if you're using tensorflow (otherwise, follow installation
+instructions [here](https://github.com/google/protobuf/tree/master/python)).
+
+This should work:
+
+```sh
+cd python/util/examples
+python compute_rf.py \
+ --graph_path /path/to/graph.pbtxt \
+ --output_path /path/to/output/rf_info.txt \
+ --input_node my_input_node \
+ --output_node my_output_node
+```
+
+Don't know how to generate a graph protobuf file? Take a look at the
+`write_inception_resnet_v2_graph.py` script, which shows how to save it for the
+Inception-Resnet-v2 model:
+
+```sh
+cd python/util/examples
+python write_inception_resnet_v2_graph.py --graph_dir /tmp --graph_filename graph.pbtxt
+```
+
+This will write the Inception-Resnet-v2 graph protobuf to `/tmp/graph.pbtxt`.
+
+For completeness, here's how you would use this file to get the receptive field
+parameters of the Inception-Resnet-v2 model:
+
+```sh
+cd python/util/examples
+python compute_rf.py \
+ --graph_path /tmp/graph.pbtxt \
+ --output_path /tmp/rf_info.txt \
+ --input_node input_image \
+ --output_node InceptionResnetV2/Conv2d_7b_1x1/Relu
+```
+
+This will write the receptive field parameters of the model to
+`/tmp/rf_info.txt`, which will look like:
+
+```sh
+Receptive field size (horizontal) = 3039
+Receptive field size (vertical) = 3039
+Effective stride (horizontal) = 32
+Effective stride (vertical) = 32
+Effective padding (horizontal) = 1482
+Effective padding (vertical) = 1482
+```
+
+## Authors
+
+Andr&eacute; Araujo (andrefaraujo@) and Mark Sandler
diff --git a/tensorflow/contrib/receptive_field/__init__.py b/tensorflow/contrib/receptive_field/__init__.py
new file mode 100644
index 0000000000..10745a6a53
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module to compute receptive field parameters for CNN tensorflow models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from tensorflow.contrib.receptive_field.python.util.graph_compute_order import get_compute_order
+from tensorflow.contrib.receptive_field.python.util.receptive_field import compute_receptive_field_from_graph_def
+# pylint: enable=unused-import
diff --git a/tensorflow/contrib/receptive_field/python/__init__.py b/tensorflow/contrib/receptive_field/python/__init__.py
new file mode 100644
index 0000000000..217047f92d
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module to compute receptive field parameters for CNN tensorflow models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py
new file mode 100644
index 0000000000..70a0d11dff
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Computes Receptive Field (RF) information given a graph protobuf.
+
+For an example of usage, see accompanying file compute_rf.sh
+"""
+
+import argparse
+import sys
+
+from google.protobuf import text_format
+
+from tensorflow.contrib import receptive_field
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.platform import app
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+
+cmd_args = None
+
+
+def _load_graphdef(path):
+ """Helper function to load GraphDef from file.
+
+ Args:
+ path: Path to pbtxt file.
+
+ Returns:
+ graph_def: A GraphDef object.
+ """
+ graph_def = graph_pb2.GraphDef()
+ pbstr = gfile.Open(path).read()
+ text_format.Parse(pbstr, graph_def)
+ return graph_def
+
+
+def main(unused_argv):
+
+ graph_def = _load_graphdef(cmd_args.graph_path)
+
+ (receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y,
+ effective_padding_x, effective_padding_y
+ ) = receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, cmd_args.input_node, cmd_args.output_node)
+
+ logging.info('Receptive field size (horizontal) = %s', receptive_field_x)
+ logging.info('Receptive field size (vertical) = %s', receptive_field_y)
+ logging.info('Effective stride (horizontal) = %s', effective_stride_x)
+ logging.info('Effective stride (vertical) = %s', effective_stride_y)
+ logging.info('Effective padding (horizontal) = %s', effective_padding_x)
+ logging.info('Effective padding (vertical) = %s', effective_padding_y)
+
+ f = gfile.GFile('%s' % cmd_args.output_path, 'w')
+ f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x)
+ f.write('Receptive field size (vertical) = %s\n' % receptive_field_y)
+ f.write('Effective stride (horizontal) = %s\n' % effective_stride_x)
+ f.write('Effective stride (vertical) = %s\n' % effective_stride_y)
+ f.write('Effective padding (horizontal) = %s\n' % effective_padding_x)
+ f.write('Effective padding (vertical) = %s\n' % effective_padding_y)
+ f.close()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--graph_path', type=str, default='', help='Graph path (pbtxt format).')
+ parser.add_argument(
+ '--output_path',
+ type=str,
+ default='',
+ help='Path to output text file where RF information will be written to.')
+ parser.add_argument(
+ '--input_node', type=str, default='', help='Name of input node.')
+ parser.add_argument(
+ '--output_node', type=str, default='', help='Name of output node.')
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py
new file mode 100644
index 0000000000..94228dfa61
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py
@@ -0,0 +1,460 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Computes Receptive Field (RF) information for different models.
+
+The receptive field (and related parameters) for the different models are
+printed to stdout, and may also optionally be written to a CSV file.
+
+For an example of usage, see rf_benchmark.sh
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import sys
+
+from nets import alexnet
+from nets import inception
+from nets import mobilenet_v1
+from nets import resnet_v1
+from nets import resnet_v2
+from nets import vgg
+from tensorflow.contrib import framework
+from tensorflow.contrib import receptive_field
+from tensorflow.contrib import slim
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import app
+
+cmd_args = None
+
+# Input node name for all architectures.
+_INPUT_NODE = 'input_image'
+
+# Variants of different network architectures.
+
+# - resnet: different versions and sizes.
+_SUPPORTED_RESNET_VARIANTS = [
+ 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200',
+ 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200'
+]
+
+# - inception_resnet_v2: default, and version with SAME padding.
+_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [
+ 'inception_resnet_v2', 'inception_resnet_v2-same'
+]
+
+# - inception_v2: default, and version with no separable conv.
+_SUPPORTED_INCEPTIONV2_VARIANTS = [
+ 'inception_v2', 'inception_v2-no-separable-conv'
+]
+
+# - inception_v3: default version.
+_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3']
+
+# - inception_v4: default version.
+_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4']
+
+# - alexnet_v2: default version.
+_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2']
+
+# - vgg: vgg_a (with 11 layers) and vgg_16 (version D).
+_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16']
+
+# - mobilenet_v1: 100% and 75%.
+_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075']
+
+
+def _construct_model(model_type='resnet_v1_50'):
+ """Constructs model for the desired type of CNN.
+
+ Args:
+ model_type: Type of model to be used.
+
+ Returns:
+ end_points: A dictionary from components of the network to the corresponding
+ activations.
+
+ Raises:
+ ValueError: If the model_type is not supported.
+ """
+ # Placeholder input.
+ images = array_ops.placeholder(
+ dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE)
+
+ # Construct model.
+ if model_type == 'inception_resnet_v2':
+ _, end_points = inception.inception_resnet_v2_base(images)
+ elif model_type == 'inception_resnet_v2-same':
+ _, end_points = inception.inception_resnet_v2_base(
+ images, align_feature_maps=True)
+ elif model_type == 'inception_v2':
+ _, end_points = inception.inception_v2_base(images)
+ elif model_type == 'inception_v2-no-separable-conv':
+ _, end_points = inception.inception_v2_base(
+ images, use_separable_conv=False)
+ elif model_type == 'inception_v3':
+ _, end_points = inception.inception_v3_base(images)
+ elif model_type == 'inception_v4':
+ _, end_points = inception.inception_v4_base(images)
+ elif model_type == 'alexnet_v2':
+ _, end_points = alexnet.alexnet_v2(images)
+ elif model_type == 'vgg_a':
+ _, end_points = vgg.vgg_a(images)
+ elif model_type == 'vgg_16':
+ _, end_points = vgg.vgg_16(images)
+ elif model_type == 'mobilenet_v1':
+ _, end_points = mobilenet_v1.mobilenet_v1_base(images)
+ elif model_type == 'mobilenet_v1_075':
+ _, end_points = mobilenet_v1.mobilenet_v1_base(
+ images, depth_multiplier=0.75)
+ elif model_type == 'resnet_v1_50':
+ _, end_points = resnet_v1.resnet_v1_50(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v1_101':
+ _, end_points = resnet_v1.resnet_v1_101(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v1_152':
+ _, end_points = resnet_v1.resnet_v1_152(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v1_200':
+ _, end_points = resnet_v1.resnet_v1_200(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v2_50':
+ _, end_points = resnet_v2.resnet_v2_50(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v2_101':
+ _, end_points = resnet_v2.resnet_v2_101(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v2_152':
+ _, end_points = resnet_v2.resnet_v2_152(
+ images, num_classes=None, is_training=False, global_pool=False)
+ elif model_type == 'resnet_v2_200':
+ _, end_points = resnet_v2.resnet_v2_200(
+ images, num_classes=None, is_training=False, global_pool=False)
+ else:
+ raise ValueError('Unsupported model_type %s.' % model_type)
+
+ return end_points
+
+
+def _get_desired_end_point_keys(model_type='resnet_v1_50'):
+ """Gets list of desired end point keys for a type of CNN.
+
+ Args:
+ model_type: Type of model to be used.
+
+ Returns:
+ desired_end_point_types: A list containing the desired end-points.
+
+ Raises:
+ ValueError: If the model_type is not supported.
+ """
+ if model_type in _SUPPORTED_RESNET_VARIANTS:
+ blocks = ['block1', 'block2', 'block3', 'block4']
+ desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks]
+ elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
+ desired_end_point_keys = [
+ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
+ 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
+ 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1'
+ ]
+ elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
+ desired_end_point_keys = [
+ 'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
+ 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
+ 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'
+ ]
+ elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
+ desired_end_point_keys = [
+ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
+ 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
+ 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
+ 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'
+ ]
+ elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
+ desired_end_point_keys = [
+ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
+ 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e',
+ 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f',
+ 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'
+ ]
+ elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
+ ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5']
+ desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
+ elif model_type in _SUPPORTED_VGG_VARIANTS:
+ ep = [
+ 'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1',
+ 'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4',
+ 'conv5/conv5_1', 'conv5/conv5_2', 'pool5'
+ ]
+ desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
+ elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
+ desired_end_point_keys = [
+ 'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
+ 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise',
+ 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
+ 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
+ 'Conv2d_12_pointwise', 'Conv2d_13_pointwise'
+ ]
+ else:
+ raise ValueError('Unsupported model_type %s.' % model_type)
+
+ return desired_end_point_keys
+
+
+def _model_graph_def(model_type='resnet_v1_50', arg_sc=None):
+ """Constructs a model graph, returning GraphDef and end-points.
+
+ Args:
+ model_type: Type of model to be used.
+ arg_sc: Optional arg scope to use in constructing the graph.
+
+ Returns:
+ graph_def: GraphDef of constructed graph.
+ end_points: A dictionary from components of the network to the corresponding
+ activations.
+ """
+ if arg_sc is None:
+ arg_sc = {}
+ g = ops.Graph()
+ with g.as_default():
+ with framework.arg_scope(arg_sc):
+ end_points = _construct_model(model_type)
+
+ return g.as_graph_def(), end_points
+
+
+def _model_rf(graphdef,
+ end_points,
+ desired_end_point_keys,
+ model_type='resnet_v1_50',
+ csv_writer=None):
+ """Computes receptive field information for a given CNN model.
+
+ The information will be printed to stdout. If the RF parameters are the same
+ for the horizontal and vertical directions, it will be printed only once.
+ Otherwise, they are printed once for the horizontal and once for the vertical
+ directions.
+
+ Args:
+ graphdef: GraphDef of given model.
+ end_points: A dictionary from components of the model to the corresponding
+ activations.
+ desired_end_point_keys: List of desired end points for which receptive field
+ information will be computed.
+ model_type: Type of model to be used, used only for printing purposes.
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for desired_end_point_key in desired_end_point_keys:
+ print('- %s:' % desired_end_point_key)
+ output_node_with_colon = end_points[desired_end_point_key].name
+ pos = output_node_with_colon.rfind(':')
+ output_node = output_node_with_colon[:pos]
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y
+ ) = receptive_field.compute_receptive_field_from_graph_def(
+ graphdef, _INPUT_NODE, output_node)
+ # If values are the same in horizontal/vertical directions, just report one
+ # of them. Otherwise, report both.
+ if (receptive_field_x == receptive_field_y) and (
+ effective_stride_x == effective_stride_y) and (
+ effective_padding_x == effective_padding_y):
+ print('Receptive field size = %5s, effective stride = %5s, effective '
+ 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x),
+ str(effective_padding_x)))
+ else:
+ print('Receptive field size: horizontal = %5s, vertical = %5s. '
+ 'Effective stride: horizontal = %5s, vertical = %5s. Effective '
+ 'padding: horizontal = %5s, vertical = %5s' %
+ (str(receptive_field_x), str(receptive_field_y),
+ str(effective_stride_x), str(effective_stride_y),
+ str(effective_padding_x), str(effective_padding_y)))
+ if csv_writer is not None:
+ csv_writer.writerow({
+ 'CNN': model_type,
+ 'end_point': desired_end_point_key,
+ 'RF size hor': str(receptive_field_x),
+ 'RF size ver': str(receptive_field_y),
+ 'effective stride hor': str(effective_stride_x),
+ 'effective stride ver': str(effective_stride_y),
+ 'effective padding hor': str(effective_padding_x),
+ 'effective padding ver': str(effective_padding_y)
+ })
+
+
+def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None):
+ """Contructs model graph and desired end-points, and compute RF.
+
+ The computed RF parameters are printed to stdout by the _model_rf function.
+
+ Args:
+ model_type: Type of model to be used.
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ arg_sc: Optional arg scope to use in constructing the graph.
+
+ """
+ print('********************%s' % model_type)
+ graphdef, end_points = _model_graph_def(model_type, arg_sc)
+ desired_end_point_keys = _get_desired_end_point_keys(model_type)
+ _model_rf(graphdef, end_points, desired_end_point_keys, model_type,
+ csv_writer)
+
+
+def _resnet_rf(csv_writer=None):
+ """Computes RF and associated parameters for resnet models.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_RESNET_VARIANTS:
+ arg_sc = resnet_v1.resnet_arg_scope()
+ _process_model_rf(model_type, csv_writer, arg_sc)
+
+
+def _inception_resnet_v2_rf(csv_writer=None):
+ """Computes RF and associated parameters for the inception_resnet_v2 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _inception_v2_rf(csv_writer=None):
+ """Computes RF and associated parameters for the inception_v2 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _inception_v3_rf(csv_writer=None):
+ """Computes RF and associated parameters for the inception_v3 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _inception_v4_rf(csv_writer=None):
+ """Computes RF and associated parameters for the inception_v4 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _alexnet_v2_rf(csv_writer=None):
+ """Computes RF and associated parameters for the alexnet_v2 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _vgg_rf(csv_writer=None):
+ """Computes RF and associated parameters for the vgg model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_VGG_VARIANTS:
+ _process_model_rf(model_type, csv_writer)
+
+
+def _mobilenet_v1_rf(csv_writer=None):
+ """Computes RF and associated parameters for the mobilenet_v1 model.
+
+ The computed values are written to stdout.
+
+ Args:
+ csv_writer: A CSV writer for RF parameters, which is used if it is not None.
+ """
+ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
+ with slim.arg_scope(
+ [slim.batch_norm, slim.dropout], is_training=False) as arg_sc:
+ _process_model_rf(model_type, csv_writer, arg_sc)
+
+
+def main(unused_argv):
+ # Configure CSV file which will be written, if desired.
+ if cmd_args.csv_path:
+ csv_file = open(cmd_args.csv_path, 'w')
+ field_names = [
+ 'CNN', 'end_point', 'RF size hor', 'RF size ver',
+ 'effective stride hor', 'effective stride ver', 'effective padding hor',
+ 'effective padding ver'
+ ]
+ rf_writer = csv.DictWriter(csv_file, fieldnames=field_names)
+ rf_writer.writeheader()
+ else:
+ rf_writer = None
+
+ # Compute RF parameters for each network architecture.
+ _alexnet_v2_rf(rf_writer)
+ _vgg_rf(rf_writer)
+ _inception_v2_rf(rf_writer)
+ _inception_v3_rf(rf_writer)
+ _inception_v4_rf(rf_writer)
+ _inception_resnet_v2_rf(rf_writer)
+ _mobilenet_v1_rf(rf_writer)
+ _resnet_rf(rf_writer)
+
+ # Close CSV file, if it was opened.
+ if cmd_args.csv_path:
+ csv_file.close()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--csv_path',
+ type=str,
+ default='',
+ help="""\
+ Path to CSV file that will be written with RF parameters.If empty, no
+ file will be written.\
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py
new file mode 100644
index 0000000000..a2384f66ce
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple script to write Inception-ResNet-v2 model to graph file.
+"""
+
+import argparse
+import sys
+
+from nets import inception
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_io
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import app
+
+cmd_args = None
+
+
+def main(unused_argv):
+ # Model definition.
+ g = ops.Graph()
+ with g.as_default():
+ images = array_ops.placeholder(
+ dtypes.float32, shape=(1, None, None, 3), name='input_image')
+ inception.inception_resnet_v2_base(images)
+
+ graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
+ cmd_args.graph_filename)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--graph_dir',
+ type=str,
+ default='/tmp',
+ help='Directory where graph will be saved.')
+ parser.add_argument(
+ '--graph_filename',
+ type=str,
+ default='graph.pbtxt',
+ help='Filename of graph that will be saved.')
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
new file mode 100644
index 0000000000..8af4be16d6
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
@@ -0,0 +1,88 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library to compute order of computations in a graph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+class GraphDefHelper(object):
+ """Helper class to collect node names and definitions.
+
+ Example:
+ b = GraphDefHelper(graph_def)
+ # Prints node that produces given output.
+ print b.output_of['conv/foo/bar']
+ """
+
+ def __init__(self, gd):
+ self.output_of = {}
+ for each in gd.node:
+ self.output_of[each.name] = each
+
+
+# pylint: disable=invalid-name
+_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node'])
+
+
+def _get_computed_nodes(g, output, seen):
+ """Traverses the graph in topological order.
+
+ Args:
+ g: GraphDefHelper object.
+ output: current node.
+ seen: map of nodes we've already traversed.
+ Returns:
+ order in topological sort for 'output'.
+ """
+ if output in seen:
+ return seen[output].order
+ node_def = g.output_of.get(output, None)
+ if node_def is None:
+ seen[output] = _NodeEntry(0, None)
+ return 0
+
+ r = 0
+ for each in node_def.input:
+ # Parses name of input node.
+ if each.startswith('^'):
+ each = each[1:]
+ each = each.split(':')[0]
+ # Recursively computes ordering.
+ new_v = _get_computed_nodes(g, each, seen)
+ r = max(r, new_v + 1)
+
+ seen[output] = _NodeEntry(r, node_def)
+
+ return seen[output].order
+
+
+def get_compute_order(graph_def):
+ """Computes order of computation for a given graph.
+
+ Args:
+ graph_def: GraphDef object.
+ Returns:
+ map: name -> {order, node}
+ """
+ helper = GraphDefHelper(graph_def)
+ seen = collections.defaultdict(_NodeEntry)
+ for each in graph_def.node:
+ _get_computed_nodes(helper, each.name, seen)
+ return seen
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
new file mode 100644
index 0000000000..4e723829bf
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
@@ -0,0 +1,481 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to compute receptive field of a fully-convolutional network.
+
+Please refer to the following g3doc for detailed explanation on how this
+computation is performed, and why it is important:
+g3doc/photos/vision/features/delf/g3doc/rf_computation.md
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from tensorflow.contrib.receptive_field.python.util import graph_compute_order
+from tensorflow.contrib.util import make_ndarray
+from tensorflow.python.platform import tf_logging as logging
+
+# White-listed layer operations, which do not affect the receptive field
+# computation.
+_UNCHANGED_RF_LAYER_OPS = [
+ "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity",
+ "VariableV2", "Sub", "Rsqrt", "ConcatV2"
+]
+
+
+def _stride_size(node):
+ """Computes stride size given a TF node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+
+ Returns:
+ stride_x: Stride size for horizontal direction (integer).
+ stride_y: Stride size for vertical direction (integer).
+ """
+ strides_attr = node.attr["strides"]
+ logging.vlog(4, "strides_attr = %s", strides_attr)
+ stride_y = strides_attr.list.i[1]
+ stride_x = strides_attr.list.i[2]
+ return stride_x, stride_y
+
+
+def _conv_kernel_size(node, name_to_order_node):
+ """Computes kernel size given a TF convolution or pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_order_node: Map from name to {order, node}. Output of
+ graph_compute_order.get_compute_order().
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+
+ Raises:
+ ValueError: If the weight layer node is invalid.
+ """
+ weights_layer_read_name = node.input[1]
+ if not weights_layer_read_name.endswith("/read"):
+ raise ValueError(
+ "Weight layer's name input to conv layer does not end with '/read'")
+ weights_layer_param_name = weights_layer_read_name[:-5]
+ weights_node = name_to_order_node[weights_layer_param_name].node
+ if weights_node.op != "VariableV2":
+ raise ValueError("Weight layer is not of type VariableV2")
+ shape = weights_node.attr["shape"]
+ logging.vlog(4, "weight shape = %s", shape)
+ kernel_size_y = shape.shape.dim[0].size
+ kernel_size_x = shape.shape.dim[1].size
+ return kernel_size_x, kernel_size_y
+
+
+def _padding_size_conv_pool(node, kernel_size, stride):
+ """Computes padding size given a TF convolution or pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ kernel_size: Kernel size of node (integer).
+ stride: Stride size of node (integer).
+
+ Returns:
+ padding: Padding size (integer).
+
+ Raises:
+ ValueError: If padding is invalid.
+ """
+ # In this case, we need to carefully consider the different TF padding modes.
+ # The padding depends on kernel size, and may depend on input size. If it
+ # depends on input size, we raise an exception.
+ padding_attr = node.attr["padding"]
+ logging.vlog(4, "padding_attr = %s", padding_attr)
+ if padding_attr.s == "VALID":
+ padding = 0
+ elif padding_attr.s == "SAME":
+ if kernel_size == 1:
+ padding = 0
+ elif stride == 1:
+ padding = int(math.floor((float(kernel_size) - 1) / 2))
+ elif stride == 2 and kernel_size % 2 == 0:
+ padding = int(math.floor((float(kernel_size) - 1) / 2))
+ else:
+ padding = None
+ logging.warning(
+ "Padding depends on input size, which means that the effective "
+ "padding may be different depending on the input image "
+ "dimensionality. In this case, alignment check will be skipped.")
+ else:
+ raise ValueError("Invalid padding operation")
+ return padding
+
+
+def _pool_kernel_size(node):
+ """Computes kernel size given a TF pooling node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+
+ Raises:
+ ValueError: If pooling is invalid.
+ """
+ ksize = node.attr["ksize"]
+ kernel_size_y = ksize.list.i[1]
+ kernel_size_x = ksize.list.i[2]
+ if ksize.list.i[0] != 1:
+ raise ValueError("pool ksize for first dim is not 1")
+ if ksize.list.i[3] != 1:
+ raise ValueError("pool ksize for last dim is not 1")
+ return kernel_size_x, kernel_size_y
+
+
+def _padding_size_pad_layer(node, name_to_order_node):
+ """Computes padding size given a TF padding node.
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_order_node: Map from name to {order, node}. Output of
+ graph_compute_order.get_compute_order().
+
+ Returns:
+ padding_x: Padding size for horizontal direction (integer).
+ padding_y: Padding size for vertical direction (integer).
+
+ Raises:
+ ValueError: If padding layer is invalid.
+ """
+ paddings_layer_name = node.input[1]
+ if not paddings_layer_name.endswith("/paddings"):
+ raise ValueError("Padding layer name does not end with '/paddings'")
+ paddings_node = name_to_order_node[paddings_layer_name].node
+ if paddings_node.op != "Const":
+ raise ValueError("Padding op is not Const")
+ value = paddings_node.attr["value"]
+ t = make_ndarray(value.tensor)
+ padding_y = t[1][0]
+ padding_x = t[2][0]
+ if t[0][0] != 0:
+ raise ValueError("padding is not zero for first tensor dim")
+ if t[3][0] != 0:
+ raise ValueError("padding is not zero for last tensor dim")
+ return padding_x, padding_y
+
+
+def _get_layer_params(node, name_to_order_node):
+ """Gets layer parameters relevant for RF computation.
+
+ Currently, only these nodes are supported:
+ - Conv2D
+ - DepthwiseConv2dNative
+ - Pad
+ - MaxPool
+ - AvgPool
+ - all nodes listed in _UNCHANGED_RF_LAYER_OPS
+
+ Args:
+ node: Tensorflow node (NodeDef proto).
+ name_to_order_node: Map from name to {order, node}. Output of
+ graph_compute_order.get_compute_order().
+
+ Returns:
+ kernel_size_x: Kernel size for horizontal direction (integer).
+ kernel_size_y: Kernel size for vertical direction (integer).
+ stride_x: Stride size for horizontal direction (integer).
+ stride_y: Stride size for vertical direction (integer).
+ padding_x: Padding size for horizontal direction (integer).
+ padding_y: Padding size for vertical direction (integer).
+
+ Raises:
+ ValueError: If layer op is unknown.
+ """
+ logging.vlog(3, "node.op = %s", node.op)
+ logging.vlog(4, "node = %s", node)
+ if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
+ stride_x, stride_y = _stride_size(node)
+ kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node)
+ # Compute the padding for this node separately for each direction.
+ padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
+ padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
+ elif node.op == "Pad":
+ # Kernel and stride are simply 1 in this case.
+ kernel_size_x = 1
+ kernel_size_y = 1
+ stride_x = 1
+ stride_y = 1
+ padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node)
+ elif node.op == "MaxPool" or node.op == "AvgPool":
+ stride_x, stride_y = _stride_size(node)
+ kernel_size_x, kernel_size_y = _pool_kernel_size(node)
+ # Compute the padding for this node separately for each direction.
+ padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
+ padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
+ elif node.op in _UNCHANGED_RF_LAYER_OPS:
+ # These nodes do not modify the RF parameters.
+ kernel_size_x = 1
+ kernel_size_y = 1
+ stride_x = 1
+ stride_y = 1
+ padding_x = 0
+ padding_y = 0
+ else:
+ raise ValueError("Unknown layer op: %s" % node.op)
+ return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
+
+
+def _reverse_sort_by_order(name_to_order_node):
+ """Sorts map of name_to_order_node nodes in reverse order.
+
+ The output is such that the nodes in name_to_order_node are sorted in
+ descending order of the "order" field.
+
+ Args:
+ name_to_order_node: Map from name to {order, node}. Output of
+ graph_compute_order.get_compute_order().
+
+ Returns:
+ sorted_name_to_order_node: Sorted version of the input, in descending order.
+ """
+ return sorted(name_to_order_node.items(), key=lambda x: -x[1].order)
+
+
+def _get_rf_size_node_input(stride, kernel_size, rf_size_output):
+ """Computes RF size at the input of a given layer.
+
+ Args:
+ stride: Stride of given layer (integer).
+ kernel_size: Kernel size of given layer (integer).
+ rf_size_output: RF size at output of given layer (integer).
+
+ Returns:
+ rf_size_input: RF size at input of given layer (integer).
+ """
+ return stride * rf_size_output + kernel_size - stride
+
+
+def _get_effective_stride_node_input(stride, effective_stride_output):
+ """Computes effective stride at the input of a given layer.
+
+ Args:
+ stride: Stride of given layer (integer).
+ effective_stride_output: Effective stride at output of given layer
+ (integer).
+
+ Returns:
+ effective_stride_input: Effective stride at input of given layer
+ (integer).
+ """
+ return stride * effective_stride_output
+
+
+def _get_effective_padding_node_input(stride, padding,
+ effective_padding_output):
+ """Computes effective padding at the input of a given layer.
+
+ Args:
+ stride: Stride of given layer (integer).
+ padding: Padding of given layer (integer).
+ effective_padding_output: Effective padding at output of given layer
+ (integer).
+
+ Returns:
+ effective_padding_input: Effective padding at input of given layer
+ (integer).
+ """
+ return stride * effective_padding_output + padding
+
+
+def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
+ """Computes receptive field (RF) parameters from a GraphDef object.
+
+ Args:
+ graph_def: GraphDef object.
+ input_node: Name of the input node from graph.
+ output_node: Name of the output node from graph.
+
+ Returns:
+ rf_size_x: Receptive field size of network in the horizontal direction, with
+ respect to specified input and output.
+ rf_size_y: Receptive field size of network in the vertical direction, with
+ respect to specified input and output.
+ effective_stride_x: Effective stride of network in the horizontal direction,
+ with respect to specified input and output.
+ effective_stride_y: Effective stride of network in the vertical direction,
+ with respect to specified input and output.
+ effective_padding_x: Effective padding of network in the horizontal
+ direction, with respect to specified input and output.
+ effective_padding_y: Effective padding of network in the vertical
+ direction, with respect to specified input and output.
+
+ Raises:
+ ValueError: If network is not aligned or if either input or output nodes
+ cannot be found. For network criterion alignment, see
+ photos/vision/features/delf/g3doc/rf_computation.md
+ """
+ # Computes order of computation for a given graph.
+ name_to_order_node = graph_compute_order.get_compute_order(
+ graph_def=graph_def)
+
+ # Sort in reverse topological order.
+ order = _reverse_sort_by_order(name_to_order_node)
+
+ # Dictionaries to keep track of receptive field, effective stride and
+ # effective padding of different nodes.
+ rf_sizes_x = {}
+ rf_sizes_y = {}
+ effective_strides_x = {}
+ effective_strides_y = {}
+ effective_paddings_x = {}
+ effective_paddings_y = {}
+
+ # Initialize dicts for output_node.
+ rf_sizes_x[output_node] = 1
+ rf_sizes_y[output_node] = 1
+ effective_strides_x[output_node] = 1
+ effective_strides_y[output_node] = 1
+ effective_paddings_x[output_node] = 0
+ effective_paddings_y[output_node] = 0
+
+ # Flag to denote if we found output node yet. If we have not, we skip nodes
+ # until the output node is found.
+ found_output_node = False
+
+ # Flag to denote if padding is undefined. This happens when SAME padding mode
+ # is used in conjunction with stride and kernel sizes which make it such that
+ # the padding to be applied would depend on the input size. In this case,
+ # alignment checks are skipped, and the effective padding is None.
+ undefined_padding = False
+
+ for _, (o, node) in order:
+ if node:
+ logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op))
+ else:
+ continue
+
+ # When we find input node, we can stop.
+ if node.name == input_node:
+ break
+
+ # Loop until we find the output node. All nodes before finding the output
+ # one are irrelevant, so they can be skipped.
+ if not found_output_node:
+ if node.name == output_node:
+ found_output_node = True
+
+ if found_output_node:
+ if node.name not in rf_sizes_x:
+ assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
+ "not in rf_sizes_x" % node.name)
+ # In this case, node is not relevant since it's not part of the
+ # computation we're interested in.
+ logging.vlog(3, "Irrelevant node %s, skipping it...", node.name)
+ continue
+
+ # Get params for this layer.
+ kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = (
+ _get_layer_params(node, name_to_order_node))
+ logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, "
+ "stride_x = %s, stride_y = %s, "
+ "padding_x = %s, padding_y = %s" %
+ (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
+ padding_y))
+ if padding_x is None or padding_y is None:
+ undefined_padding = True
+
+ # Get parameters at input of this layer which may or may not be propagated
+ # to the input layers.
+ rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x,
+ rf_sizes_x[node.name])
+ rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y,
+ rf_sizes_y[node.name])
+ effective_stride_input_x = _get_effective_stride_node_input(
+ stride_x, effective_strides_x[node.name])
+ effective_stride_input_y = _get_effective_stride_node_input(
+ stride_y, effective_strides_y[node.name])
+ if not undefined_padding:
+ effective_padding_input_x = _get_effective_padding_node_input(
+ stride_x, padding_x, effective_paddings_x[node.name])
+ effective_padding_input_y = _get_effective_padding_node_input(
+ stride_y, padding_y, effective_paddings_y[node.name])
+ else:
+ effective_padding_input_x = None
+ effective_padding_input_y = None
+
+ # Loop over this node's inputs and potentially propagate information down.
+ for inp_name in node.input:
+ logging.vlog(4, "inp_name = %s", inp_name)
+ inp_node = name_to_order_node[inp_name].node
+ logging.vlog(4, "inp_node = \n%s", inp_node)
+ if inp_node.name in rf_sizes_x:
+ assert inp_node.name in rf_sizes_y, (
+ "Node %s is in rf_sizes_x, but "
+ "not in rf_sizes_y" % inp_node.name)
+ # This node was already discovered through a previous path, so we need
+ # to make sure that graph is aligned. This alignment check is skipped
+ # if the padding is not defined, since in this case alignment cannot
+ # be checked.
+ if not undefined_padding:
+ if effective_strides_x[inp_node.name] != effective_stride_input_x:
+ raise ValueError(
+ "Graph is not aligned since effective stride from different "
+ "paths is different in horizontal direction")
+ if effective_strides_y[inp_node.name] != effective_stride_input_y:
+ raise ValueError(
+ "Graph is not aligned since effective stride from different "
+ "paths is different in vertical direction")
+ if (rf_sizes_x[inp_node.name] - 1
+ ) / 2 - effective_paddings_x[inp_node.name] != (
+ rf_size_input_x - 1) / 2 - effective_padding_input_x:
+ raise ValueError(
+ "Graph is not aligned since center shift from different "
+ "paths is different in horizontal direction")
+ if (rf_sizes_y[inp_node.name] - 1
+ ) / 2 - effective_paddings_y[inp_node.name] != (
+ rf_size_input_y - 1) / 2 - effective_padding_input_y:
+ raise ValueError(
+ "Graph is not aligned since center shift from different "
+ "paths is different in vertical direction")
+ # Keep track of path with largest RF, for both directions.
+ if rf_sizes_x[inp_node.name] < rf_size_input_x:
+ rf_sizes_x[inp_node.name] = rf_size_input_x
+ effective_strides_x[inp_node.name] = effective_stride_input_x
+ effective_paddings_x[inp_node.name] = effective_padding_input_x
+ if rf_sizes_y[inp_node.name] < rf_size_input_y:
+ rf_sizes_y[inp_node.name] = rf_size_input_y
+ effective_strides_y[inp_node.name] = effective_stride_input_y
+ effective_paddings_y[inp_node.name] = effective_padding_input_y
+ else:
+ assert inp_node.name not in rf_sizes_y, (
+ "Node %s is in rf_sizes_y, but "
+ "not in rf_sizes_x" % inp_node.name)
+ # In this case, it is the first time we encounter this node. So we
+ # propagate the RF parameters.
+ rf_sizes_x[inp_node.name] = rf_size_input_x
+ rf_sizes_y[inp_node.name] = rf_size_input_y
+ effective_strides_x[inp_node.name] = effective_stride_input_x
+ effective_strides_y[inp_node.name] = effective_stride_input_y
+ effective_paddings_x[inp_node.name] = effective_padding_input_x
+ effective_paddings_y[inp_node.name] = effective_padding_input_y
+
+ if not found_output_node:
+ raise ValueError("Output node was not found")
+ if input_node not in rf_sizes_x:
+ raise ValueError("Input node was not found")
+ return (rf_sizes_x[input_node], rf_sizes_y[input_node],
+ effective_strides_x[input_node], effective_strides_y[input_node],
+ effective_paddings_x[input_node], effective_paddings_y[input_node])
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
new file mode 100644
index 0000000000..44e5beda60
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
@@ -0,0 +1,221 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for receptive_fields module."""
+
+from tensorflow.contrib import slim
+from tensorflow.contrib.receptive_field.python.util import receptive_field
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+def create_test_network_1():
+ """Aligned network for test.
+
+ The graph corresponds to the example from the second figure in
+ go/cnn-rf-computation#arbitrary-computation-graphs
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Left branch.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ # Right branch.
+ l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
+ l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
+ l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
+ # Addition.
+ nn.relu(l1 + l3, name='output')
+ return g
+
+
+def create_test_network_2():
+ """Aligned network for test.
+
+ The graph corresponds to a variation to the example from the second figure in
+ go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed
+ to max-pooling operations. Since the functionality is the same as convolution,
+ the network is aligned and the receptive field size is the same as from the
+ network created using create_test_network_1().
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Left branch.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ # Right branch.
+ l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
+ l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID')
+ l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID')
+ # Addition.
+ nn.relu(l1 + l3, name='output')
+ return g
+
+
+def create_test_network_3():
+ """Misaligned network for test.
+
+ The graph corresponds to the example from the first figure in
+ go/cnn-rf-computation#arbitrary-computation-graphs
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Left branch.
+ l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]])
+ l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID')
+ # Right branch.
+ l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID')
+ l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID')
+ # Addition.
+ nn.relu(l1 + l3, name='output')
+ return g
+
+
+def create_test_network_4():
+ """Misaligned network for test.
+
+ The graph corresponds to a variation from the example from the second figure
+ in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME'
+ padding, which makes its padding dependent on the input image dimensionality.
+ In this case, the effective padding will be undetermined, and the utility is
+ not able to check the network alignment.
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Left branch.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ # Right branch.
+ l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
+ l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
+ # Addition.
+ nn.relu(l1 + l3, name='output')
+ return g
+
+
+def create_test_network_5():
+ """Single-path network for testing non-square kernels.
+
+ The graph is similar to the right branch of the graph from
+ create_test_network_1(), except that the kernel sizes are changed to be
+ non-square.
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Two convolutional layers, where the first one has non-square kernel.
+ l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID')
+ l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID')
+ # ReLU.
+ nn.relu(l2, name='output')
+ return g
+
+
+class RfUtilsTest(test.TestCase):
+
+ def testComputeRFFromGraphDefAligned(self):
+ graph_def = create_test_network_1().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node))
+ self.assertEqual(receptive_field_x, 3)
+ self.assertEqual(receptive_field_y, 3)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, 1)
+ self.assertEqual(effective_padding_y, 1)
+
+ def testComputeRFFromGraphDefAligned2(self):
+ graph_def = create_test_network_2().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node))
+ self.assertEqual(receptive_field_x, 3)
+ self.assertEqual(receptive_field_y, 3)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, 1)
+ self.assertEqual(effective_padding_y, 1)
+
+ def testComputeRFFromGraphDefUnaligned(self):
+ graph_def = create_test_network_3().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ with self.assertRaises(ValueError):
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node)
+
+ def testComputeRFFromGraphDefUnaligned2(self):
+ graph_def = create_test_network_4().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node))
+ self.assertEqual(receptive_field_x, 3)
+ self.assertEqual(receptive_field_y, 3)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, None)
+ self.assertEqual(effective_padding_y, None)
+
+ def testComputeRFFromGraphDefNonSquareRF(self):
+ graph_def = create_test_network_5().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node))
+ self.assertEqual(receptive_field_x, 5)
+ self.assertEqual(receptive_field_y, 7)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, 0)
+ self.assertEqual(effective_padding_y, 0)
+
+
+if __name__ == '__main__':
+ test.main()