aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py214
1 files changed, 112 insertions, 102 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 6c87dccaf1..d56c1457e0 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -55,8 +55,11 @@ from tensorflow.python.util import tf_inspect
# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+# This is to avoid a circular dependency with gradients_impl
+gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-def create_substitute_placeholder(value, name, dtype=None):
+
+def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
@@ -88,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None):
return placeholder
-def capture_value(tensor_map, value, dtype, name):
- """Capture a value from outside the function, to pass in as an extra arg."""
- captured_value = tensor_map.get(value, None)
- if captured_value is None:
- captured_value = create_substitute_placeholder(value, name=name,
- dtype=dtype)
- tensor_map[value] = captured_value
- tape.record_operation("captured_value", [captured_value], [value],
- lambda x: [x])
- return captured_value
-
-
-class CapturingGraph(ops.Graph):
- """Graph that can capture tensors from other graphs.
-
- Attributes:
- captures: Maps external tensor -> internal tensor (e.g. input placeholder).
- The entries are in the order they were captured.
- """
-
- def __init__(self):
- super(CapturingGraph, self).__init__()
-
- self.captures = collections.OrderedDict()
- self._building_function = True
-
- # Map from resource tensor name to last op (in program order) which uses
- # this tensor. Used to enforce that execution order matches program order
- # for resource tensors.
- self._last_op_using_resource_tensor = {}
-
- def clear_resource_control_flow_state(self):
- self._last_op_using_resource_tensor = {}
-
- # TODO(skyewm): get rid of name and use the name of `tensor`.
- def capture(self, tensor, name=None):
- """Capture `tensor` if it's external to this graph.
-
- If `tensor` is from a different graph, returns a placeholder for it.
- `tensor` and the placeholder will also appears in self.captures. Multiple
- calls to this method with the same `tensor` argument will return the same
- placeholder. If `tensor` is from this graph, returns `tensor`.
-
- Args:
- tensor: Tensor. May be from this FuncGraph or a different graph.
- name: Optional name if a placeholder is created.
-
- Returns:
- Tensor from this FuncGraph.
- """
- if isinstance(tensor, ops.EagerTensor):
- if name is None:
- name = str(ops.uid())
- return capture_value(self.captures, tensor, tensor.dtype, name)
- if tensor.graph is not self:
- if name is None:
- name = tensor.op.name
- return capture_value(self.captures, tensor, tensor.dtype, name)
- return tensor
-
- def create_op(
- self,
- op_type,
- inputs,
- dtypes, # pylint: disable=redefined-outer-name
- input_types=None,
- name=None,
- attrs=None,
- op_def=None,
- compute_shapes=True,
- compute_device=True):
- """Captures an external inputs before calling Graph.capture_op."""
- # This capturing logic interacts poorly with control flow contexts which
- # want to replace inputs of ops far too late in the process. This can lead
- # the context to get confused and try to create an Enter for an Enter. We
- # can detect this here and skip the additional Enter which can confuse loop
- # validation logic.
- if op_type == "Enter" and inputs[0].op.type == "Enter":
- if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
- return inputs[0].op
- # Calling AddValue on the control flow contexts to force creation of the
- # backward accumulators in the original graph before we create placeholders
- # to capture the inputs.
- ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
- for i, inp in enumerate(inputs):
- if ctxt is not None and hasattr(ctxt, "AddValue"):
- inp = ctxt.AddValue(inp)
- inp = self.capture(inp)
- inputs[i] = inp
- return super(CapturingGraph, self).create_op(
- op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_device=compute_device)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -190,7 +99,7 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
-class FuncGraph(CapturingGraph):
+class FuncGraph(ops.Graph):
"""Graph representing a function body.
Attributes:
@@ -207,6 +116,8 @@ class FuncGraph(CapturingGraph):
variables: Variables that should be watched during function execution.
outer_graph: The graph this function is defined in. May be another FuncGraph
or the global default Graph.
+ captures: Maps external tensor -> internal tensor (i.e. input placeholder).
+ The entries are in the order they were captured.
seed: The graph-level random seed.
"""
@@ -227,6 +138,13 @@ class FuncGraph(CapturingGraph):
self.structured_outputs = None
self.variables = []
self.outer_graph = ops.get_default_graph()
+ self.captures = collections.OrderedDict()
+
+ self._building_function = True
+ # Map from resource tensor name to last op (in program order) which uses
+ # this tensor. Used to enforce that execution order matches program order
+ # for resource tensors.
+ self._last_op_using_resource_tensor = {}
graph = self.outer_graph
@@ -255,15 +173,107 @@ class FuncGraph(CapturingGraph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ def create_op(
+ self,
+ op_type,
+ inputs,
+ dtypes,
+ input_types=None,
+ name=None,
+ attrs=None,
+ op_def=None,
+ compute_shapes=True,
+ compute_device=True):
+ """Like Graph.create_op, except handles external input tensors.
+
+ This overload adds functionality to create_op to "capture" any external
+ input tensors, i.e. tensors from the eager context or outer function graphs
+ if this is a nested function. See `capture` for more information.
+
+ Args:
+ op_type: The `Operation` type to create. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+ dtypes: A list of `DType` objects that will be the types of the tensors
+ that the operation produces.
+ input_types: (Optional.) A list of `DType`s that will be the types of
+ the tensors that the operation consumes. By default, uses the base
+ `DType` of each input in `inputs`. Operations that expect
+ reference-typed inputs must specify `input_types` explicitly.
+ name: (Optional.) A string name for the operation. If not specified, a
+ name is generated based on `op_type`.
+ attrs: (Optional.) A dictionary where the key is the attribute name (a
+ string) and the value is the respective `attr` attribute of the
+ `NodeDef` proto that will represent the operation (an `AttrValue`
+ proto).
+ op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+ the operation will have.
+ compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+ computed).
+ compute_device: (Optional.) If True, device functions will be executed
+ to compute the device property of the Operation.
+
+ Returns:
+ An `Operation` object.
+ """
+ # This capturing logic interacts poorly with control flow contexts which
+ # want to replace inputs of ops far too late in the process. This can lead
+ # the context to get confused and try to create an Enter for an Enter. We
+ # can detect this here and skip the additional Enter which can confuse loop
+ # validation logic.
+ if op_type == "Enter" and inputs[0].op.type == "Enter":
+ if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
+ return inputs[0].op
+ # Calling AddValue on the control flow contexts to force creation of the
+ # backward accumulators in the original graph before we create placeholders
+ # to capture the inputs.
+ ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
+ for i, inp in enumerate(inputs):
+ # TPU Estimator defines a control flow context with no AddValue method.
+ if ctxt is not None and hasattr(ctxt, "AddValue"):
+ inp = ctxt.AddValue(inp)
+ inp = self.capture(inp)
+ inputs[i] = inp
+ return super(FuncGraph, self).create_op(
+ op_type, inputs, dtypes, input_types, name, attrs, op_def,
+ compute_device=compute_device)
+
def capture(self, tensor, name=None):
- """Calls CapturingGraph.capture and updates self.inputs if necessary."""
- new_capture = tensor not in self.captures
- internal_tensor = super(FuncGraph, self).capture(tensor, name)
+ """Captures `tensor` if it's external to this graph.
- if new_capture and tensor is not internal_tensor:
- self.inputs.append(internal_tensor)
+ If `tensor` is from a different graph, returns a placeholder for it.
+ `tensor` and the placeholder will appear in self.captures, and the
+ placeholder will appear in self.inputs. Multiple calls to this method with
+ the same `tensor` argument will return the same placeholder. If `tensor` is
+ from this graph, returns `tensor`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+
+ Returns:
+ Tensor from this FuncGraph.
+ """
+ if isinstance(tensor, ops.EagerTensor):
+ if name is None:
+ name = str(ops.uid())
+ return self._capture_helper(tensor, name)
+ if tensor.graph is not self:
+ if name is None:
+ name = tensor.op.name
+ return self._capture_helper(tensor, name)
+ return tensor
- return internal_tensor
+ def _capture_helper(self, tensor, name):
+ captured_tensor = self.captures.get(tensor, None)
+ if captured_tensor is None:
+ captured_tensor = _create_substitute_placeholder(tensor, name=name,
+ dtype=tensor.dtype)
+ self.captures[tensor] = captured_tensor
+ self.inputs.append(captured_tensor)
+ tape.record_operation("captured_value", [captured_tensor], [tensor],
+ lambda x: [x])
+ return captured_tensor
@property
def external_captures(self):