aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py202
1 files changed, 108 insertions, 94 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index df83d673ad..5e4f9e29da 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import threading
import numpy as np
@@ -137,7 +138,7 @@ class CapturingGraph(ops.Graph):
inputs[i] = self.capture(inp)
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_shapes, compute_device)
+ compute_device=compute_device)
# pylint: disable=invalid-name
@@ -469,37 +470,39 @@ class GraphModeFunction(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
- with self._graph.as_default(), context.graph_mode():
- c_known_ops = set()
- c_captured_tensors = set()
-
- existing_op_len = len(self._graph.get_operations())
- filtered_outputs = [x for x in self._python_returns if x is not None]
+ filtered_outputs = [x for x in self._python_returns if x is not None]
+ captures = {}
+ backwards_graph = CapturingGraph(captures)
+ backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
+ for collection in self._graph.collections:
+ backwards_graph.get_collection_ref(
+ collection)[:] = self._graph.get_collection(collection)
+ backwards_graph.seed = self._graph.seed
+ with backwards_graph.as_default():
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl.gradients(
+ in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access
filtered_outputs,
self._input_placeholders,
- grad_ys=self._out_grad_placeholders)
- for op in self._graph.get_operations()[existing_op_len:]:
- if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
- raise ValueError("defun cannot capture variables created without "
- "using tf.get_variable. Op: %s" % op)
- c_known_ops.add(op)
- for i in op.inputs:
- if i.op not in c_known_ops:
- c_captured_tensors.add(i)
+ grad_ys=self._out_grad_placeholders,
+ src_graph=self._graph)
backward_outputs = tuple(
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
- captures = list(sorted(c_captured_tensors, key=lambda x: x.name))
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + captures, self._attrs)
- all_inputs = self._out_grad_placeholders + captures
+ filtered_outputs + list(extra_inputs), self._attrs)
+ all_inputs = self._out_grad_placeholders + list(extra_placeholders)
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
all_ignored_ops = frozenset(x.op for x in all_inputs)
@@ -507,11 +510,12 @@ class GraphModeFunction(object):
# means rerunning the function-defining code will always define the same
# function, which is useful if we serialize this etc.
function_def_ops = tuple(x
- for x in sorted(c_known_ops, key=lambda x: x.name)
+ for x in sorted(backwards_graph.get_operations(),
+ key=lambda x: x.name)
if x not in all_ignored_ops)
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
- bname, all_inputs, [], self._graph, function_def_ops,
+ bname, all_inputs, [], backwards_graph, function_def_ops,
backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
@@ -656,55 +660,58 @@ def _deterministic_dict_values(kwds):
def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- captures = {}
- tmp_graph = CapturingGraph(captures)
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- # Copy the graph collections to ensure summaries and other things work. This
- # lets the function access (but not mutate) collections of the containing
- # graph, such as the global step and the summary writer collections.
- curr_graph = ops.get_default_graph()
- for collection in curr_graph.collections:
- tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
- collection)
- with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_args = _get_defun_inputs(args)
- func_kwds = _get_defun_inputs(kwds)
-
- def convert(x):
- if x is None:
- return None
- x = ops.convert_to_tensor_or_indexed_slices(x)
- x = a.mark_as_return(x)
- return x
-
- this_tape = tape.push_new_tape()
- try:
- func_outputs = func(*func_args, **func_kwds)
- func_outputs = nest.map_structure(convert, func_outputs)
- finally:
- tape.pop_tape(this_tape)
- variables = this_tape.watched_variables()
-
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- outputs_list = _flatten(func_outputs)
- func_def_outputs = [
- tmp_graph.capture(x) for x in outputs_list
- if x is not None
- ]
-
- ids = list(sorted(captures.keys()))
- if ids:
- extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
- else:
- extra_inputs = []
- extra_placeholders = []
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in func_def_outputs)
+ captures = {}
+ tmp_graph = CapturingGraph(captures)
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ tmp_graph._graph_key = graph_key # pylint: disable=protected-access
+ # Copy the graph collections to ensure summaries and other things work. This
+ # lets the function access (but not mutate) collections of the containing
+ # graph, such as the global step and the summary writer collections.
+ curr_graph = ops.get_default_graph()
+ for collection in curr_graph.collections:
+ tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+ collection)
+ if context.executing_eagerly():
+ tmp_graph.seed = context.global_seed()
+ else:
+ tmp_graph.seed = curr_graph.seed
+ with tmp_graph.as_default(), AutomaticControlDependencies() as a:
+ func_args = _get_defun_inputs(args)
+ func_kwds = _get_defun_inputs(kwds)
+
+ def convert(x):
+ if x is None:
+ return None
+ x = ops.convert_to_tensor_or_indexed_slices(x)
+ x = a.mark_as_return(x)
+ return x
+
+ this_tape = tape.push_new_tape()
+ try:
+ func_outputs = func(*func_args, **func_kwds)
+ func_outputs = nest.map_structure(convert, func_outputs)
+ finally:
+ tape.pop_tape(this_tape)
+ variables = this_tape.watched_variables()
+
+ # Returning a closed-over tensor as an output does not trigger a
+ # call to convert_to_tensor, so we manually capture all such tensors.
+ outputs_list = _flatten(func_outputs)
+ func_def_outputs = [
+ tmp_graph.capture(x) for x in outputs_list
+ if x is not None
+ ]
+
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+ output_shapes = tuple(
+ x.shape if isinstance(x, ops.Tensor) else None
+ for x in func_def_outputs)
func_kwds_values = _deterministic_dict_values(func_kwds)
flat_inputs = [
@@ -770,6 +777,11 @@ class _PolymorphicFunction(object):
See the documentation for `defun` for more information on the semantics of
defined functions.
+
+ _PolymorphicFunction class is thread-compatible meaning that minimal
+ usage of defuns (defining and calling) is thread-safe, but if users call other
+ methods or invoke the base `python_function` themselves, external
+ synchronization is necessary.
"""
def __init__(self, python_function, name, compiled=False):
@@ -787,6 +799,8 @@ class _PolymorphicFunction(object):
self._arguments_to_functions = {}
self._variables = []
+ self._lock = threading.Lock()
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -825,15 +839,16 @@ class _PolymorphicFunction(object):
# signature so we don't improperly capture tensors such as variables.
signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
- if signature not in self._arguments_to_functions:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ with self._lock:
+ if signature not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[signature] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function, inputs
+ else:
+ return self._arguments_to_functions[signature], inputs
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""
@@ -1065,7 +1080,7 @@ def defun(func=None, compiled=False):
tf.enable_eager_execution()
def fn():
- x = tf.contrib.eager.Variable(0.0)
+ x = tf.Variable(0.0)
x.assign_add(1.0)
return x.read_value()
@@ -1082,19 +1097,18 @@ def defun(func=None, compiled=False):
```
Finally, because each input signature is bound to a unique graph, if your
- Python function constructs `tf.contrib.eager.Variable` objects, then each
- graph constructed for that Python function will reference a unique set of
- variables. To circumvent this problem, we recommend against compiling Python
- functions that create `tf.contrib.eager.Variable` objects. Instead, Python
- functions should either lexically close over `tf.contrib.eager.Variable`
- objects or accept them as arguments, preferably encapsulated in an
- object-oriented container. If you must create variables inside your Python
- function and you want each graph generated for it to reference the same set of
- variables, add logic to your Python function that ensures that variables are
- only created the first time it is called and are reused for every subsequent
- invocation; note that this is precisely what @{tf.keras.layers.Layer} objects
- do, so we recommend using them to represent variable-bearing computations
- whenever possible.
+ Python function constructs `tf.Variable` objects, then each graph constructed
+ for that Python function will reference a unique set of variables. To
+ circumvent this problem, we recommend against compiling Python functions that
+ create `tf.Variable` objects. Instead, Python functions should either
+ lexically close over `tf.Variable` objects or accept them as arguments,
+ preferably encapsulated in an object-oriented container. If you must create
+ variables inside your Python function and you want each graph generated for it
+ to reference the same set of variables, add logic to your Python function that
+ ensures that variables are only created the first time it is called and are
+ reused for every subsequent invocation; note that this is precisely what
+ @{tf.keras.layers.Layer} objects do, so we recommend using them to represent
+ variable-bearing computations whenever possible.
Args:
func: function to be compiled. If `func` is None, returns a