diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 202 |
1 files changed, 108 insertions, 94 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index df83d673ad..5e4f9e29da 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +import threading import numpy as np @@ -137,7 +138,7 @@ class CapturingGraph(ops.Graph): inputs[i] = self.capture(inp) return super(CapturingGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_shapes, compute_device) + compute_device=compute_device) # pylint: disable=invalid-name @@ -469,37 +470,39 @@ class GraphModeFunction(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" - with self._graph.as_default(), context.graph_mode(): - c_known_ops = set() - c_captured_tensors = set() - - existing_op_len = len(self._graph.get_operations()) - filtered_outputs = [x for x in self._python_returns if x is not None] + filtered_outputs = [x for x in self._python_returns if x is not None] + captures = {} + backwards_graph = CapturingGraph(captures) + backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access + for collection in self._graph.collections: + backwards_graph.get_collection_ref( + collection)[:] = self._graph.get_collection(collection) + backwards_graph.seed = self._graph.seed + with backwards_graph.as_default(): self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] - in_gradients = gradients_impl.gradients( + in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access filtered_outputs, self._input_placeholders, - grad_ys=self._out_grad_placeholders) - for op in self._graph.get_operations()[existing_op_len:]: - if op.type in ["Variable", "VariableV2", "VarHandleOp"]: - raise ValueError("defun cannot capture variables created without " - "using tf.get_variable. Op: %s" % op) - c_known_ops.add(op) - for i in op.inputs: - if i.op not in c_known_ops: - c_captured_tensors.add(i) + grad_ys=self._out_grad_placeholders, + src_graph=self._graph) backward_outputs = tuple( grad for grad in _flatten(in_gradients) if grad is not None) output_shapes = tuple(grad.shape for grad in backward_outputs) - captures = list(sorted(c_captured_tensors, key=lambda x: x.name)) + ids = list(sorted(captures.keys())) + if ids: + extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) + else: + extra_inputs = [] + extra_placeholders = [] + forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + captures, self._attrs) - all_inputs = self._out_grad_placeholders + captures + filtered_outputs + list(extra_inputs), self._attrs) + all_inputs = self._out_grad_placeholders + list(extra_placeholders) # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. all_ignored_ops = frozenset(x.op for x in all_inputs) @@ -507,11 +510,12 @@ class GraphModeFunction(object): # means rerunning the function-defining code will always define the same # function, which is useful if we serialize this etc. function_def_ops = tuple(x - for x in sorted(c_known_ops, key=lambda x: x.name) + for x in sorted(backwards_graph.get_operations(), + key=lambda x: x.name) if x not in all_ignored_ops) bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( - bname, all_inputs, [], self._graph, function_def_ops, + bname, all_inputs, [], backwards_graph, function_def_ops, backward_outputs, in_gradients, output_shapes, attrs=self._attrs) def _backprop_call(self, args): @@ -656,55 +660,58 @@ def _deterministic_dict_values(kwds): def _trace_and_define_function(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - with context.graph_mode(): - captures = {} - tmp_graph = CapturingGraph(captures) - # Inherit the graph key, since this is used for matching variables in - # optimizers. - tmp_graph._graph_key = graph_key # pylint: disable=protected-access - # Copy the graph collections to ensure summaries and other things work. This - # lets the function access (but not mutate) collections of the containing - # graph, such as the global step and the summary writer collections. - curr_graph = ops.get_default_graph() - for collection in curr_graph.collections: - tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( - collection) - with tmp_graph.as_default(), AutomaticControlDependencies() as a: - func_args = _get_defun_inputs(args) - func_kwds = _get_defun_inputs(kwds) - - def convert(x): - if x is None: - return None - x = ops.convert_to_tensor_or_indexed_slices(x) - x = a.mark_as_return(x) - return x - - this_tape = tape.push_new_tape() - try: - func_outputs = func(*func_args, **func_kwds) - func_outputs = nest.map_structure(convert, func_outputs) - finally: - tape.pop_tape(this_tape) - variables = this_tape.watched_variables() - - # Returning a closed-over tensor as an output does not trigger a - # call to convert_to_tensor, so we manually capture all such tensors. - outputs_list = _flatten(func_outputs) - func_def_outputs = [ - tmp_graph.capture(x) for x in outputs_list - if x is not None - ] - - ids = list(sorted(captures.keys())) - if ids: - extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) - else: - extra_inputs = [] - extra_placeholders = [] - output_shapes = tuple( - x.shape if isinstance(x, ops.Tensor) else None - for x in func_def_outputs) + captures = {} + tmp_graph = CapturingGraph(captures) + # Inherit the graph key, since this is used for matching variables in + # optimizers. + tmp_graph._graph_key = graph_key # pylint: disable=protected-access + # Copy the graph collections to ensure summaries and other things work. This + # lets the function access (but not mutate) collections of the containing + # graph, such as the global step and the summary writer collections. + curr_graph = ops.get_default_graph() + for collection in curr_graph.collections: + tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( + collection) + if context.executing_eagerly(): + tmp_graph.seed = context.global_seed() + else: + tmp_graph.seed = curr_graph.seed + with tmp_graph.as_default(), AutomaticControlDependencies() as a: + func_args = _get_defun_inputs(args) + func_kwds = _get_defun_inputs(kwds) + + def convert(x): + if x is None: + return None + x = ops.convert_to_tensor_or_indexed_slices(x) + x = a.mark_as_return(x) + return x + + this_tape = tape.push_new_tape() + try: + func_outputs = func(*func_args, **func_kwds) + func_outputs = nest.map_structure(convert, func_outputs) + finally: + tape.pop_tape(this_tape) + variables = this_tape.watched_variables() + + # Returning a closed-over tensor as an output does not trigger a + # call to convert_to_tensor, so we manually capture all such tensors. + outputs_list = _flatten(func_outputs) + func_def_outputs = [ + tmp_graph.capture(x) for x in outputs_list + if x is not None + ] + + ids = list(sorted(captures.keys())) + if ids: + extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) + else: + extra_inputs = [] + extra_placeholders = [] + output_shapes = tuple( + x.shape if isinstance(x, ops.Tensor) else None + for x in func_def_outputs) func_kwds_values = _deterministic_dict_values(func_kwds) flat_inputs = [ @@ -770,6 +777,11 @@ class _PolymorphicFunction(object): See the documentation for `defun` for more information on the semantics of defined functions. + + _PolymorphicFunction class is thread-compatible meaning that minimal + usage of defuns (defining and calling) is thread-safe, but if users call other + methods or invoke the base `python_function` themselves, external + synchronization is necessary. """ def __init__(self, python_function, name, compiled=False): @@ -787,6 +799,8 @@ class _PolymorphicFunction(object): self._arguments_to_functions = {} self._variables = [] + self._lock = threading.Lock() + def __get__(self, instance, owner): """Makes it possible to defun instance methods.""" del owner @@ -825,15 +839,16 @@ class _PolymorphicFunction(object): # signature so we don't improperly capture tensors such as variables. signature += tuple([context.executing_eagerly() or ops.get_default_graph()]) - if signature not in self._arguments_to_functions: - graph_function = _trace_and_define_function( - self._name, self._python_function, self._compiled, args, kwds) - self._arguments_to_functions[signature] = graph_function - self._variables.extend( - [v for v in graph_function.variables if v not in self._variables]) - return graph_function, inputs - else: - return self._arguments_to_functions[signature], inputs + with self._lock: + if signature not in self._arguments_to_functions: + graph_function = _trace_and_define_function( + self._name, self._python_function, self._compiled, args, kwds) + self._arguments_to_functions[signature] = graph_function + self._variables.extend( + [v for v in graph_function.variables if v not in self._variables]) + return graph_function, inputs + else: + return self._arguments_to_functions[signature], inputs def __call__(self, *args, **kwds): """Calls a graph function specialized for this input signature.""" @@ -1065,7 +1080,7 @@ def defun(func=None, compiled=False): tf.enable_eager_execution() def fn(): - x = tf.contrib.eager.Variable(0.0) + x = tf.Variable(0.0) x.assign_add(1.0) return x.read_value() @@ -1082,19 +1097,18 @@ def defun(func=None, compiled=False): ``` Finally, because each input signature is bound to a unique graph, if your - Python function constructs `tf.contrib.eager.Variable` objects, then each - graph constructed for that Python function will reference a unique set of - variables. To circumvent this problem, we recommend against compiling Python - functions that create `tf.contrib.eager.Variable` objects. Instead, Python - functions should either lexically close over `tf.contrib.eager.Variable` - objects or accept them as arguments, preferably encapsulated in an - object-oriented container. If you must create variables inside your Python - function and you want each graph generated for it to reference the same set of - variables, add logic to your Python function that ensures that variables are - only created the first time it is called and are reused for every subsequent - invocation; note that this is precisely what @{tf.keras.layers.Layer} objects - do, so we recommend using them to represent variable-bearing computations - whenever possible. + Python function constructs `tf.Variable` objects, then each graph constructed + for that Python function will reference a unique set of variables. To + circumvent this problem, we recommend against compiling Python functions that + create `tf.Variable` objects. Instead, Python functions should either + lexically close over `tf.Variable` objects or accept them as arguments, + preferably encapsulated in an object-oriented container. If you must create + variables inside your Python function and you want each graph generated for it + to reference the same set of variables, add logic to your Python function that + ensures that variables are only created the first time it is called and are + reused for every subsequent invocation; note that this is precisely what + @{tf.keras.layers.Layer} objects do, so we recommend using them to represent + variable-bearing computations whenever possible. Args: func: function to be compiled. If `func` is None, returns a |