aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/receptive_field/python/util/receptive_field.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/receptive_field/python/util/receptive_field.py')
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field.py134
1 files changed, 122 insertions, 12 deletions
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
index db190a1a41..8b34465d21 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py
@@ -27,13 +27,15 @@ import math
from tensorflow.contrib.receptive_field.python.util import graph_compute_order
from tensorflow.contrib.util import make_ndarray
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.framework import ops as framework_ops
+import numpy as np
# White-listed layer operations, which do not affect the receptive field
# computation.
_UNCHANGED_RF_LAYER_OPS = [
- "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity",
- "VariableV2", "Sub", "Rsqrt", "ConcatV2"
-]
+ 'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log',
+ 'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub',
+ 'VariableV2']
# Different ways in which padding modes may be spelled.
_VALID_PADDING = ["VALID", b"VALID"]
@@ -238,7 +240,8 @@ def _get_layer_params(node, name_to_order_node):
padding_x = 0
padding_y = 0
else:
- raise ValueError("Unknown layer op: %s" % node.op)
+ raise ValueError("Unknown layer for operation '%s': %s" %
+ (node.name, node.op))
return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
@@ -304,13 +307,103 @@ def _get_effective_padding_node_input(stride, padding,
return stride * effective_padding_output + padding
-def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
- """Computes receptive field (RF) parameters from a GraphDef object.
+class ReceptiveField:
+ """
+ Receptive field of a convolutional neural network.
+
+ Args:
+ size: Receptive field size.
+ stride: Effective stride.
+ padding: Effective padding.
+ """
+ def __init__(self, size, stride, padding):
+ self.size = np.asarray(size)
+ self.stride = np.asarray(stride)
+ self.padding = np.asarray(padding)
+
+ def compute_input_center_coordinates(self, y, axis=None):
+ """
+ Computes the center of the receptive field that generated a feature.
+
+ Args:
+ y: An array of feature coordinates with shape `(..., d)`, where `d` is the
+ number of dimensions of the coordinates.
+ axis: The dimensions for which to compute the input center coordinates.
+ If `None` (the default), compute the input center coordinates for all
+ dimensions.
+
+ Returns:
+ x: Center of the receptive field that generated the features, at the input
+ of the network.
+
+ Raises:
+ ValueError: If the number of dimensions of the feature coordinates does
+ not match the number of elements in `axis`.
+ """
+ # Use all dimensions.
+ if axis is None:
+ axis = range(self.size.size)
+ # Ensure axis is a list because tuples have different indexing behavior.
+ axis = list(axis)
+ y = np.asarray(y)
+ if y.shape[-1] != len(axis):
+ raise ValueError("Dimensionality of the feature coordinates `y` (%d) "
+ "does not match dimensionality of `axis` (%d)" %
+ (y.shape[-1], len(axis)))
+ return - self.padding[axis] + y * self.stride[axis] + \
+ (self.size[axis] - 1) / 2
+
+ def compute_feature_coordinates(self, x, axis=None):
+ """
+ Computes the position of a feature given the center of a receptive field.
+
+ Args:
+ x: An array of input center coordinates with shape `(..., d)`, where `d`
+ is the number of dimensions of the coordinates.
+ axis: The dimensions for which to compute the feature coordinates.
+ If `None` (the default), compute the feature coordinates for all
+ dimensions.
+
+ Returns:
+ y: Coordinates of the features.
+
+ Raises:
+ ValueError: If the number of dimensions of the input center coordinates
+ does not match the number of elements in `axis`.
+ """
+ # Use all dimensions.
+ if axis is None:
+ axis = range(self.size.size)
+ # Ensure axis is a list because tuples have different indexing behavior.
+ axis = list(axis)
+ x = np.asarray(x)
+ if x.shape[-1] != len(axis):
+ raise ValueError("Dimensionality of the input center coordinates `x` "
+ "(%d) does not match dimensionality of `axis` (%d)" %
+ (x.shape[-1], len(axis)))
+ return (x + self.padding[axis] + (1 - self.size[axis]) / 2) / \
+ self.stride[axis]
+
+ def __iter__(self):
+ return iter(np.concatenate([self.size, self.stride, self.padding]))
+
+
+def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
+ stop_propagation=None):
+ """Computes receptive field (RF) parameters from a Graph or GraphDef object.
+
+ The algorithm stops the calculation of the receptive field whenever it
+ encounters an operation in the list `stop_propagation`. Stopping the
+ calculation early can be useful to calculate the receptive field of a
+ subgraph such as a single branch of the
+ [inception network](https://arxiv.org/abs/1512.00567).
Args:
- graph_def: GraphDef object.
- input_node: Name of the input node from graph.
- output_node: Name of the output node from graph.
+ graph_def: Graph or GraphDef object.
+ input_node: Name of the input node or Tensor object from graph.
+ output_node: Name of the output node or Tensor object from graph.
+ stop_propagation: List of operation or scope names for which to stop the
+ propagation of the receptive field.
Returns:
rf_size_x: Receptive field size of network in the horizontal direction, with
@@ -331,6 +424,18 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
cannot be found. For network criterion alignment, see
photos/vision/features/delf/g3doc/rf_computation.md
"""
+ # Convert a graph to graph_def if necessary.
+ if isinstance(graph_def, framework_ops.Graph):
+ graph_def = graph_def.as_graph_def()
+
+ # Convert tensors to names.
+ if isinstance(input_node, framework_ops.Tensor):
+ input_node = input_node.op.name
+ if isinstance(output_node, framework_ops.Tensor):
+ output_node = output_node.op.name
+
+ stop_propagation = stop_propagation or []
+
# Computes order of computation for a given graph.
name_to_order_node = graph_compute_order.get_compute_order(
graph_def=graph_def)
@@ -422,6 +527,10 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
# Loop over this node's inputs and potentially propagate information down.
for inp_name in node.input:
+ # Stop the propagation of the receptive field.
+ if any(inp_name.startswith(stop) for stop in stop_propagation):
+ logging.vlog(3, "Skipping explicitly ignored node %s.", node.name)
+ continue
logging.vlog(4, "inp_name = %s", inp_name)
inp_node = name_to_order_node[inp_name].node
logging.vlog(4, "inp_node = \n%s", inp_node)
@@ -480,6 +589,7 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
raise ValueError("Output node was not found")
if input_node not in rf_sizes_x:
raise ValueError("Input node was not found")
- return (rf_sizes_x[input_node], rf_sizes_y[input_node],
- effective_strides_x[input_node], effective_strides_y[input_node],
- effective_paddings_x[input_node], effective_paddings_y[input_node])
+ return ReceptiveField(
+ (rf_sizes_x[input_node], rf_sizes_y[input_node]),
+ (effective_strides_x[input_node], effective_strides_y[input_node]),
+ (effective_paddings_x[input_node], effective_paddings_y[input_node]))