diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 153 |
1 files changed, 62 insertions, 91 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f87d88040f..5afba466bc 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -42,7 +42,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import distribute +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -227,6 +228,9 @@ class FuncGraph(CapturingGraph): self.get_collection_ref(collection)[:] = graph.get_collection( collection) + # Copy distribution strategy scope from the containing graph as well. + self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access + if context.executing_eagerly(): self.seed = context.global_seed() else: @@ -243,78 +247,6 @@ class FuncGraph(CapturingGraph): return internal_tensor -# pylint: disable=invalid-name -class HelperContext(object): - """ControlFlowContext with a customizable AddOp method.""" - - def __init__(self, add_op_internal): - self._add_op_internal = add_op_internal - self._values = set() # control flow code sometimes updates this. - - def _AddOpInternal(self, op): - self._add_op_internal(op) - - @property - def outer_context(self): - return self._outer_context - - def GetWhileContext(self): - if self._outer_context: - return self._outer_context.GetWhileContext() - - def IsWhileContext(self): - return False - - def IsCondContext(self): - return False - - def IsXLAContext(self): - return False - - def AddOp(self, op): # pylint: disable=invalid-name - self._AddOpInternal(op) - if self._outer_context: - self._outer_context.AddOp(op) - - def AddName(self, _): - pass - - def AddInnerOp(self, op): - self._AddOpInternal(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - if self._outer_context: - return self._outer_context.AddValue(val) - else: - return val - - def EnterGradientColocation(self, op, gradient_uid): - """Start building a gradient colocated with an op.""" - if self._outer_context: - self._outer_context.EnterGradientColocation(op, gradient_uid) - - def ExitGradientColocation(self, op, gradient_uid): - """Start building a gradient colocated with an op.""" - if self._outer_context: - self._outer_context.ExitGradientColocation(op, gradient_uid) - - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._outer_context = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - self._nested_contexts = ( - self._outer_context._nested_contexts - if self._outer_context is not None else None) - # pylint: enable=protected-access - - def __exit__(self, *_): - self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access -# pylint: enable=invalid-name - - def _forward_name(n): """The name of a generated forward defun named n.""" return "__forward_%s_%s" % (n, ops.uid()) @@ -479,11 +411,6 @@ class _EagerDefinedFunction(object): return outputs -def _map_sequence_obj_to_idx(sequence): - """Maps objs in the sequence from id(obj) to sequence index.""" - return {id(x): i for i, x in enumerate(sequence)} - - def _flatten(sequence): """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`.""" # TODO(akshayka): Support `SparseTensor` in a similar fashion. @@ -568,7 +495,7 @@ class GraphModeFunction(object): # Find the variables that are components of something distributed and # put them into a {handle_tensor -> distributed variable object} map. self._distributed_variables = {} - strategy = distribute.get_distribution_strategy() + strategy = distribution_strategy_context.get_distribution_strategy() for variable in self._variables: # If variable is not distributed, unwrap returns [variable]. component_variables = strategy.unwrap(variable) @@ -832,6 +759,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph()) with func_graph.as_default(), AutomaticControlDependencies() as a: + variable_scope.get_variable_scope().set_use_resource(True) + if signature is None: func_args = _get_defun_inputs_from_args(args) func_kwds = _get_defun_inputs_from_args(kwds) @@ -898,7 +827,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, # the function is run on a different device). Thus, instead of storing # the specific captured variable, we replace it with its distributed # container. - strategy = distribute.get_distribution_strategy() + strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. variables[i] = strategy.value_container(variable) @@ -1322,18 +1251,60 @@ def defun(func=None, input_signature=None, compiled=False): generates and placed in the eager context if executing eagerly or into an outer graph otherwise. - _Tracing and Input Signatures_. - The signature of inputs supplied to `F` is defined to be a tuple of the shapes - and dtypes of Tensor-typed arguments and the values of non-Tensor arguments, - where "arguments" includes both args and kwargs. Every time `F` is invoked, - the signature of its inputs are inferred. The first time `F(*args, **kwargs)` - is invoked with a particular signature, `f(*args, **kwargs)` is executed and - all the TensorFlow operations that `f` executes, along with the Tensors that - flow between them, are recorded in a TensorFlow graph. `F` caches this graph - and binds it to the inputs' signature; every subsequent invocation of `F` with - inputs conforming to this signature will immediately retrieve the cached graph - and pass it to the TensorFlow runtime for execution. + _Input Signatures_ + By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph + for every unique sequence of the shapes and dtypes of Tensor arguments and + the values of Python objects it is invoked with. For example, calling + `F(tf.random_uniform([2])` will execute a different graph than + `F(tf.random_uniform([3])` because the two inputs have different shapes. + The first time that `F(*args, **kwargs)` is called with a particular sequence + of Tensor shapes and dtypes and Python values, it constructs a graph by + tracing the execution of `f(*args, **kwargs)`; this graph is bound to an + input signature inferred from `(*args, **kwargs)` and cached for future reuse. + + `tf.contrib.eager.defun` caches graphs for your convenience, letting you + define TensorFlow functions without explicitly specifying their signatures. + However, this policy is conservative and potentially expensive; for example, + when different invocations of your function have differently-shaped Tensor + inputs, this policy might generate more graph functions than necessary. To + eliminate such costs, `tf.contrib.eager.defun` allows you to supply an + optional `input_signature` argument specifying the shapes and dtypes of the + inputs. In particular, the shapes may be partially unspecified, with `None`s + in the unknown dimensions. When an input signature is provided, + `tf.contrib.eager.defun` will only instantiate a single graph for the + decorated Python function. The following is an example: + + ```python + import tensorflow as tf + + # The first `TensorSpec` below describes the shape and dtype of `words`, + # and the second describes the shape and dtype of `another_tensor`. Note that + # the last dimension of the `words` `TensorSpec` is left unspecified. + @tf.contrib.eager.defun(input_signature=[ + tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), + tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) + ]) + def my_sequence_model(words, another_tensor): + ... + + # Note how the third dimension of the first input can vary freely. + words = tf.random_uniform(([50, 300, 10]) + second_input = tf.random_uniform([300, 100]) + my_sequence_model(words, second_input) + + words = tf.random_uniform(([50, 300, 20]) + my_sequence_model(words, second_input) + + # Passing an input with an incompatible shape will raise an error. + words = tf.random_uniform(([50, 100, 20]) + my_sequence_model(words, second_input) # <---- This will raise an error. + + ``` + + Python functions that are compiled with an `input_signature` must only accept + Tensors as arguments and must not take unnamed keyword arguments (**kwargs). + _Tracing_ Be aware that because `F` only logs TensorFlow operations, all the other Python code that `f` executes will only shape the _construction_ of the graphs that `F` executes: the Python code won't be executed when the graphs |