aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/receptive_field
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-08 10:43:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 10:47:18 -0800
commit428ef1d987ab5b75baedefd29620c7b4874bc836 (patch)
tree296464a13bf128dd2a583d8c144e2e9fe129a6e5 /tensorflow/contrib/receptive_field
parent3392e77ccc85ccb3a21b7d8350c62ff907b2a205 (diff)
Small changes to tf.contrib.receptive_field, in order to
1) handle control-dependency input nodes. 2) handle a few more types of operations. The new functionality is tested by a unit test. PiperOrigin-RevId: 181185267
Diffstat (limited to 'tensorflow/contrib/receptive_field')
-rw-r--r--tensorflow/contrib/receptive_field/python/util/graph_compute_order.py4
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field.py37
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field_test.py48
3 files changed, 72 insertions, 17 deletions
diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
index 8af4be16d6..6153607656 100644
--- a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
+++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
@@ -62,7 +62,9 @@ def _get_computed_nodes(g, output, seen):
for each in node_def.input:
# Parses name of input node.
if each.startswith('^'):
- each = each[1:]
+ # The character '^' denotes a control dependency, so this input node can
+ # be safely ignored.
+ continue
each = each.split(':')[0]
# Recursively computes ordering.
new_v = _get_computed_nodes(g, each, seen)
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
index 8b34465d21..6d207347d8 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
@@ -33,9 +33,10 @@ import numpy as np
# White-listed layer operations, which do not affect the receptive field
# computation.
_UNCHANGED_RF_LAYER_OPS = [
- 'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log',
- 'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub',
- 'VariableV2']
+ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "Identity",
+ "Log", "Mul", "Pow", "RealDiv", "Relu", "Relu6", "Round", "Rsqrt",
+ "Softplus", "Sub", "VariableV2"
+]
# Different ways in which padding modes may be spelled.
_VALID_PADDING = ["VALID", b"VALID"]
@@ -240,8 +241,8 @@ def _get_layer_params(node, name_to_order_node):
padding_x = 0
padding_y = 0
else:
- raise ValueError("Unknown layer for operation '%s': %s" %
- (node.name, node.op))
+ raise ValueError("Unknown layer for operation '%s': %s" % (node.name,
+ node.op))
return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
@@ -308,22 +309,21 @@ def _get_effective_padding_node_input(stride, padding,
class ReceptiveField:
- """
- Receptive field of a convolutional neural network.
+ """Receptive field of a convolutional neural network.
Args:
size: Receptive field size.
stride: Effective stride.
padding: Effective padding.
"""
+
def __init__(self, size, stride, padding):
self.size = np.asarray(size)
self.stride = np.asarray(stride)
self.padding = np.asarray(padding)
def compute_input_center_coordinates(self, y, axis=None):
- """
- Computes the center of the receptive field that generated a feature.
+ """Computes the center of the receptive field that generated a feature.
Args:
y: An array of feature coordinates with shape `(..., d)`, where `d` is the
@@ -354,8 +354,7 @@ class ReceptiveField:
(self.size[axis] - 1) / 2
def compute_feature_coordinates(self, x, axis=None):
- """
- Computes the position of a feature given the center of a receptive field.
+ """Computes the position of a feature given the center of a receptive field.
Args:
x: An array of input center coordinates with shape `(..., d)`, where `d`
@@ -388,7 +387,9 @@ class ReceptiveField:
return iter(np.concatenate([self.size, self.stride, self.padding]))
-def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
+def compute_receptive_field_from_graph_def(graph_def,
+ input_node,
+ output_node,
stop_propagation=None):
"""Computes receptive field (RF) parameters from a Graph or GraphDef object.
@@ -531,7 +532,13 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
if any(inp_name.startswith(stop) for stop in stop_propagation):
logging.vlog(3, "Skipping explicitly ignored node %s.", node.name)
continue
+
logging.vlog(4, "inp_name = %s", inp_name)
+ if inp_name.startswith("^"):
+ # The character "^" denotes a control dependency, so this input node
+ # can be safely ignored.
+ continue
+
inp_node = name_to_order_node[inp_name].node
logging.vlog(4, "inp_node = \n%s", inp_node)
if inp_node.name in rf_sizes_x:
@@ -590,6 +597,6 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
if input_node not in rf_sizes_x:
raise ValueError("Input node was not found")
return ReceptiveField(
- (rf_sizes_x[input_node], rf_sizes_y[input_node]),
- (effective_strides_x[input_node], effective_strides_y[input_node]),
- (effective_paddings_x[input_node], effective_paddings_y[input_node]))
+ (rf_sizes_x[input_node], rf_sizes_y[input_node]),
+ (effective_strides_x[input_node], effective_strides_y[input_node]),
+ (effective_paddings_x[input_node], effective_paddings_y[input_node]))
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
index 8d7d5440f6..de860fbd3c 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
@@ -23,6 +23,8 @@ from tensorflow.contrib.receptive_field.python.util import receptive_field
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
import numpy as np
@@ -176,6 +178,34 @@ def create_test_network_6():
return g
+def create_test_network_7():
+ """Aligned network for test, with a control dependency.
+
+ The graph is similar to create_test_network_1(), except that it includes an
+ assert operation on the left branch.
+
+ Returns:
+ g: Tensorflow graph object (Graph proto).
+ """
+ g = ops.Graph()
+ with g.as_default():
+ # An 8x8 test image.
+ x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
+ # Left branch.
+ l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
+ l1_shape = array_ops.shape(l1)
+ assert_op = control_flow_ops.Assert(
+ gen_math_ops.equal(l1_shape[1], 2), [l1_shape], summarize=4)
+ # Right branch.
+ l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
+ l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
+ l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
+ # Addition.
+ with ops.control_dependencies([assert_op]):
+ nn.relu(l1 + l3, name='output')
+ return g
+
+
class RfUtilsTest(test.TestCase):
def testComputeRFFromGraphDefAligned(self):
@@ -269,7 +299,7 @@ class RfUtilsTest(test.TestCase):
input_node = 'input_image'
output_node = 'output'
rf = receptive_field.compute_receptive_field_from_graph_def(
- graph_def, input_node, output_node)
+ graph_def, input_node, output_node)
x = np.random.randint(0, 100, (50, 2))
y = rf.compute_feature_coordinates(x)
@@ -277,5 +307,21 @@ class RfUtilsTest(test.TestCase):
self.assertAllEqual(x, x2)
+ def testComputeRFFromGraphDefAlignedWithControlDependencies(self):
+ graph_def = create_test_network_7().as_graph_def()
+ input_node = 'input_image'
+ output_node = 'output'
+ (receptive_field_x, receptive_field_y, effective_stride_x,
+ effective_stride_y, effective_padding_x, effective_padding_y) = (
+ receptive_field.compute_receptive_field_from_graph_def(
+ graph_def, input_node, output_node))
+ self.assertEqual(receptive_field_x, 3)
+ self.assertEqual(receptive_field_y, 3)
+ self.assertEqual(effective_stride_x, 4)
+ self.assertEqual(effective_stride_y, 4)
+ self.assertEqual(effective_padding_x, 1)
+ self.assertEqual(effective_padding_y, 1)
+
+
if __name__ == '__main__':
test.main()