diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-31 11:00:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-31 11:04:09 -0700 |
commit | 863163067ad38e093ae7c4557c5065f589eef70c (patch) | |
tree | 8d4f97f266950bc66daad93f9bbc0f76ee3830da /tensorflow/contrib/receptive_field | |
parent | 8dbd2b91f9b387e6f3b0084671e46325fc9915be (diff) |
Introducing tf.contrib.receptive_field
PiperOrigin-RevId: 167160384
Diffstat (limited to 'tensorflow/contrib/receptive_field')
10 files changed, 1674 insertions, 0 deletions
diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD new file mode 100644 index 0000000000..ed2f3af08c --- /dev/null +++ b/tensorflow/contrib/receptive_field/BUILD @@ -0,0 +1,71 @@ +# Description: +# Contains modules to compute receptive field parameters for CNN models. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "receptive_field_pip", + deps = [ + ":graph_compute_order_py", + ":receptive_field_py", + ], +) + +py_library( + name = "graph_compute_order_py", + srcs = [ + "__init__.py", + "python/util/graph_compute_order.py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "receptive_field_py", + srcs = [ + "__init__.py", + "python/util/receptive_field.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":graph_compute_order_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "receptive_field_test", + srcs = ["python/util/receptive_field_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":receptive_field_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md new file mode 100644 index 0000000000..f7539ec145 --- /dev/null +++ b/tensorflow/contrib/receptive_field/README.md @@ -0,0 +1,164 @@ +# Receptive field computation for convnets + +This library enables you to easily compute the receptive field parameters of +your favorite convnet. You can use it to understand how big of an input image +region your output features depend on. Better yet, using the parameters computed +by the library, you can easily find the exact image region which is used to +compute each convnet feature. + +## Basic usage + +The main function to be called is `compute_receptive_field_from_graph_def`, +which will return the receptive field, effective stride and effective padding +for both horizontal and vertical directions. + +For example, if your model is constructed using the function +`my_model_construction()`, you can use the library as follows: + +```python +import tensorflow as tf +from tensorflow.contrib import receptive_field + +# Construct graph. +g = tf.Graph() +with g.as_default(): + images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image') + my_model_construction(images) + +# Compute receptive field parameters. +rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \ + receptive_field.compute_receptive_field_from_graph_def( \ + g.as_graph_def(), 'input_image', 'my_output_endpoint') +``` + +Here's a simple example of computing the receptive field parameters for +Inception-Resnet-v2. To get this to work, be sure to checkout +[tensorflow/models](https://github.com/tensorflow/models), so that the Inception +models are available to you. This can be done in three simple commands: + +```sh +git clone https://github.com/tensorflow/models +cd models/slim +sudo python setup.py install_lib +``` + +You can then compute the receptive field parameters for Inception-Resnet-v2 as: + +```python +from nets import inception +import tensorflow as tf +from tensorflow.contrib import receptive_field + +# Construct graph. +g = tf.Graph() +with g.as_default(): + images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image') + inception.inception_resnet_v2_base(images) + +# Compute receptive field parameters. +rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \ + receptive_field.compute_receptive_field_from_graph_def( \ + g.as_graph_def(), 'input_image', 'InceptionResnetV2/Conv2d_7b_1x1/Relu') +``` + +This will give you `rf_x = rf_y = 3039`, `eff_stride_x = eff_stride_y = 32`, and +`eff_pad_x = eff_pad_y = 1482`. This means that each feature that is output at +the node `'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is computed from a region +which is of size `3039x3039`. Further, by using the expressions + +```python +center_x = -eff_pad_x + feature_x*eff_stride_x + (rf_x - 1)/2 +center_y = -eff_pad_y + feature_y*eff_stride_y + (rf_y - 1)/2 +``` + +one can compute the center of the region in the input image that is used to +compute the output feature at position `[feature_x, feature_y]`. For example, +the feature at position `[0, 2]` at the output of the layer +`'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is centered in the original image in +the position `[37, 101]`. + +TODO: include link to derivations and definitions of different parameters. + +## Receptive field benchmark + +As you might expect, it is straightforward to run this library on the popular +convnets, and gather their receptive fields. We provide a python script which +does exactly that, available under `python/util/examples/rf_benchmark.py`. + +To get this to work, be sure to checkout +[tensorflow/models](https://github.com/tensorflow/models) (see the 3-command +instructions for this above). Then, simply: + +```sh +cd python/util/examples +python rf_benchmark.py --csv_path /tmp/rf_benchmark_results.csv +``` + +The script will write to stdout the receptive field parameters for many variants +of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They +are also written to the file `/tmp/rf_benchmark_results.csv`. + +TODO: include here a plot for receptive field sizes of different convnets. + +TODO: include table/link to pre-computed RF parameters. + +## Compute RF parameters from a graph pbtxt + +We also provide a utility to compute the receptive field parameters directly +from a graph protobuf file. + +Have a `graph.pbtxt` file and want to compute its receptive field parameters? We +got you covered. The only prerequisite is to install +[google/protobuf](https://github.com/google/protobuf), which you probably +already have if you're using tensorflow (otherwise, follow installation +instructions [here](https://github.com/google/protobuf/tree/master/python)). + +This should work: + +```sh +cd python/util/examples +python compute_rf.py \ + --graph_path /path/to/graph.pbtxt \ + --output_path /path/to/output/rf_info.txt \ + --input_node my_input_node \ + --output_node my_output_node +``` + +Don't know how to generate a graph protobuf file? Take a look at the +`write_inception_resnet_v2_graph.py` script, which shows how to save it for the +Inception-Resnet-v2 model: + +```sh +cd python/util/examples +python write_inception_resnet_v2_graph.py --graph_dir /tmp --graph_filename graph.pbtxt +``` + +This will write the Inception-Resnet-v2 graph protobuf to `/tmp/graph.pbtxt`. + +For completeness, here's how you would use this file to get the receptive field +parameters of the Inception-Resnet-v2 model: + +```sh +cd python/util/examples +python compute_rf.py \ + --graph_path /tmp/graph.pbtxt \ + --output_path /tmp/rf_info.txt \ + --input_node input_image \ + --output_node InceptionResnetV2/Conv2d_7b_1x1/Relu +``` + +This will write the receptive field parameters of the model to +`/tmp/rf_info.txt`, which will look like: + +```sh +Receptive field size (horizontal) = 3039 +Receptive field size (vertical) = 3039 +Effective stride (horizontal) = 32 +Effective stride (vertical) = 32 +Effective padding (horizontal) = 1482 +Effective padding (vertical) = 1482 +``` + +## Authors + +André Araujo (andrefaraujo@) and Mark Sandler diff --git a/tensorflow/contrib/receptive_field/__init__.py b/tensorflow/contrib/receptive_field/__init__.py new file mode 100644 index 0000000000..10745a6a53 --- /dev/null +++ b/tensorflow/contrib/receptive_field/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module to compute receptive field parameters for CNN tensorflow models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.receptive_field.python.util.graph_compute_order import get_compute_order +from tensorflow.contrib.receptive_field.python.util.receptive_field import compute_receptive_field_from_graph_def +# pylint: enable=unused-import diff --git a/tensorflow/contrib/receptive_field/python/__init__.py b/tensorflow/contrib/receptive_field/python/__init__.py new file mode 100644 index 0000000000..217047f92d --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module to compute receptive field parameters for CNN tensorflow models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py new file mode 100644 index 0000000000..70a0d11dff --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -0,0 +1,90 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Computes Receptive Field (RF) information given a graph protobuf. + +For an example of usage, see accompanying file compute_rf.sh +""" + +import argparse +import sys + +from google.protobuf import text_format + +from tensorflow.contrib import receptive_field +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.platform import app +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging + +cmd_args = None + + +def _load_graphdef(path): + """Helper function to load GraphDef from file. + + Args: + path: Path to pbtxt file. + + Returns: + graph_def: A GraphDef object. + """ + graph_def = graph_pb2.GraphDef() + pbstr = gfile.Open(path).read() + text_format.Parse(pbstr, graph_def) + return graph_def + + +def main(unused_argv): + + graph_def = _load_graphdef(cmd_args.graph_path) + + (receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y, + effective_padding_x, effective_padding_y + ) = receptive_field.compute_receptive_field_from_graph_def( + graph_def, cmd_args.input_node, cmd_args.output_node) + + logging.info('Receptive field size (horizontal) = %s', receptive_field_x) + logging.info('Receptive field size (vertical) = %s', receptive_field_y) + logging.info('Effective stride (horizontal) = %s', effective_stride_x) + logging.info('Effective stride (vertical) = %s', effective_stride_y) + logging.info('Effective padding (horizontal) = %s', effective_padding_x) + logging.info('Effective padding (vertical) = %s', effective_padding_y) + + f = gfile.GFile('%s' % cmd_args.output_path, 'w') + f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x) + f.write('Receptive field size (vertical) = %s\n' % receptive_field_y) + f.write('Effective stride (horizontal) = %s\n' % effective_stride_x) + f.write('Effective stride (vertical) = %s\n' % effective_stride_y) + f.write('Effective padding (horizontal) = %s\n' % effective_padding_x) + f.write('Effective padding (vertical) = %s\n' % effective_padding_y) + f.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--graph_path', type=str, default='', help='Graph path (pbtxt format).') + parser.add_argument( + '--output_path', + type=str, + default='', + help='Path to output text file where RF information will be written to.') + parser.add_argument( + '--input_node', type=str, default='', help='Name of input node.') + parser.add_argument( + '--output_node', type=str, default='', help='Name of output node.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py new file mode 100644 index 0000000000..94228dfa61 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -0,0 +1,460 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Computes Receptive Field (RF) information for different models. + +The receptive field (and related parameters) for the different models are +printed to stdout, and may also optionally be written to a CSV file. + +For an example of usage, see rf_benchmark.sh +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from nets import alexnet +from nets import inception +from nets import mobilenet_v1 +from nets import resnet_v1 +from nets import resnet_v2 +from nets import vgg +from tensorflow.contrib import framework +from tensorflow.contrib import receptive_field +from tensorflow.contrib import slim +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import app + +cmd_args = None + +# Input node name for all architectures. +_INPUT_NODE = 'input_image' + +# Variants of different network architectures. + +# - resnet: different versions and sizes. +_SUPPORTED_RESNET_VARIANTS = [ + 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200', + 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200' +] + +# - inception_resnet_v2: default, and version with SAME padding. +_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [ + 'inception_resnet_v2', 'inception_resnet_v2-same' +] + +# - inception_v2: default, and version with no separable conv. +_SUPPORTED_INCEPTIONV2_VARIANTS = [ + 'inception_v2', 'inception_v2-no-separable-conv' +] + +# - inception_v3: default version. +_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3'] + +# - inception_v4: default version. +_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4'] + +# - alexnet_v2: default version. +_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2'] + +# - vgg: vgg_a (with 11 layers) and vgg_16 (version D). +_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16'] + +# - mobilenet_v1: 100% and 75%. +_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075'] + + +def _construct_model(model_type='resnet_v1_50'): + """Constructs model for the desired type of CNN. + + Args: + model_type: Type of model to be used. + + Returns: + end_points: A dictionary from components of the network to the corresponding + activations. + + Raises: + ValueError: If the model_type is not supported. + """ + # Placeholder input. + images = array_ops.placeholder( + dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE) + + # Construct model. + if model_type == 'inception_resnet_v2': + _, end_points = inception.inception_resnet_v2_base(images) + elif model_type == 'inception_resnet_v2-same': + _, end_points = inception.inception_resnet_v2_base( + images, align_feature_maps=True) + elif model_type == 'inception_v2': + _, end_points = inception.inception_v2_base(images) + elif model_type == 'inception_v2-no-separable-conv': + _, end_points = inception.inception_v2_base( + images, use_separable_conv=False) + elif model_type == 'inception_v3': + _, end_points = inception.inception_v3_base(images) + elif model_type == 'inception_v4': + _, end_points = inception.inception_v4_base(images) + elif model_type == 'alexnet_v2': + _, end_points = alexnet.alexnet_v2(images) + elif model_type == 'vgg_a': + _, end_points = vgg.vgg_a(images) + elif model_type == 'vgg_16': + _, end_points = vgg.vgg_16(images) + elif model_type == 'mobilenet_v1': + _, end_points = mobilenet_v1.mobilenet_v1_base(images) + elif model_type == 'mobilenet_v1_075': + _, end_points = mobilenet_v1.mobilenet_v1_base( + images, depth_multiplier=0.75) + elif model_type == 'resnet_v1_50': + _, end_points = resnet_v1.resnet_v1_50( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_101': + _, end_points = resnet_v1.resnet_v1_101( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_152': + _, end_points = resnet_v1.resnet_v1_152( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_200': + _, end_points = resnet_v1.resnet_v1_200( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_50': + _, end_points = resnet_v2.resnet_v2_50( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_101': + _, end_points = resnet_v2.resnet_v2_101( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_152': + _, end_points = resnet_v2.resnet_v2_152( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_200': + _, end_points = resnet_v2.resnet_v2_200( + images, num_classes=None, is_training=False, global_pool=False) + else: + raise ValueError('Unsupported model_type %s.' % model_type) + + return end_points + + +def _get_desired_end_point_keys(model_type='resnet_v1_50'): + """Gets list of desired end point keys for a type of CNN. + + Args: + model_type: Type of model to be used. + + Returns: + desired_end_point_types: A list containing the desired end-points. + + Raises: + ValueError: If the model_type is not supported. + """ + if model_type in _SUPPORTED_RESNET_VARIANTS: + blocks = ['block1', 'block2', 'block3', 'block4'] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks] + elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', + 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', + 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1' + ] + elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', + 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b', + 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c' + ] + elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', + 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', + 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', + 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c' + ] + elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', + 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e', + 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f', + 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d' + ] + elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS: + ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5'] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] + elif model_type in _SUPPORTED_VGG_VARIANTS: + ep = [ + 'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1', + 'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4', + 'conv5/conv5_1', 'conv5/conv5_2', 'pool5' + ] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] + elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise', + 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise', + 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise', + 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise', + 'Conv2d_12_pointwise', 'Conv2d_13_pointwise' + ] + else: + raise ValueError('Unsupported model_type %s.' % model_type) + + return desired_end_point_keys + + +def _model_graph_def(model_type='resnet_v1_50', arg_sc=None): + """Constructs a model graph, returning GraphDef and end-points. + + Args: + model_type: Type of model to be used. + arg_sc: Optional arg scope to use in constructing the graph. + + Returns: + graph_def: GraphDef of constructed graph. + end_points: A dictionary from components of the network to the corresponding + activations. + """ + if arg_sc is None: + arg_sc = {} + g = ops.Graph() + with g.as_default(): + with framework.arg_scope(arg_sc): + end_points = _construct_model(model_type) + + return g.as_graph_def(), end_points + + +def _model_rf(graphdef, + end_points, + desired_end_point_keys, + model_type='resnet_v1_50', + csv_writer=None): + """Computes receptive field information for a given CNN model. + + The information will be printed to stdout. If the RF parameters are the same + for the horizontal and vertical directions, it will be printed only once. + Otherwise, they are printed once for the horizontal and once for the vertical + directions. + + Args: + graphdef: GraphDef of given model. + end_points: A dictionary from components of the model to the corresponding + activations. + desired_end_point_keys: List of desired end points for which receptive field + information will be computed. + model_type: Type of model to be used, used only for printing purposes. + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for desired_end_point_key in desired_end_point_keys: + print('- %s:' % desired_end_point_key) + output_node_with_colon = end_points[desired_end_point_key].name + pos = output_node_with_colon.rfind(':') + output_node = output_node_with_colon[:pos] + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y + ) = receptive_field.compute_receptive_field_from_graph_def( + graphdef, _INPUT_NODE, output_node) + # If values are the same in horizontal/vertical directions, just report one + # of them. Otherwise, report both. + if (receptive_field_x == receptive_field_y) and ( + effective_stride_x == effective_stride_y) and ( + effective_padding_x == effective_padding_y): + print('Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) + else: + print('Receptive field size: horizontal = %5s, vertical = %5s. ' + 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' + 'padding: horizontal = %5s, vertical = %5s' % + (str(receptive_field_x), str(receptive_field_y), + str(effective_stride_x), str(effective_stride_y), + str(effective_padding_x), str(effective_padding_y))) + if csv_writer is not None: + csv_writer.writerow({ + 'CNN': model_type, + 'end_point': desired_end_point_key, + 'RF size hor': str(receptive_field_x), + 'RF size ver': str(receptive_field_y), + 'effective stride hor': str(effective_stride_x), + 'effective stride ver': str(effective_stride_y), + 'effective padding hor': str(effective_padding_x), + 'effective padding ver': str(effective_padding_y) + }) + + +def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None): + """Contructs model graph and desired end-points, and compute RF. + + The computed RF parameters are printed to stdout by the _model_rf function. + + Args: + model_type: Type of model to be used. + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + arg_sc: Optional arg scope to use in constructing the graph. + + """ + print('********************%s' % model_type) + graphdef, end_points = _model_graph_def(model_type, arg_sc) + desired_end_point_keys = _get_desired_end_point_keys(model_type) + _model_rf(graphdef, end_points, desired_end_point_keys, model_type, + csv_writer) + + +def _resnet_rf(csv_writer=None): + """Computes RF and associated parameters for resnet models. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_RESNET_VARIANTS: + arg_sc = resnet_v1.resnet_arg_scope() + _process_model_rf(model_type, csv_writer, arg_sc) + + +def _inception_resnet_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_resnet_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v3_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v3 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v4_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v4 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _alexnet_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the alexnet_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_ALEXNETV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _vgg_rf(csv_writer=None): + """Computes RF and associated parameters for the vgg model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_VGG_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _mobilenet_v1_rf(csv_writer=None): + """Computes RF and associated parameters for the mobilenet_v1 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: + with slim.arg_scope( + [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + _process_model_rf(model_type, csv_writer, arg_sc) + + +def main(unused_argv): + # Configure CSV file which will be written, if desired. + if cmd_args.csv_path: + csv_file = open(cmd_args.csv_path, 'w') + field_names = [ + 'CNN', 'end_point', 'RF size hor', 'RF size ver', + 'effective stride hor', 'effective stride ver', 'effective padding hor', + 'effective padding ver' + ] + rf_writer = csv.DictWriter(csv_file, fieldnames=field_names) + rf_writer.writeheader() + else: + rf_writer = None + + # Compute RF parameters for each network architecture. + _alexnet_v2_rf(rf_writer) + _vgg_rf(rf_writer) + _inception_v2_rf(rf_writer) + _inception_v3_rf(rf_writer) + _inception_v4_rf(rf_writer) + _inception_resnet_v2_rf(rf_writer) + _mobilenet_v1_rf(rf_writer) + _resnet_rf(rf_writer) + + # Close CSV file, if it was opened. + if cmd_args.csv_path: + csv_file.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='', + help="""\ + Path to CSV file that will be written with RF parameters.If empty, no + file will be written.\ + """) + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py new file mode 100644 index 0000000000..a2384f66ce --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to write Inception-ResNet-v2 model to graph file. +""" + +import argparse +import sys + +from nets import inception +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_io +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + # Model definition. + g = ops.Graph() + with g.as_default(): + images = array_ops.placeholder( + dtypes.float32, shape=(1, None, None, 3), name='input_image') + inception.inception_resnet_v2_base(images) + + graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir, + cmd_args.graph_filename) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--graph_dir', + type=str, + default='/tmp', + help='Directory where graph will be saved.') + parser.add_argument( + '--graph_filename', + type=str, + default='graph.pbtxt', + help='Filename of graph that will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py new file mode 100644 index 0000000000..8af4be16d6 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py @@ -0,0 +1,88 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library to compute order of computations in a graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + + +class GraphDefHelper(object): + """Helper class to collect node names and definitions. + + Example: + b = GraphDefHelper(graph_def) + # Prints node that produces given output. + print b.output_of['conv/foo/bar'] + """ + + def __init__(self, gd): + self.output_of = {} + for each in gd.node: + self.output_of[each.name] = each + + +# pylint: disable=invalid-name +_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node']) + + +def _get_computed_nodes(g, output, seen): + """Traverses the graph in topological order. + + Args: + g: GraphDefHelper object. + output: current node. + seen: map of nodes we've already traversed. + Returns: + order in topological sort for 'output'. + """ + if output in seen: + return seen[output].order + node_def = g.output_of.get(output, None) + if node_def is None: + seen[output] = _NodeEntry(0, None) + return 0 + + r = 0 + for each in node_def.input: + # Parses name of input node. + if each.startswith('^'): + each = each[1:] + each = each.split(':')[0] + # Recursively computes ordering. + new_v = _get_computed_nodes(g, each, seen) + r = max(r, new_v + 1) + + seen[output] = _NodeEntry(r, node_def) + + return seen[output].order + + +def get_compute_order(graph_def): + """Computes order of computation for a given graph. + + Args: + graph_def: GraphDef object. + Returns: + map: name -> {order, node} + """ + helper = GraphDefHelper(graph_def) + seen = collections.defaultdict(_NodeEntry) + for each in graph_def.node: + _get_computed_nodes(helper, each.name, seen) + return seen diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py new file mode 100644 index 0000000000..4e723829bf --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -0,0 +1,481 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions to compute receptive field of a fully-convolutional network. + +Please refer to the following g3doc for detailed explanation on how this +computation is performed, and why it is important: +g3doc/photos/vision/features/delf/g3doc/rf_computation.md +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from tensorflow.contrib.receptive_field.python.util import graph_compute_order +from tensorflow.contrib.util import make_ndarray +from tensorflow.python.platform import tf_logging as logging + +# White-listed layer operations, which do not affect the receptive field +# computation. +_UNCHANGED_RF_LAYER_OPS = [ + "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity", + "VariableV2", "Sub", "Rsqrt", "ConcatV2" +] + + +def _stride_size(node): + """Computes stride size given a TF node. + + Args: + node: Tensorflow node (NodeDef proto). + + Returns: + stride_x: Stride size for horizontal direction (integer). + stride_y: Stride size for vertical direction (integer). + """ + strides_attr = node.attr["strides"] + logging.vlog(4, "strides_attr = %s", strides_attr) + stride_y = strides_attr.list.i[1] + stride_x = strides_attr.list.i[2] + return stride_x, stride_y + + +def _conv_kernel_size(node, name_to_order_node): + """Computes kernel size given a TF convolution or pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + + Raises: + ValueError: If the weight layer node is invalid. + """ + weights_layer_read_name = node.input[1] + if not weights_layer_read_name.endswith("/read"): + raise ValueError( + "Weight layer's name input to conv layer does not end with '/read'") + weights_layer_param_name = weights_layer_read_name[:-5] + weights_node = name_to_order_node[weights_layer_param_name].node + if weights_node.op != "VariableV2": + raise ValueError("Weight layer is not of type VariableV2") + shape = weights_node.attr["shape"] + logging.vlog(4, "weight shape = %s", shape) + kernel_size_y = shape.shape.dim[0].size + kernel_size_x = shape.shape.dim[1].size + return kernel_size_x, kernel_size_y + + +def _padding_size_conv_pool(node, kernel_size, stride): + """Computes padding size given a TF convolution or pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + kernel_size: Kernel size of node (integer). + stride: Stride size of node (integer). + + Returns: + padding: Padding size (integer). + + Raises: + ValueError: If padding is invalid. + """ + # In this case, we need to carefully consider the different TF padding modes. + # The padding depends on kernel size, and may depend on input size. If it + # depends on input size, we raise an exception. + padding_attr = node.attr["padding"] + logging.vlog(4, "padding_attr = %s", padding_attr) + if padding_attr.s == "VALID": + padding = 0 + elif padding_attr.s == "SAME": + if kernel_size == 1: + padding = 0 + elif stride == 1: + padding = int(math.floor((float(kernel_size) - 1) / 2)) + elif stride == 2 and kernel_size % 2 == 0: + padding = int(math.floor((float(kernel_size) - 1) / 2)) + else: + padding = None + logging.warning( + "Padding depends on input size, which means that the effective " + "padding may be different depending on the input image " + "dimensionality. In this case, alignment check will be skipped.") + else: + raise ValueError("Invalid padding operation") + return padding + + +def _pool_kernel_size(node): + """Computes kernel size given a TF pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + + Raises: + ValueError: If pooling is invalid. + """ + ksize = node.attr["ksize"] + kernel_size_y = ksize.list.i[1] + kernel_size_x = ksize.list.i[2] + if ksize.list.i[0] != 1: + raise ValueError("pool ksize for first dim is not 1") + if ksize.list.i[3] != 1: + raise ValueError("pool ksize for last dim is not 1") + return kernel_size_x, kernel_size_y + + +def _padding_size_pad_layer(node, name_to_order_node): + """Computes padding size given a TF padding node. + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + padding_x: Padding size for horizontal direction (integer). + padding_y: Padding size for vertical direction (integer). + + Raises: + ValueError: If padding layer is invalid. + """ + paddings_layer_name = node.input[1] + if not paddings_layer_name.endswith("/paddings"): + raise ValueError("Padding layer name does not end with '/paddings'") + paddings_node = name_to_order_node[paddings_layer_name].node + if paddings_node.op != "Const": + raise ValueError("Padding op is not Const") + value = paddings_node.attr["value"] + t = make_ndarray(value.tensor) + padding_y = t[1][0] + padding_x = t[2][0] + if t[0][0] != 0: + raise ValueError("padding is not zero for first tensor dim") + if t[3][0] != 0: + raise ValueError("padding is not zero for last tensor dim") + return padding_x, padding_y + + +def _get_layer_params(node, name_to_order_node): + """Gets layer parameters relevant for RF computation. + + Currently, only these nodes are supported: + - Conv2D + - DepthwiseConv2dNative + - Pad + - MaxPool + - AvgPool + - all nodes listed in _UNCHANGED_RF_LAYER_OPS + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + stride_x: Stride size for horizontal direction (integer). + stride_y: Stride size for vertical direction (integer). + padding_x: Padding size for horizontal direction (integer). + padding_y: Padding size for vertical direction (integer). + + Raises: + ValueError: If layer op is unknown. + """ + logging.vlog(3, "node.op = %s", node.op) + logging.vlog(4, "node = %s", node) + if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative": + stride_x, stride_y = _stride_size(node) + kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node) + # Compute the padding for this node separately for each direction. + padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x) + padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y) + elif node.op == "Pad": + # Kernel and stride are simply 1 in this case. + kernel_size_x = 1 + kernel_size_y = 1 + stride_x = 1 + stride_y = 1 + padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node) + elif node.op == "MaxPool" or node.op == "AvgPool": + stride_x, stride_y = _stride_size(node) + kernel_size_x, kernel_size_y = _pool_kernel_size(node) + # Compute the padding for this node separately for each direction. + padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x) + padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y) + elif node.op in _UNCHANGED_RF_LAYER_OPS: + # These nodes do not modify the RF parameters. + kernel_size_x = 1 + kernel_size_y = 1 + stride_x = 1 + stride_y = 1 + padding_x = 0 + padding_y = 0 + else: + raise ValueError("Unknown layer op: %s" % node.op) + return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y + + +def _reverse_sort_by_order(name_to_order_node): + """Sorts map of name_to_order_node nodes in reverse order. + + The output is such that the nodes in name_to_order_node are sorted in + descending order of the "order" field. + + Args: + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + sorted_name_to_order_node: Sorted version of the input, in descending order. + """ + return sorted(name_to_order_node.items(), key=lambda x: -x[1].order) + + +def _get_rf_size_node_input(stride, kernel_size, rf_size_output): + """Computes RF size at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + kernel_size: Kernel size of given layer (integer). + rf_size_output: RF size at output of given layer (integer). + + Returns: + rf_size_input: RF size at input of given layer (integer). + """ + return stride * rf_size_output + kernel_size - stride + + +def _get_effective_stride_node_input(stride, effective_stride_output): + """Computes effective stride at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + effective_stride_output: Effective stride at output of given layer + (integer). + + Returns: + effective_stride_input: Effective stride at input of given layer + (integer). + """ + return stride * effective_stride_output + + +def _get_effective_padding_node_input(stride, padding, + effective_padding_output): + """Computes effective padding at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + padding: Padding of given layer (integer). + effective_padding_output: Effective padding at output of given layer + (integer). + + Returns: + effective_padding_input: Effective padding at input of given layer + (integer). + """ + return stride * effective_padding_output + padding + + +def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): + """Computes receptive field (RF) parameters from a GraphDef object. + + Args: + graph_def: GraphDef object. + input_node: Name of the input node from graph. + output_node: Name of the output node from graph. + + Returns: + rf_size_x: Receptive field size of network in the horizontal direction, with + respect to specified input and output. + rf_size_y: Receptive field size of network in the vertical direction, with + respect to specified input and output. + effective_stride_x: Effective stride of network in the horizontal direction, + with respect to specified input and output. + effective_stride_y: Effective stride of network in the vertical direction, + with respect to specified input and output. + effective_padding_x: Effective padding of network in the horizontal + direction, with respect to specified input and output. + effective_padding_y: Effective padding of network in the vertical + direction, with respect to specified input and output. + + Raises: + ValueError: If network is not aligned or if either input or output nodes + cannot be found. For network criterion alignment, see + photos/vision/features/delf/g3doc/rf_computation.md + """ + # Computes order of computation for a given graph. + name_to_order_node = graph_compute_order.get_compute_order( + graph_def=graph_def) + + # Sort in reverse topological order. + order = _reverse_sort_by_order(name_to_order_node) + + # Dictionaries to keep track of receptive field, effective stride and + # effective padding of different nodes. + rf_sizes_x = {} + rf_sizes_y = {} + effective_strides_x = {} + effective_strides_y = {} + effective_paddings_x = {} + effective_paddings_y = {} + + # Initialize dicts for output_node. + rf_sizes_x[output_node] = 1 + rf_sizes_y[output_node] = 1 + effective_strides_x[output_node] = 1 + effective_strides_y[output_node] = 1 + effective_paddings_x[output_node] = 0 + effective_paddings_y[output_node] = 0 + + # Flag to denote if we found output node yet. If we have not, we skip nodes + # until the output node is found. + found_output_node = False + + # Flag to denote if padding is undefined. This happens when SAME padding mode + # is used in conjunction with stride and kernel sizes which make it such that + # the padding to be applied would depend on the input size. In this case, + # alignment checks are skipped, and the effective padding is None. + undefined_padding = False + + for _, (o, node) in order: + if node: + logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op)) + else: + continue + + # When we find input node, we can stop. + if node.name == input_node: + break + + # Loop until we find the output node. All nodes before finding the output + # one are irrelevant, so they can be skipped. + if not found_output_node: + if node.name == output_node: + found_output_node = True + + if found_output_node: + if node.name not in rf_sizes_x: + assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but " + "not in rf_sizes_x" % node.name) + # In this case, node is not relevant since it's not part of the + # computation we're interested in. + logging.vlog(3, "Irrelevant node %s, skipping it...", node.name) + continue + + # Get params for this layer. + kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = ( + _get_layer_params(node, name_to_order_node)) + logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y)) + if padding_x is None or padding_y is None: + undefined_padding = True + + # Get parameters at input of this layer which may or may not be propagated + # to the input layers. + rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x, + rf_sizes_x[node.name]) + rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y, + rf_sizes_y[node.name]) + effective_stride_input_x = _get_effective_stride_node_input( + stride_x, effective_strides_x[node.name]) + effective_stride_input_y = _get_effective_stride_node_input( + stride_y, effective_strides_y[node.name]) + if not undefined_padding: + effective_padding_input_x = _get_effective_padding_node_input( + stride_x, padding_x, effective_paddings_x[node.name]) + effective_padding_input_y = _get_effective_padding_node_input( + stride_y, padding_y, effective_paddings_y[node.name]) + else: + effective_padding_input_x = None + effective_padding_input_y = None + + # Loop over this node's inputs and potentially propagate information down. + for inp_name in node.input: + logging.vlog(4, "inp_name = %s", inp_name) + inp_node = name_to_order_node[inp_name].node + logging.vlog(4, "inp_node = \n%s", inp_node) + if inp_node.name in rf_sizes_x: + assert inp_node.name in rf_sizes_y, ( + "Node %s is in rf_sizes_x, but " + "not in rf_sizes_y" % inp_node.name) + # This node was already discovered through a previous path, so we need + # to make sure that graph is aligned. This alignment check is skipped + # if the padding is not defined, since in this case alignment cannot + # be checked. + if not undefined_padding: + if effective_strides_x[inp_node.name] != effective_stride_input_x: + raise ValueError( + "Graph is not aligned since effective stride from different " + "paths is different in horizontal direction") + if effective_strides_y[inp_node.name] != effective_stride_input_y: + raise ValueError( + "Graph is not aligned since effective stride from different " + "paths is different in vertical direction") + if (rf_sizes_x[inp_node.name] - 1 + ) / 2 - effective_paddings_x[inp_node.name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: + raise ValueError( + "Graph is not aligned since center shift from different " + "paths is different in horizontal direction") + if (rf_sizes_y[inp_node.name] - 1 + ) / 2 - effective_paddings_y[inp_node.name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: + raise ValueError( + "Graph is not aligned since center shift from different " + "paths is different in vertical direction") + # Keep track of path with largest RF, for both directions. + if rf_sizes_x[inp_node.name] < rf_size_input_x: + rf_sizes_x[inp_node.name] = rf_size_input_x + effective_strides_x[inp_node.name] = effective_stride_input_x + effective_paddings_x[inp_node.name] = effective_padding_input_x + if rf_sizes_y[inp_node.name] < rf_size_input_y: + rf_sizes_y[inp_node.name] = rf_size_input_y + effective_strides_y[inp_node.name] = effective_stride_input_y + effective_paddings_y[inp_node.name] = effective_padding_input_y + else: + assert inp_node.name not in rf_sizes_y, ( + "Node %s is in rf_sizes_y, but " + "not in rf_sizes_x" % inp_node.name) + # In this case, it is the first time we encounter this node. So we + # propagate the RF parameters. + rf_sizes_x[inp_node.name] = rf_size_input_x + rf_sizes_y[inp_node.name] = rf_size_input_y + effective_strides_x[inp_node.name] = effective_stride_input_x + effective_strides_y[inp_node.name] = effective_stride_input_y + effective_paddings_x[inp_node.name] = effective_padding_input_x + effective_paddings_y[inp_node.name] = effective_padding_input_y + + if not found_output_node: + raise ValueError("Output node was not found") + if input_node not in rf_sizes_x: + raise ValueError("Input node was not found") + return (rf_sizes_x[input_node], rf_sizes_y[input_node], + effective_strides_x[input_node], effective_strides_y[input_node], + effective_paddings_x[input_node], effective_paddings_y[input_node]) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py new file mode 100644 index 0000000000..44e5beda60 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -0,0 +1,221 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for receptive_fields module.""" + +from tensorflow.contrib import slim +from tensorflow.contrib.receptive_field.python.util import receptive_field +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +def create_test_network_1(): + """Aligned network for test. + + The graph corresponds to the example from the second figure in + go/cnn-rf-computation#arbitrary-computation-graphs + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_2(): + """Aligned network for test. + + The graph corresponds to a variation to the example from the second figure in + go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed + to max-pooling operations. Since the functionality is the same as convolution, + the network is aligned and the receptive field size is the same as from the + network created using create_test_network_1(). + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_3(): + """Misaligned network for test. + + The graph corresponds to the example from the first figure in + go/cnn-rf-computation#arbitrary-computation-graphs + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]]) + l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID') + # Right branch. + l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_4(): + """Misaligned network for test. + + The graph corresponds to a variation from the example from the second figure + in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME' + padding, which makes its padding dependent on the input image dimensionality. + In this case, the effective padding will be undetermined, and the utility is + not able to check the network alignment. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_5(): + """Single-path network for testing non-square kernels. + + The graph is similar to the right branch of the graph from + create_test_network_1(), except that the kernel sizes are changed to be + non-square. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Two convolutional layers, where the first one has non-square kernel. + l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID') + l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID') + # ReLU. + nn.relu(l2, name='output') + return g + + +class RfUtilsTest(test.TestCase): + + def testComputeRFFromGraphDefAligned(self): + graph_def = create_test_network_1().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeRFFromGraphDefAligned2(self): + graph_def = create_test_network_2().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeRFFromGraphDefUnaligned(self): + graph_def = create_test_network_3().as_graph_def() + input_node = 'input_image' + output_node = 'output' + with self.assertRaises(ValueError): + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node) + + def testComputeRFFromGraphDefUnaligned2(self): + graph_def = create_test_network_4().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, None) + self.assertEqual(effective_padding_y, None) + + def testComputeRFFromGraphDefNonSquareRF(self): + graph_def = create_test_network_5().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 5) + self.assertEqual(receptive_field_y, 7) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 0) + self.assertEqual(effective_padding_y, 0) + + +if __name__ == '__main__': + test.main() |