aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py153
1 files changed, 62 insertions, 91 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f87d88040f..5afba466bc 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -42,7 +42,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import distribute
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -227,6 +228,9 @@ class FuncGraph(CapturingGraph):
self.get_collection_ref(collection)[:] = graph.get_collection(
collection)
+ # Copy distribution strategy scope from the containing graph as well.
+ self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
+
if context.executing_eagerly():
self.seed = context.global_seed()
else:
@@ -243,78 +247,6 @@ class FuncGraph(CapturingGraph):
return internal_tensor
-# pylint: disable=invalid-name
-class HelperContext(object):
- """ControlFlowContext with a customizable AddOp method."""
-
- def __init__(self, add_op_internal):
- self._add_op_internal = add_op_internal
- self._values = set() # control flow code sometimes updates this.
-
- def _AddOpInternal(self, op):
- self._add_op_internal(op)
-
- @property
- def outer_context(self):
- return self._outer_context
-
- def GetWhileContext(self):
- if self._outer_context:
- return self._outer_context.GetWhileContext()
-
- def IsWhileContext(self):
- return False
-
- def IsCondContext(self):
- return False
-
- def IsXLAContext(self):
- return False
-
- def AddOp(self, op): # pylint: disable=invalid-name
- self._AddOpInternal(op)
- if self._outer_context:
- self._outer_context.AddOp(op)
-
- def AddName(self, _):
- pass
-
- def AddInnerOp(self, op):
- self._AddOpInternal(op)
- if self._outer_context:
- self._outer_context.AddInnerOp(op)
-
- def AddValue(self, val):
- if self._outer_context:
- return self._outer_context.AddValue(val)
- else:
- return val
-
- def EnterGradientColocation(self, op, gradient_uid):
- """Start building a gradient colocated with an op."""
- if self._outer_context:
- self._outer_context.EnterGradientColocation(op, gradient_uid)
-
- def ExitGradientColocation(self, op, gradient_uid):
- """Start building a gradient colocated with an op."""
- if self._outer_context:
- self._outer_context.ExitGradientColocation(op, gradient_uid)
-
- def __enter__(self):
- # pylint: disable=protected-access
- self._g = ops.get_default_graph()
- self._outer_context = self._g._get_control_flow_context()
- self._g._set_control_flow_context(self)
- self._nested_contexts = (
- self._outer_context._nested_contexts
- if self._outer_context is not None else None)
- # pylint: enable=protected-access
-
- def __exit__(self, *_):
- self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access
-# pylint: enable=invalid-name
-
-
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, ops.uid())
@@ -479,11 +411,6 @@ class _EagerDefinedFunction(object):
return outputs
-def _map_sequence_obj_to_idx(sequence):
- """Maps objs in the sequence from id(obj) to sequence index."""
- return {id(x): i for i, x in enumerate(sequence)}
-
-
def _flatten(sequence):
"""A wrapper around `nest.flatten` that also unpacks `IndexedSlices`."""
# TODO(akshayka): Support `SparseTensor` in a similar fashion.
@@ -568,7 +495,7 @@ class GraphModeFunction(object):
# Find the variables that are components of something distributed and
# put them into a {handle_tensor -> distributed variable object} map.
self._distributed_variables = {}
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for variable in self._variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
@@ -832,6 +759,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
with func_graph.as_default(), AutomaticControlDependencies() as a:
+ variable_scope.get_variable_scope().set_use_resource(True)
+
if signature is None:
func_args = _get_defun_inputs_from_args(args)
func_kwds = _get_defun_inputs_from_args(kwds)
@@ -898,7 +827,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# the function is run on a different device). Thus, instead of storing
# the specific captured variable, we replace it with its distributed
# container.
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
@@ -1322,18 +1251,60 @@ def defun(func=None, input_signature=None, compiled=False):
generates and placed in the eager context if executing eagerly or into an
outer graph otherwise.
- _Tracing and Input Signatures_.
- The signature of inputs supplied to `F` is defined to be a tuple of the shapes
- and dtypes of Tensor-typed arguments and the values of non-Tensor arguments,
- where "arguments" includes both args and kwargs. Every time `F` is invoked,
- the signature of its inputs are inferred. The first time `F(*args, **kwargs)`
- is invoked with a particular signature, `f(*args, **kwargs)` is executed and
- all the TensorFlow operations that `f` executes, along with the Tensors that
- flow between them, are recorded in a TensorFlow graph. `F` caches this graph
- and binds it to the inputs' signature; every subsequent invocation of `F` with
- inputs conforming to this signature will immediately retrieve the cached graph
- and pass it to the TensorFlow runtime for execution.
+ _Input Signatures_
+ By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
+ for every unique sequence of the shapes and dtypes of Tensor arguments and
+ the values of Python objects it is invoked with. For example, calling
+ `F(tf.random_uniform([2])` will execute a different graph than
+ `F(tf.random_uniform([3])` because the two inputs have different shapes.
+ The first time that `F(*args, **kwargs)` is called with a particular sequence
+ of Tensor shapes and dtypes and Python values, it constructs a graph by
+ tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
+ input signature inferred from `(*args, **kwargs)` and cached for future reuse.
+
+ `tf.contrib.eager.defun` caches graphs for your convenience, letting you
+ define TensorFlow functions without explicitly specifying their signatures.
+ However, this policy is conservative and potentially expensive; for example,
+ when different invocations of your function have differently-shaped Tensor
+ inputs, this policy might generate more graph functions than necessary. To
+ eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
+ optional `input_signature` argument specifying the shapes and dtypes of the
+ inputs. In particular, the shapes may be partially unspecified, with `None`s
+ in the unknown dimensions. When an input signature is provided,
+ `tf.contrib.eager.defun` will only instantiate a single graph for the
+ decorated Python function. The following is an example:
+
+ ```python
+ import tensorflow as tf
+
+ # The first `TensorSpec` below describes the shape and dtype of `words`,
+ # and the second describes the shape and dtype of `another_tensor`. Note that
+ # the last dimension of the `words` `TensorSpec` is left unspecified.
+ @tf.contrib.eager.defun(input_signature=[
+ tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
+ tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
+ ])
+ def my_sequence_model(words, another_tensor):
+ ...
+
+ # Note how the third dimension of the first input can vary freely.
+ words = tf.random_uniform(([50, 300, 10])
+ second_input = tf.random_uniform([300, 100])
+ my_sequence_model(words, second_input)
+
+ words = tf.random_uniform(([50, 300, 20])
+ my_sequence_model(words, second_input)
+
+ # Passing an input with an incompatible shape will raise an error.
+ words = tf.random_uniform(([50, 100, 20])
+ my_sequence_model(words, second_input) # <---- This will raise an error.
+
+ ```
+
+ Python functions that are compiled with an `input_signature` must only accept
+ Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
+ _Tracing_
Be aware that because `F` only logs TensorFlow operations, all the other
Python code that `f` executes will only shape the _construction_ of the graphs
that `F` executes: the Python code won't be executed when the graphs