diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 393 |
1 files changed, 229 insertions, 164 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a81ef90513..5e4f9e29da 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +import threading import numpy as np @@ -36,6 +37,7 @@ from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat @@ -136,7 +138,7 @@ class CapturingGraph(ops.Graph): inputs[i] = self.capture(inp) return super(CapturingGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_shapes, compute_device) + compute_device=compute_device) # pylint: disable=invalid-name @@ -231,11 +233,20 @@ def _register(fn): context.context().add_function(fn) +_xla_compile_attr = "_XlaCompile" + + # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. class _EagerDefinedFunction(object): - """Function object with the interface of tf _DefinedFunction.""" + """Callable with the interface of `framework.function._DefinedFunction.` + + `_EagerDefinedFunction` encapsulates a function definition and its properties, + and it provides a method for calling the encapsulated function. Some Ops + take functions as attributes, which have type `func`; an instance of this + class may be provided as the value of these `func` attributes. + """ def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. @@ -266,6 +277,7 @@ class _EagerDefinedFunction(object): # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) + self._xla_compile = _xla_compile_attr in attrs # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. @@ -277,12 +289,92 @@ class _EagerDefinedFunction(object): if context.executing_eagerly(): _register(fn) self.definition = function_def - self.name = function_def.signature.name + self.name = compat.as_bytes(function_def.signature.name) self.signature = function_def.signature + self._num_outputs = len(self.signature.output_arg) + self._output_types = [o.type for o in self.signature.output_arg] self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None + self._graph = graph + self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful) + + def add_to_graph(self, g): + # pylint: disable=protected-access + if self.name not in g._functions: + g._add_function(self) + for f in self._graph._functions.values(): + if f.name not in g._functions: + g._add_function(f) + # pylint: enable=protected-access + + @property + def stateful_ops(self): + return self._stateful_ops + + def call(self, ctx, args, output_shapes): + """Calls this function with `args` as inputs. + + Function execution respects device annotations only if the function won't + be compiled with xla. + + Args: + ctx: a Context object + args: a list of arguments to supply this function with. + output_shapes: shapes to which outputs should be set; ignored when + executing eagerly. + + Returns: + The outputs of the function call. + """ + + executing_eagerly = ctx.executing_eagerly() + + xla_compile = self._xla_compile or (executing_eagerly and + ctx.device_spec.device_type == "TPU") + + if xla_compile: + # XLA compilation relies upon a custom kernel creator to run functions. + signature = self.signature + if executing_eagerly: + outputs = execute.execute( + str(signature.name), + num_outputs=self._num_outputs, + inputs=args, + attrs=None, + ctx=ctx) + else: + g = ops.get_default_graph() + self.add_to_graph(g) + op = g.create_op( + signature.name, + [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), + op_def=signature, + name="FunctionCall", + compute_shapes=False) + outputs = op.outputs + if not outputs: + return op + outputs = [outputs] if isinstance( + outputs, (ops.Tensor, type(None))) else list(outputs) + else: + # TODO(akshayka): Either remove this if the FunctionLibraryRuntime + # creates `PartitionedCallOp` kernels by default, or remove the previous + # branch if a TPU kernel is registered for `PartitionedCall`. + outputs = functional_ops.partitioned_call( + args=args, + f=self, + tout=self._output_types, + executing_eagerly=executing_eagerly) + + if executing_eagerly: + return outputs + else: + for i, shape in enumerate(output_shapes): + outputs[i].set_shape(shape) + return outputs def _map_sequence_obj_to_idx(sequence): @@ -306,8 +398,12 @@ def _flatten(sequence): return outputs +# TODO(akshayka): Perhaps rename to something more appropriate. class GraphModeFunction(object): - """Callable object representing a graph-mode function. + """Callable object encapsulating a function definition and its gradient. + + `GraphModeFunction` is a callable that encapsulates a function definition and + is differentiable under `tf.GradientTape` objects. """ def __init__(self, @@ -374,37 +470,39 @@ class GraphModeFunction(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" - with self._graph.as_default(), context.graph_mode(): - c_known_ops = set() - c_captured_tensors = set() - - existing_op_len = len(self._graph.get_operations()) - filtered_outputs = [x for x in self._python_returns if x is not None] + filtered_outputs = [x for x in self._python_returns if x is not None] + captures = {} + backwards_graph = CapturingGraph(captures) + backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access + for collection in self._graph.collections: + backwards_graph.get_collection_ref( + collection)[:] = self._graph.get_collection(collection) + backwards_graph.seed = self._graph.seed + with backwards_graph.as_default(): self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] - in_gradients = gradients_impl.gradients( + in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access filtered_outputs, self._input_placeholders, - grad_ys=self._out_grad_placeholders) - for op in self._graph.get_operations()[existing_op_len:]: - if op.type in ["Variable", "VariableV2", "VarHandleOp"]: - raise ValueError("tfe.defun cannot capture variables created without " - "using tf.get_variable. Op: %s" % op) - c_known_ops.add(op) - for i in op.inputs: - if i.op not in c_known_ops: - c_captured_tensors.add(i) + grad_ys=self._out_grad_placeholders, + src_graph=self._graph) backward_outputs = tuple( grad for grad in _flatten(in_gradients) if grad is not None) output_shapes = tuple(grad.shape for grad in backward_outputs) - captures = list(sorted(c_captured_tensors, key=lambda x: x.name)) + ids = list(sorted(captures.keys())) + if ids: + extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) + else: + extra_inputs = [] + extra_placeholders = [] + forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + captures, self._attrs) - all_inputs = self._out_grad_placeholders + captures + filtered_outputs + list(extra_inputs), self._attrs) + all_inputs = self._out_grad_placeholders + list(extra_placeholders) # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. all_ignored_ops = frozenset(x.op for x in all_inputs) @@ -412,11 +510,12 @@ class GraphModeFunction(object): # means rerunning the function-defining code will always define the same # function, which is useful if we serialize this etc. function_def_ops = tuple(x - for x in sorted(c_known_ops, key=lambda x: x.name) + for x in sorted(backwards_graph.get_operations(), + key=lambda x: x.name) if x not in all_ignored_ops) bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( - bname, all_inputs, [], self._graph, function_def_ops, + bname, all_inputs, [], backwards_graph, function_def_ops, backward_outputs, in_gradients, output_shapes, attrs=self._attrs) def _backprop_call(self, args): @@ -430,35 +529,10 @@ class GraphModeFunction(object): The call output. """ all_args = args + self._extra_inputs - signature = self._forward_fdef.signature ctx = context.context() - if ctx.executing_eagerly(): - outputs = execute.execute( - str(signature.name), - num_outputs=len(signature.output_arg), - inputs=all_args, - attrs=None, - ctx=ctx) - if not outputs: - return None - else: - g = ops.get_default_graph() - g._add_function(self._forward_fdef) # pylint: disable=protected-access - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - outputs = op.outputs - if not outputs: - return op - outputs = [outputs] if isinstance(outputs, ops.Tensor) else list(outputs) - - shapes = [shape for shape in self._output_shapes if shape is not None] - for i, shape in enumerate(shapes): - outputs[i].set_shape(shape) + outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes) + if isinstance(outputs, ops.Operation) or outputs is None: + return outputs # `real_outputs` are the actual outputs of the inference graph function; # `side_outputs` are the intermediate Tensors that were added as outputs to @@ -470,7 +544,7 @@ class GraphModeFunction(object): return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable tape.record_operation( - signature.name, + self._forward_fdef.signature.name, real_outputs, (args + self._extra_inputs), backward_function) @@ -512,13 +586,6 @@ class GraphModeFunction(object): """Returns the name of the function in Eager-compatible format.""" return self._function_def.name.encode("utf-8") - def add_to_graph(self, g): - if self._function_def.name not in g._functions: # pylint: disable=protected-access - g._add_function(self._function_def) # pylint: disable=protected-access - for f in self._graph._functions.values(): # pylint: disable=protected-access - if f.name not in g._functions: # pylint: disable=protected-access - g._add_function(f) # pylint: disable=protected-access - def __call__(self, *args): """Executes the passed function in eager mode.""" for v in self._variables: @@ -533,34 +600,9 @@ class GraphModeFunction(object): return self._backprop_call(tensor_inputs) ctx = context.context() - if ctx.executing_eagerly(): - result = execute.execute( - str(self._func_name), - num_outputs=self._num_outputs, - inputs=tensor_inputs + self._extra_inputs, - attrs=None, - ctx=ctx) - else: - g = ops.get_default_graph() - self.add_to_graph(g) - signature = self._function_def.definition.signature - args = list(tensor_inputs) + self._extra_inputs - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - result = op.outputs - if not result: - return op - - shapes = [shape for shape in self._output_shapes if shape is not None] - for i, shape in enumerate(shapes): - result[i].set_shape(shape) - - return self._build_call_outputs(result) + args = tensor_inputs + self._extra_inputs + outputs = self._function_def.call(ctx, args, self._output_shapes) + return self._build_call_outputs(outputs) def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -571,7 +613,8 @@ class GraphModeFunction(object): The actual call output. """ if self._python_func_outputs is None: - return None + return result + # Use `nest.flatten` instead of `_flatten` in order to preserve any # IndexedSlices in `self._python_func_outputs`. outputs_list = nest.flatten(self._python_func_outputs) @@ -617,55 +660,58 @@ def _deterministic_dict_values(kwds): def _trace_and_define_function(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - with context.graph_mode(): - captures = {} - tmp_graph = CapturingGraph(captures) - # Inherit the graph key, since this is used for matching variables in - # optimizers. - tmp_graph._graph_key = graph_key # pylint: disable=protected-access - # Copy the graph collections to ensure summaries and other things work. This - # lets the function access (but not mutate) collections of the containing - # graph, such as the global step and the summary writer collections. - curr_graph = ops.get_default_graph() - for collection in curr_graph.collections: - tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( - collection) - with tmp_graph.as_default(), AutomaticControlDependencies() as a: - func_args = _get_defun_inputs(args) - func_kwds = _get_defun_inputs(kwds) - - def convert(x): - if x is None: - return None - x = ops.convert_to_tensor_or_indexed_slices(x) - x = a.mark_as_return(x) - return x + captures = {} + tmp_graph = CapturingGraph(captures) + # Inherit the graph key, since this is used for matching variables in + # optimizers. + tmp_graph._graph_key = graph_key # pylint: disable=protected-access + # Copy the graph collections to ensure summaries and other things work. This + # lets the function access (but not mutate) collections of the containing + # graph, such as the global step and the summary writer collections. + curr_graph = ops.get_default_graph() + for collection in curr_graph.collections: + tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( + collection) + if context.executing_eagerly(): + tmp_graph.seed = context.global_seed() + else: + tmp_graph.seed = curr_graph.seed + with tmp_graph.as_default(), AutomaticControlDependencies() as a: + func_args = _get_defun_inputs(args) + func_kwds = _get_defun_inputs(kwds) - this_tape = tape.push_new_tape() - try: - func_outputs = func(*func_args, **func_kwds) - func_outputs = nest.map_structure(convert, func_outputs) - finally: - tape.pop_tape(this_tape) - variables = this_tape.watched_variables() - - # Returning a closed-over tensor as an output does not trigger a - # call to convert_to_tensor, so we manually capture all such tensors. - outputs_list = _flatten(func_outputs) - func_def_outputs = [ - tmp_graph.capture(x) for x in outputs_list - if x is not None - ] - - ids = list(sorted(captures.keys())) - if ids: - extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) - else: - extra_inputs = [] - extra_placeholders = [] - output_shapes = tuple( - x.shape if isinstance(x, ops.Tensor) else None - for x in outputs_list) + def convert(x): + if x is None: + return None + x = ops.convert_to_tensor_or_indexed_slices(x) + x = a.mark_as_return(x) + return x + + this_tape = tape.push_new_tape() + try: + func_outputs = func(*func_args, **func_kwds) + func_outputs = nest.map_structure(convert, func_outputs) + finally: + tape.pop_tape(this_tape) + variables = this_tape.watched_variables() + + # Returning a closed-over tensor as an output does not trigger a + # call to convert_to_tensor, so we manually capture all such tensors. + outputs_list = _flatten(func_outputs) + func_def_outputs = [ + tmp_graph.capture(x) for x in outputs_list + if x is not None + ] + + ids = list(sorted(captures.keys())) + if ids: + extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) + else: + extra_inputs = [] + extra_placeholders = [] + output_shapes = tuple( + x.shape if isinstance(x, ops.Tensor) else None + for x in func_def_outputs) func_kwds_values = _deterministic_dict_values(func_kwds) flat_inputs = [ @@ -686,7 +732,7 @@ def _trace_and_define_function(name, func, compiled, args, kwds): attrs = {} if compiled: - attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True) + attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True) return GraphModeFunction( fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, @@ -731,6 +777,11 @@ class _PolymorphicFunction(object): See the documentation for `defun` for more information on the semantics of defined functions. + + _PolymorphicFunction class is thread-compatible meaning that minimal + usage of defuns (defining and calling) is thread-safe, but if users call other + methods or invoke the base `python_function` themselves, external + synchronization is necessary. """ def __init__(self, python_function, name, compiled=False): @@ -748,6 +799,8 @@ class _PolymorphicFunction(object): self._arguments_to_functions = {} self._variables = [] + self._lock = threading.Lock() + def __get__(self, instance, owner): """Makes it possible to defun instance methods.""" del owner @@ -782,22 +835,30 @@ class _PolymorphicFunction(object): kwd_values = _deterministic_dict_values(kwds) inputs = args + kwd_values signature = tuple(_cache_key(x) for x in inputs) - - if signature not in self._arguments_to_functions: - graph_function = _trace_and_define_function( - self._name, self._python_function, self._compiled, args, kwds) - self._arguments_to_functions[signature] = graph_function - self._variables.extend( - [v for v in graph_function.variables if v not in self._variables]) - return graph_function, inputs - else: - return self._arguments_to_functions[signature], inputs + # The graph, or whether we're executing eagerly, should be a part of the + # signature so we don't improperly capture tensors such as variables. + signature += tuple([context.executing_eagerly() or ops.get_default_graph()]) + + with self._lock: + if signature not in self._arguments_to_functions: + graph_function = _trace_and_define_function( + self._name, self._python_function, self._compiled, args, kwds) + self._arguments_to_functions[signature] = graph_function + self._variables.extend( + [v for v in graph_function.variables if v not in self._variables]) + return graph_function, inputs + else: + return self._arguments_to_functions[signature], inputs def __call__(self, *args, **kwds): """Calls a graph function specialized for this input signature.""" graph_function, inputs = self._maybe_define_function(*args, **kwds) return graph_function(*inputs) + def call_python_function(self, *args, **kwargs): + """Directly calls the wrapped python function.""" + return self._python_function(*args, **kwargs) + @property def variables(self): """Returns a list of variables used in any of the defined functions.""" @@ -835,6 +896,11 @@ def defun(func=None, compiled=False): be hashable Python objects or lists thereof. Additionally, it must return zero or more @{tf.Tensor} objects. + Executing a graph generated by `defun` respects device annotations (i.e., + all `with tf.device` directives present in a Python function will also be + present in its corresponding graph), but it is not yet possible to execute the + generated graphs across multiple machines. + _Example Usage_ ```python @@ -1014,7 +1080,7 @@ def defun(func=None, compiled=False): tf.enable_eager_execution() def fn(): - x = tf.contrib.eager.Variable(0.0) + x = tf.Variable(0.0) x.assign_add(1.0) return x.read_value() @@ -1031,19 +1097,18 @@ def defun(func=None, compiled=False): ``` Finally, because each input signature is bound to a unique graph, if your - Python function constructs `tf.contrib.eager.Variable` objects, then each - graph constructed for that Python function will reference a unique set of - variables. To circumvent this problem, we recommend against compiling Python - functions that create `tf.contrib.eager.Variable` objects. Instead, Python - functions should either lexically close over `tf.contrib.eager.Variable` - objects or accept them as arguments, preferably encapsulated in an - object-oriented container. If you must create variables inside your Python - function and you want each graph generated for it to reference the same set of - variables, add logic to your Python function that ensures that variables are - only created the first time it is called and are reused for every subsequent - invocation; note that this is precisely what @{tf.keras.layers.Layer} objects - do, so we recommend using them to represent variable-bearing computations - whenever possible. + Python function constructs `tf.Variable` objects, then each graph constructed + for that Python function will reference a unique set of variables. To + circumvent this problem, we recommend against compiling Python functions that + create `tf.Variable` objects. Instead, Python functions should either + lexically close over `tf.Variable` objects or accept them as arguments, + preferably encapsulated in an object-oriented container. If you must create + variables inside your Python function and you want each graph generated for it + to reference the same set of variables, add logic to your Python function that + ensures that variables are only created the first time it is called and are + reused for every subsequent invocation; note that this is precisely what + @{tf.keras.layers.Layer} objects do, so we recommend using them to represent + variable-bearing computations whenever possible. Args: func: function to be compiled. If `func` is None, returns a @@ -1245,7 +1310,7 @@ class AutomaticControlDependencies(object): # Ensures the merge always runs ops_which_must_run.add(new_merge[0].op) if inp in last_op_using_resource_tensor: - # Ensures the switch exectutes after the previous op using the resource. + # Ensures the switch executes after the previous op using the resource. switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access # Ensure the next op outside the cond happens after the merge. last_op_using_resource_tensor[inp] = new_merge[0].op |