diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 214 |
1 files changed, 112 insertions, 102 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 6c87dccaf1..d56c1457e0 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -55,8 +55,11 @@ from tensorflow.python.util import tf_inspect # (function -> gradients_impl -> control_flow_ops -> cond_v2_impl). cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access +# This is to avoid a circular dependency with gradients_impl +gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access -def create_substitute_placeholder(value, name, dtype=None): + +def _create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. @@ -88,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None): return placeholder -def capture_value(tensor_map, value, dtype, name): - """Capture a value from outside the function, to pass in as an extra arg.""" - captured_value = tensor_map.get(value, None) - if captured_value is None: - captured_value = create_substitute_placeholder(value, name=name, - dtype=dtype) - tensor_map[value] = captured_value - tape.record_operation("captured_value", [captured_value], [value], - lambda x: [x]) - return captured_value - - -class CapturingGraph(ops.Graph): - """Graph that can capture tensors from other graphs. - - Attributes: - captures: Maps external tensor -> internal tensor (e.g. input placeholder). - The entries are in the order they were captured. - """ - - def __init__(self): - super(CapturingGraph, self).__init__() - - self.captures = collections.OrderedDict() - self._building_function = True - - # Map from resource tensor name to last op (in program order) which uses - # this tensor. Used to enforce that execution order matches program order - # for resource tensors. - self._last_op_using_resource_tensor = {} - - def clear_resource_control_flow_state(self): - self._last_op_using_resource_tensor = {} - - # TODO(skyewm): get rid of name and use the name of `tensor`. - def capture(self, tensor, name=None): - """Capture `tensor` if it's external to this graph. - - If `tensor` is from a different graph, returns a placeholder for it. - `tensor` and the placeholder will also appears in self.captures. Multiple - calls to this method with the same `tensor` argument will return the same - placeholder. If `tensor` is from this graph, returns `tensor`. - - Args: - tensor: Tensor. May be from this FuncGraph or a different graph. - name: Optional name if a placeholder is created. - - Returns: - Tensor from this FuncGraph. - """ - if isinstance(tensor, ops.EagerTensor): - if name is None: - name = str(ops.uid()) - return capture_value(self.captures, tensor, tensor.dtype, name) - if tensor.graph is not self: - if name is None: - name = tensor.op.name - return capture_value(self.captures, tensor, tensor.dtype, name) - return tensor - - def create_op( - self, - op_type, - inputs, - dtypes, # pylint: disable=redefined-outer-name - input_types=None, - name=None, - attrs=None, - op_def=None, - compute_shapes=True, - compute_device=True): - """Captures an external inputs before calling Graph.capture_op.""" - # This capturing logic interacts poorly with control flow contexts which - # want to replace inputs of ops far too late in the process. This can lead - # the context to get confused and try to create an Enter for an Enter. We - # can detect this here and skip the additional Enter which can confuse loop - # validation logic. - if op_type == "Enter" and inputs[0].op.type == "Enter": - if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: - return inputs[0].op - # Calling AddValue on the control flow contexts to force creation of the - # backward accumulators in the original graph before we create placeholders - # to capture the inputs. - ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access - for i, inp in enumerate(inputs): - if ctxt is not None and hasattr(ctxt, "AddValue"): - inp = ctxt.AddValue(inp) - inp = self.capture(inp) - inputs[i] = inp - return super(CapturingGraph, self).create_op( - op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_device=compute_device) - - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" if ctx.executing_eagerly(): @@ -190,7 +99,7 @@ def _get_device_functions(ctx, graph): return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access -class FuncGraph(CapturingGraph): +class FuncGraph(ops.Graph): """Graph representing a function body. Attributes: @@ -207,6 +116,8 @@ class FuncGraph(CapturingGraph): variables: Variables that should be watched during function execution. outer_graph: The graph this function is defined in. May be another FuncGraph or the global default Graph. + captures: Maps external tensor -> internal tensor (i.e. input placeholder). + The entries are in the order they were captured. seed: The graph-level random seed. """ @@ -227,6 +138,13 @@ class FuncGraph(CapturingGraph): self.structured_outputs = None self.variables = [] self.outer_graph = ops.get_default_graph() + self.captures = collections.OrderedDict() + + self._building_function = True + # Map from resource tensor name to last op (in program order) which uses + # this tensor. Used to enforce that execution order matches program order + # for resource tensors. + self._last_op_using_resource_tensor = {} graph = self.outer_graph @@ -255,15 +173,107 @@ class FuncGraph(CapturingGraph): self._graph_key = graph._graph_key # pylint: enable=protected-access + def create_op( + self, + op_type, + inputs, + dtypes, + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_shapes=True, + compute_device=True): + """Like Graph.create_op, except handles external input tensors. + + This overload adds functionality to create_op to "capture" any external + input tensors, i.e. tensors from the eager context or outer function graphs + if this is a nested function. See `capture` for more information. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: A list of `DType` objects that will be the types of the tensors + that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of + the tensors that the operation consumes. By default, uses the base + `DType` of each input in `inputs`. Operations that expect + reference-typed inputs must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A dictionary where the key is the attribute name (a + string) and the value is the respective `attr` attribute of the + `NodeDef` proto that will represent the operation (an `AttrValue` + proto). + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). + compute_device: (Optional.) If True, device functions will be executed + to compute the device property of the Operation. + + Returns: + An `Operation` object. + """ + # This capturing logic interacts poorly with control flow contexts which + # want to replace inputs of ops far too late in the process. This can lead + # the context to get confused and try to create an Enter for an Enter. We + # can detect this here and skip the additional Enter which can confuse loop + # validation logic. + if op_type == "Enter" and inputs[0].op.type == "Enter": + if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: + return inputs[0].op + # Calling AddValue on the control flow contexts to force creation of the + # backward accumulators in the original graph before we create placeholders + # to capture the inputs. + ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access + for i, inp in enumerate(inputs): + # TPU Estimator defines a control flow context with no AddValue method. + if ctxt is not None and hasattr(ctxt, "AddValue"): + inp = ctxt.AddValue(inp) + inp = self.capture(inp) + inputs[i] = inp + return super(FuncGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_device=compute_device) + def capture(self, tensor, name=None): - """Calls CapturingGraph.capture and updates self.inputs if necessary.""" - new_capture = tensor not in self.captures - internal_tensor = super(FuncGraph, self).capture(tensor, name) + """Captures `tensor` if it's external to this graph. - if new_capture and tensor is not internal_tensor: - self.inputs.append(internal_tensor) + If `tensor` is from a different graph, returns a placeholder for it. + `tensor` and the placeholder will appear in self.captures, and the + placeholder will appear in self.inputs. Multiple calls to this method with + the same `tensor` argument will return the same placeholder. If `tensor` is + from this graph, returns `tensor`. + + Args: + tensor: Tensor. May be from this FuncGraph or a different graph. + name: Optional name if a placeholder is created. + + Returns: + Tensor from this FuncGraph. + """ + if isinstance(tensor, ops.EagerTensor): + if name is None: + name = str(ops.uid()) + return self._capture_helper(tensor, name) + if tensor.graph is not self: + if name is None: + name = tensor.op.name + return self._capture_helper(tensor, name) + return tensor - return internal_tensor + def _capture_helper(self, tensor, name): + captured_tensor = self.captures.get(tensor, None) + if captured_tensor is None: + captured_tensor = _create_substitute_placeholder(tensor, name=name, + dtype=tensor.dtype) + self.captures[tensor] = captured_tensor + self.inputs.append(captured_tensor) + tape.record_operation("captured_value", [captured_tensor], [tensor], + lambda x: [x]) + return captured_tensor @property def external_captures(self): |