aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py393
1 files changed, 229 insertions, 164 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index a81ef90513..5e4f9e29da 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import threading
import numpy as np
@@ -36,6 +37,7 @@ from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
@@ -136,7 +138,7 @@ class CapturingGraph(ops.Graph):
inputs[i] = self.capture(inp)
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_shapes, compute_device)
+ compute_device=compute_device)
# pylint: disable=invalid-name
@@ -231,11 +233,20 @@ def _register(fn):
context.context().add_function(fn)
+_xla_compile_attr = "_XlaCompile"
+
+
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
class _EagerDefinedFunction(object):
- """Function object with the interface of tf _DefinedFunction."""
+ """Callable with the interface of `framework.function._DefinedFunction.`
+
+ `_EagerDefinedFunction` encapsulates a function definition and its properties,
+ and it provides a method for calling the encapsulated function. Some Ops
+ take functions as attributes, which have type `func`; an instance of this
+ class may be provided as the value of these `func` attributes.
+ """
def __init__(self, name, graph, operations, inputs, outputs, attrs):
"""Initializes an eager defined function.
@@ -266,6 +277,7 @@ class _EagerDefinedFunction(object):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
+ self._xla_compile = _xla_compile_attr in attrs
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
@@ -277,12 +289,92 @@ class _EagerDefinedFunction(object):
if context.executing_eagerly():
_register(fn)
self.definition = function_def
- self.name = function_def.signature.name
+ self.name = compat.as_bytes(function_def.signature.name)
self.signature = function_def.signature
+ self._num_outputs = len(self.signature.output_arg)
+ self._output_types = [o.type for o in self.signature.output_arg]
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
+ self._graph = graph
+ self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful)
+
+ def add_to_graph(self, g):
+ # pylint: disable=protected-access
+ if self.name not in g._functions:
+ g._add_function(self)
+ for f in self._graph._functions.values():
+ if f.name not in g._functions:
+ g._add_function(f)
+ # pylint: enable=protected-access
+
+ @property
+ def stateful_ops(self):
+ return self._stateful_ops
+
+ def call(self, ctx, args, output_shapes):
+ """Calls this function with `args` as inputs.
+
+ Function execution respects device annotations only if the function won't
+ be compiled with xla.
+
+ Args:
+ ctx: a Context object
+ args: a list of arguments to supply this function with.
+ output_shapes: shapes to which outputs should be set; ignored when
+ executing eagerly.
+
+ Returns:
+ The outputs of the function call.
+ """
+
+ executing_eagerly = ctx.executing_eagerly()
+
+ xla_compile = self._xla_compile or (executing_eagerly and
+ ctx.device_spec.device_type == "TPU")
+
+ if xla_compile:
+ # XLA compilation relies upon a custom kernel creator to run functions.
+ signature = self.signature
+ if executing_eagerly:
+ outputs = execute.execute(
+ str(signature.name),
+ num_outputs=self._num_outputs,
+ inputs=args,
+ attrs=None,
+ ctx=ctx)
+ else:
+ g = ops.get_default_graph()
+ self.add_to_graph(g)
+ op = g.create_op(
+ signature.name,
+ [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
+ op_def=signature,
+ name="FunctionCall",
+ compute_shapes=False)
+ outputs = op.outputs
+ if not outputs:
+ return op
+ outputs = [outputs] if isinstance(
+ outputs, (ops.Tensor, type(None))) else list(outputs)
+ else:
+ # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
+ # creates `PartitionedCallOp` kernels by default, or remove the previous
+ # branch if a TPU kernel is registered for `PartitionedCall`.
+ outputs = functional_ops.partitioned_call(
+ args=args,
+ f=self,
+ tout=self._output_types,
+ executing_eagerly=executing_eagerly)
+
+ if executing_eagerly:
+ return outputs
+ else:
+ for i, shape in enumerate(output_shapes):
+ outputs[i].set_shape(shape)
+ return outputs
def _map_sequence_obj_to_idx(sequence):
@@ -306,8 +398,12 @@ def _flatten(sequence):
return outputs
+# TODO(akshayka): Perhaps rename to something more appropriate.
class GraphModeFunction(object):
- """Callable object representing a graph-mode function.
+ """Callable object encapsulating a function definition and its gradient.
+
+ `GraphModeFunction` is a callable that encapsulates a function definition and
+ is differentiable under `tf.GradientTape` objects.
"""
def __init__(self,
@@ -374,37 +470,39 @@ class GraphModeFunction(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
- with self._graph.as_default(), context.graph_mode():
- c_known_ops = set()
- c_captured_tensors = set()
-
- existing_op_len = len(self._graph.get_operations())
- filtered_outputs = [x for x in self._python_returns if x is not None]
+ filtered_outputs = [x for x in self._python_returns if x is not None]
+ captures = {}
+ backwards_graph = CapturingGraph(captures)
+ backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
+ for collection in self._graph.collections:
+ backwards_graph.get_collection_ref(
+ collection)[:] = self._graph.get_collection(collection)
+ backwards_graph.seed = self._graph.seed
+ with backwards_graph.as_default():
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl.gradients(
+ in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access
filtered_outputs,
self._input_placeholders,
- grad_ys=self._out_grad_placeholders)
- for op in self._graph.get_operations()[existing_op_len:]:
- if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
- raise ValueError("tfe.defun cannot capture variables created without "
- "using tf.get_variable. Op: %s" % op)
- c_known_ops.add(op)
- for i in op.inputs:
- if i.op not in c_known_ops:
- c_captured_tensors.add(i)
+ grad_ys=self._out_grad_placeholders,
+ src_graph=self._graph)
backward_outputs = tuple(
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
- captures = list(sorted(c_captured_tensors, key=lambda x: x.name))
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + captures, self._attrs)
- all_inputs = self._out_grad_placeholders + captures
+ filtered_outputs + list(extra_inputs), self._attrs)
+ all_inputs = self._out_grad_placeholders + list(extra_placeholders)
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
all_ignored_ops = frozenset(x.op for x in all_inputs)
@@ -412,11 +510,12 @@ class GraphModeFunction(object):
# means rerunning the function-defining code will always define the same
# function, which is useful if we serialize this etc.
function_def_ops = tuple(x
- for x in sorted(c_known_ops, key=lambda x: x.name)
+ for x in sorted(backwards_graph.get_operations(),
+ key=lambda x: x.name)
if x not in all_ignored_ops)
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
- bname, all_inputs, [], self._graph, function_def_ops,
+ bname, all_inputs, [], backwards_graph, function_def_ops,
backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
@@ -430,35 +529,10 @@ class GraphModeFunction(object):
The call output.
"""
all_args = args + self._extra_inputs
- signature = self._forward_fdef.signature
ctx = context.context()
- if ctx.executing_eagerly():
- outputs = execute.execute(
- str(signature.name),
- num_outputs=len(signature.output_arg),
- inputs=all_args,
- attrs=None,
- ctx=ctx)
- if not outputs:
- return None
- else:
- g = ops.get_default_graph()
- g._add_function(self._forward_fdef) # pylint: disable=protected-access
- op = g.create_op(
- signature.name,
- [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args],
- tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
- op_def=signature,
- name="FunctionCall",
- compute_shapes=False)
- outputs = op.outputs
- if not outputs:
- return op
- outputs = [outputs] if isinstance(outputs, ops.Tensor) else list(outputs)
-
- shapes = [shape for shape in self._output_shapes if shape is not None]
- for i, shape in enumerate(shapes):
- outputs[i].set_shape(shape)
+ outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes)
+ if isinstance(outputs, ops.Operation) or outputs is None:
+ return outputs
# `real_outputs` are the actual outputs of the inference graph function;
# `side_outputs` are the intermediate Tensors that were added as outputs to
@@ -470,7 +544,7 @@ class GraphModeFunction(object):
return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
tape.record_operation(
- signature.name,
+ self._forward_fdef.signature.name,
real_outputs,
(args + self._extra_inputs),
backward_function)
@@ -512,13 +586,6 @@ class GraphModeFunction(object):
"""Returns the name of the function in Eager-compatible format."""
return self._function_def.name.encode("utf-8")
- def add_to_graph(self, g):
- if self._function_def.name not in g._functions: # pylint: disable=protected-access
- g._add_function(self._function_def) # pylint: disable=protected-access
- for f in self._graph._functions.values(): # pylint: disable=protected-access
- if f.name not in g._functions: # pylint: disable=protected-access
- g._add_function(f) # pylint: disable=protected-access
-
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
@@ -533,34 +600,9 @@ class GraphModeFunction(object):
return self._backprop_call(tensor_inputs)
ctx = context.context()
- if ctx.executing_eagerly():
- result = execute.execute(
- str(self._func_name),
- num_outputs=self._num_outputs,
- inputs=tensor_inputs + self._extra_inputs,
- attrs=None,
- ctx=ctx)
- else:
- g = ops.get_default_graph()
- self.add_to_graph(g)
- signature = self._function_def.definition.signature
- args = list(tensor_inputs) + self._extra_inputs
- op = g.create_op(
- signature.name,
- [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
- tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
- op_def=signature,
- name="FunctionCall",
- compute_shapes=False)
- result = op.outputs
- if not result:
- return op
-
- shapes = [shape for shape in self._output_shapes if shape is not None]
- for i, shape in enumerate(shapes):
- result[i].set_shape(shape)
-
- return self._build_call_outputs(result)
+ args = tensor_inputs + self._extra_inputs
+ outputs = self._function_def.call(ctx, args, self._output_shapes)
+ return self._build_call_outputs(outputs)
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -571,7 +613,8 @@ class GraphModeFunction(object):
The actual call output.
"""
if self._python_func_outputs is None:
- return None
+ return result
+
# Use `nest.flatten` instead of `_flatten` in order to preserve any
# IndexedSlices in `self._python_func_outputs`.
outputs_list = nest.flatten(self._python_func_outputs)
@@ -617,55 +660,58 @@ def _deterministic_dict_values(kwds):
def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- captures = {}
- tmp_graph = CapturingGraph(captures)
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- # Copy the graph collections to ensure summaries and other things work. This
- # lets the function access (but not mutate) collections of the containing
- # graph, such as the global step and the summary writer collections.
- curr_graph = ops.get_default_graph()
- for collection in curr_graph.collections:
- tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
- collection)
- with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_args = _get_defun_inputs(args)
- func_kwds = _get_defun_inputs(kwds)
-
- def convert(x):
- if x is None:
- return None
- x = ops.convert_to_tensor_or_indexed_slices(x)
- x = a.mark_as_return(x)
- return x
+ captures = {}
+ tmp_graph = CapturingGraph(captures)
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ tmp_graph._graph_key = graph_key # pylint: disable=protected-access
+ # Copy the graph collections to ensure summaries and other things work. This
+ # lets the function access (but not mutate) collections of the containing
+ # graph, such as the global step and the summary writer collections.
+ curr_graph = ops.get_default_graph()
+ for collection in curr_graph.collections:
+ tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+ collection)
+ if context.executing_eagerly():
+ tmp_graph.seed = context.global_seed()
+ else:
+ tmp_graph.seed = curr_graph.seed
+ with tmp_graph.as_default(), AutomaticControlDependencies() as a:
+ func_args = _get_defun_inputs(args)
+ func_kwds = _get_defun_inputs(kwds)
- this_tape = tape.push_new_tape()
- try:
- func_outputs = func(*func_args, **func_kwds)
- func_outputs = nest.map_structure(convert, func_outputs)
- finally:
- tape.pop_tape(this_tape)
- variables = this_tape.watched_variables()
-
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- outputs_list = _flatten(func_outputs)
- func_def_outputs = [
- tmp_graph.capture(x) for x in outputs_list
- if x is not None
- ]
-
- ids = list(sorted(captures.keys()))
- if ids:
- extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
- else:
- extra_inputs = []
- extra_placeholders = []
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in outputs_list)
+ def convert(x):
+ if x is None:
+ return None
+ x = ops.convert_to_tensor_or_indexed_slices(x)
+ x = a.mark_as_return(x)
+ return x
+
+ this_tape = tape.push_new_tape()
+ try:
+ func_outputs = func(*func_args, **func_kwds)
+ func_outputs = nest.map_structure(convert, func_outputs)
+ finally:
+ tape.pop_tape(this_tape)
+ variables = this_tape.watched_variables()
+
+ # Returning a closed-over tensor as an output does not trigger a
+ # call to convert_to_tensor, so we manually capture all such tensors.
+ outputs_list = _flatten(func_outputs)
+ func_def_outputs = [
+ tmp_graph.capture(x) for x in outputs_list
+ if x is not None
+ ]
+
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+ output_shapes = tuple(
+ x.shape if isinstance(x, ops.Tensor) else None
+ for x in func_def_outputs)
func_kwds_values = _deterministic_dict_values(func_kwds)
flat_inputs = [
@@ -686,7 +732,7 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
attrs = {}
if compiled:
- attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
+ attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
@@ -731,6 +777,11 @@ class _PolymorphicFunction(object):
See the documentation for `defun` for more information on the semantics of
defined functions.
+
+ _PolymorphicFunction class is thread-compatible meaning that minimal
+ usage of defuns (defining and calling) is thread-safe, but if users call other
+ methods or invoke the base `python_function` themselves, external
+ synchronization is necessary.
"""
def __init__(self, python_function, name, compiled=False):
@@ -748,6 +799,8 @@ class _PolymorphicFunction(object):
self._arguments_to_functions = {}
self._variables = []
+ self._lock = threading.Lock()
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -782,22 +835,30 @@ class _PolymorphicFunction(object):
kwd_values = _deterministic_dict_values(kwds)
inputs = args + kwd_values
signature = tuple(_cache_key(x) for x in inputs)
-
- if signature not in self._arguments_to_functions:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # signature so we don't improperly capture tensors such as variables.
+ signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
+
+ with self._lock:
+ if signature not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[signature] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function, inputs
+ else:
+ return self._arguments_to_functions[signature], inputs
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""
graph_function, inputs = self._maybe_define_function(*args, **kwds)
return graph_function(*inputs)
+ def call_python_function(self, *args, **kwargs):
+ """Directly calls the wrapped python function."""
+ return self._python_function(*args, **kwargs)
+
@property
def variables(self):
"""Returns a list of variables used in any of the defined functions."""
@@ -835,6 +896,11 @@ def defun(func=None, compiled=False):
be hashable Python objects or lists thereof. Additionally, it must return zero
or more @{tf.Tensor} objects.
+ Executing a graph generated by `defun` respects device annotations (i.e.,
+ all `with tf.device` directives present in a Python function will also be
+ present in its corresponding graph), but it is not yet possible to execute the
+ generated graphs across multiple machines.
+
_Example Usage_
```python
@@ -1014,7 +1080,7 @@ def defun(func=None, compiled=False):
tf.enable_eager_execution()
def fn():
- x = tf.contrib.eager.Variable(0.0)
+ x = tf.Variable(0.0)
x.assign_add(1.0)
return x.read_value()
@@ -1031,19 +1097,18 @@ def defun(func=None, compiled=False):
```
Finally, because each input signature is bound to a unique graph, if your
- Python function constructs `tf.contrib.eager.Variable` objects, then each
- graph constructed for that Python function will reference a unique set of
- variables. To circumvent this problem, we recommend against compiling Python
- functions that create `tf.contrib.eager.Variable` objects. Instead, Python
- functions should either lexically close over `tf.contrib.eager.Variable`
- objects or accept them as arguments, preferably encapsulated in an
- object-oriented container. If you must create variables inside your Python
- function and you want each graph generated for it to reference the same set of
- variables, add logic to your Python function that ensures that variables are
- only created the first time it is called and are reused for every subsequent
- invocation; note that this is precisely what @{tf.keras.layers.Layer} objects
- do, so we recommend using them to represent variable-bearing computations
- whenever possible.
+ Python function constructs `tf.Variable` objects, then each graph constructed
+ for that Python function will reference a unique set of variables. To
+ circumvent this problem, we recommend against compiling Python functions that
+ create `tf.Variable` objects. Instead, Python functions should either
+ lexically close over `tf.Variable` objects or accept them as arguments,
+ preferably encapsulated in an object-oriented container. If you must create
+ variables inside your Python function and you want each graph generated for it
+ to reference the same set of variables, add logic to your Python function that
+ ensures that variables are only created the first time it is called and are
+ reused for every subsequent invocation; note that this is precisely what
+ @{tf.keras.layers.Layer} objects do, so we recommend using them to represent
+ variable-bearing computations whenever possible.
Args:
func: function to be compiled. If `func` is None, returns a
@@ -1245,7 +1310,7 @@ class AutomaticControlDependencies(object):
# Ensures the merge always runs
ops_which_must_run.add(new_merge[0].op)
if inp in last_op_using_resource_tensor:
- # Ensures the switch exectutes after the previous op using the resource.
+ # Ensures the switch executes after the previous op using the resource.
switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
# Ensure the next op outside the cond happens after the merge.
last_op_using_resource_tensor[inp] = new_merge[0].op