aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py656
1 files changed, 486 insertions, 170 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 29e234efd8..f87d88040f 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -24,6 +24,7 @@ import functools
import threading
import numpy as np
+import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -35,69 +36,77 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import distribute
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def create_substitute_placeholder(value, name, dtype=None):
+ """Creates a placeholder for `value` and propagates shape info to it."""
+ # Note: setting ops.control_dependencies(None) ensures we always put
+ # capturing placeholders outside of any control flow context.
+ with ops.control_dependencies(None):
+ placeholder = graph_placeholder(
+ dtype=dtype or value.dtype, shape=value.shape, name=name)
+ if placeholder.dtype == dtypes_module.resource:
+ if isinstance(value, ops.EagerTensor):
+ handle_data = value._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(value)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetResourceHandleShapeAndType(
+ placeholder.graph._c_graph, placeholder._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ placeholder._op._graph._c_graph, # pylint: disable=protected-access
+ placeholder._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+ return placeholder
def capture_value(tensor_map, value, dtype, name):
"""Capture a value from outside the function, to pass in as an extra arg."""
- captured_value = tensor_map.get(ops.tensor_id(value), None)
+ captured_value = tensor_map.get(value, None)
if captured_value is None:
- # Note: setting ops.control_dependencies(None) ensures we always put
- # capturing placeholders outside of any control flow context.
- with ops.control_dependencies(None):
- captured_value = graph_placeholder(
- dtype=dtype or value.dtype, shape=value.shape, name=name)
- if captured_value.dtype == dtypes_module.resource:
- if ops._USE_C_SHAPES: # pylint: disable=protected-access
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
- else:
- handle_data = value._handle_data # pylint: disable=protected-access
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- if ops._USE_C_SHAPES:
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- captured_value.graph._c_graph, captured_value._as_tf_output(),
- handle_data.SerializeToString())
- else:
- captured_value._handle_data = handle_data
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- captured_value._op._graph._c_graph, # pylint: disable=protected-access
- captured_value._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
- tensor_map[ops.tensor_id(value)] = (value, captured_value)
- else:
- captured_value = captured_value[1]
+ captured_value = create_substitute_placeholder(value, name=name,
+ dtype=dtype)
+ tensor_map[value] = captured_value
tape.record_operation("captured_value", [captured_value], [value],
lambda x: [x])
return captured_value
class CapturingGraph(ops.Graph):
- """Graph used when constructing eager functions."""
+ """Graph that can capture tensors from other graphs.
+
+ Attributes:
+ captures: Maps external tensor -> internal tensor (e.g. input placeholder).
+ The entries are in the order they were captured.
+ """
def __init__(self):
super(CapturingGraph, self).__init__()
+
+ self.captures = collections.OrderedDict()
self._building_function = True
- # Maps external tensor id -> internal tensor (e.g. input placeholder).
- self.captures = {}
+
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
@@ -110,7 +119,22 @@ class CapturingGraph(ops.Graph):
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
+ # TODO(skyewm): get rid of name and use the name of `tensor`.
def capture(self, tensor, name=None):
+ """Capture `tensor` if it's external to this graph.
+
+ If `tensor` is from a different graph, returns a placeholder for it.
+ `tensor` and the placeholder will also appears in self.captures. Multiple
+ calls to this method with the same `tensor` argument will return the same
+ placeholder. If `tensor` is from this graph, returns `tensor`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+
+ Returns:
+ Tensor from this FuncGraph.
+ """
if isinstance(tensor, ops.EagerTensor):
if name is None:
name = str(ops.uid())
@@ -132,6 +156,7 @@ class CapturingGraph(ops.Graph):
op_def=None,
compute_shapes=True,
compute_device=True):
+ """Captures an external inputs before calling Graph.capture_op."""
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
@@ -154,6 +179,70 @@ class CapturingGraph(ops.Graph):
compute_device=compute_device)
+class FuncGraph(CapturingGraph):
+ """Graph representing a function body.
+
+ Attributes:
+ name: The name of the function.
+
+ inputs: Placeholder tensors representing the inputs to this function. The
+ tensors are in this FuncGraph. This represents "regular" inputs as well as
+ captured inputs (i.e. the values of self.captures), with the regular
+ inputs coming first.
+ outputs: Tensors that will be returned by this function. The tensors are in
+ this FuncGraph.
+ structured_outputs: A possibly-nested python object which will be returned
+ by this function. The Tensors in this structure are the same as those of
+ self.outputs. Note that this structure might contain Python `None`s.
+ variables: Variables that should be watched during function execution.
+ seed: The graph-level random seed.
+ """
+
+ def __init__(self, name, graph=None):
+ """Construct a new FuncGraph.
+
+ Args:
+ name: the name of the function.
+ graph: if specified, this FuncGraph will inherit its graph key,
+ collections, and seed from `graph`.
+ """
+ super(FuncGraph, self).__init__()
+
+ self.name = name
+ self.inputs = []
+ self.outputs = []
+ self.structured_outputs = None
+ self.variables = []
+
+ if graph is not None:
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ self._graph_key = graph._graph_key # pylint: disable=protected-access
+
+ # Copy the graph collections to ensure summaries and other things work.
+ # This lets the function access (but not mutate) collections of the
+ # containing graph, such as the global step and the summary writer
+ # collections.
+ for collection in graph.collections:
+ self.get_collection_ref(collection)[:] = graph.get_collection(
+ collection)
+
+ if context.executing_eagerly():
+ self.seed = context.global_seed()
+ else:
+ self.seed = graph.seed
+
+ def capture(self, tensor, name=None):
+ """Calls CapturingGraph.capture and updates self.inputs if necessary."""
+ new_capture = tensor not in self.captures
+ internal_tensor = super(FuncGraph, self).capture(tensor, name)
+
+ if new_capture and tensor is not internal_tensor:
+ self.inputs.append(internal_tensor)
+
+ return internal_tensor
+
+
# pylint: disable=invalid-name
class HelperContext(object):
"""ControlFlowContext with a customizable AddOp method."""
@@ -476,6 +565,20 @@ class GraphModeFunction(object):
self._output_shapes = output_shapes
self._variables = variables if variables is not None else []
+ # Find the variables that are components of something distributed and
+ # put them into a {handle_tensor -> distributed variable object} map.
+ self._distributed_variables = {}
+ strategy = distribute.get_distribution_strategy()
+ for variable in self._variables:
+ # If variable is not distributed, unwrap returns [variable].
+ component_variables = strategy.unwrap(variable)
+ # Only add to the dictionary when the variable is actually distributed,
+ # i.e. more than one component or the component is different from the
+ # variable itself. component_variables cannot be empty.
+ if (len(component_variables) > 1 or component_variables[0] != variable):
+ for component_variable in component_variables:
+ self._distributed_variables[component_variable.handle] = variable
+
@property
def variables(self):
return self._variables
@@ -483,6 +586,7 @@ class GraphModeFunction(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
filtered_outputs = [x for x in self._python_returns if x is not None]
+ # TODO(skyewm): use FuncGraph
backwards_graph = CapturingGraph()
backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
for collection in self._graph.collections:
@@ -502,13 +606,8 @@ class GraphModeFunction(object):
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
- captures = backwards_graph.captures
- ids = list(sorted(captures.keys()))
- if ids:
- extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
- else:
- extra_inputs = []
- extra_placeholders = []
+ extra_inputs = backwards_graph.captures.keys()
+ extra_placeholders = backwards_graph.captures.values()
forward_name = _forward_name(self._func_name)
# Note: we cannot have placeholder ops in the graph or the TPU compilation
@@ -542,13 +641,12 @@ class GraphModeFunction(object):
(Only records results on a tape if the function has outputs)
Args:
- args: The tensor inputs to the function.
+ args: All inputs to the function, including resolved extra inputs
Returns:
The call output.
"""
- all_args = args + self._extra_inputs
ctx = context.context()
- outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes)
+ outputs = self._forward_fdef.call(ctx, args, self._output_shapes)
if isinstance(outputs, ops.Operation) or outputs is None:
return outputs
@@ -564,7 +662,7 @@ class GraphModeFunction(object):
tape.record_operation(
self._forward_fdef.signature.name,
real_outputs,
- (args + self._extra_inputs),
+ args,
backward_function)
return self._build_call_outputs(real_outputs)
@@ -604,21 +702,50 @@ class GraphModeFunction(object):
"""Returns the name of the function in Eager-compatible format."""
return self._function_def.name.encode("utf-8")
+ def _resolve_extra_inputs(self):
+ """Resolve captured distributed variables to their current values.
+
+ Some inputs can be distributed variables. Such variables yield a different
+ component (i.e. actual tf.Variable) variables depending on the context of
+ execution.
+
+ Returns:
+ a list of resolved extra input tensors.
+ """
+ if self._distributed_variables:
+ # Loop over each extra_inputs and check if it corresponds to something
+ # distributed. If so, get its _distributed_container and fetch the
+ # component appropriate for the current execution context.
+ resolved_extra_inputs = self._extra_inputs[:]
+ for i, extra_input in enumerate(self._extra_inputs):
+ distributed_var = self._distributed_variables.get(extra_input, None)
+ if distributed_var is not None:
+ # distributed variables override __getattr__ and substitute the
+ # right component variable. In here, `distributed_var.handle`
+ # actually does the equivalent of
+ # distributed_var.get_current_component_var().handle.
+ resolved_extra_inputs[i] = distributed_var.handle
+ return resolved_extra_inputs
+
+ return self._extra_inputs
+
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
if v.trainable:
tape.watch_variable(v)
+ resolved_extra_inputs = self._resolve_extra_inputs()
+
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ args = tensor_inputs + resolved_extra_inputs
if tape.should_record(tensor_inputs) or tape.should_record(
- self._extra_inputs):
+ resolved_extra_inputs):
if self._backward_function is None:
self._construct_backprop_function()
- return self._backprop_call(tensor_inputs)
+ return self._backprop_call(args)
ctx = context.context()
- args = tensor_inputs + self._extra_inputs
outputs = self._function_def.call(ctx, args, self._output_shapes)
return self._build_call_outputs(outputs)
@@ -659,43 +786,64 @@ class GraphModeFunction(object):
return ret
-def _get_defun_inputs(args):
- """Maps the inputs args to graph inputs."""
- ret = []
- flat_args = nest.flatten(args)
- for a in flat_args:
- if isinstance(a, ops.Tensor):
- ret.append(graph_placeholder(a.dtype, a.shape))
+def _get_defun_inputs_from_signature(signature):
+ """Maps a signature to graph-construction inputs."""
+ function_inputs = [
+ graph_placeholder(spec.dtype, spec.shape)
+ for spec in nest.flatten(signature)
+ ]
+ return nest.pack_sequence_as(signature, function_inputs)
+
+
+def _get_defun_inputs_from_args(args):
+ """Maps python function args to graph-construction inputs."""
+ function_inputs = [
+ graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor)
+ else arg for arg in nest.flatten(args)
+ ]
+ return nest.pack_sequence_as(args, function_inputs)
+
+
+def _trace_and_define_function(name, python_func, compiled, args, kwds,
+ signature=None):
+ """Defines and returns graph-mode version of `python_func`.
+
+ Args:
+ name: an identifier for the function.
+ python_func: the Python function to trace.
+ compiled: whether the graph function should be compiled through XLA.
+ args: the positional args with which the Python function should be called;
+ ignored if a signature is provided.
+ kwds: the keyword args with which the Python function should be called;
+ ignored if a signature is provided.
+ signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
+ and dtypes of the arguments. When a signature is provided, `args` and
+ `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ to `signature`. If `None`, the shapes and dtypes are inferred from the
+ inputs.
+
+ Returns:
+ A GraphModeFunction.
+
+ Raises:
+ TypeError: If any of `python_func`'s return values is neither `None` nor a
+ `Tensor`.
+ """
+ func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
+
+ with func_graph.as_default(), AutomaticControlDependencies() as a:
+ if signature is None:
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwds = _get_defun_inputs_from_args(kwds)
else:
- ret.append(a)
- return nest.pack_sequence_as(args, ret)
-
-
-def _deterministic_dict_values(kwds):
- return tuple(kwds[key] for key in sorted(kwds))
-
-
-def _trace_and_define_function(name, func, compiled, args, kwds):
- """Defines and returns graph-mode version of func."""
- graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- tmp_graph = CapturingGraph()
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- # Copy the graph collections to ensure summaries and other things work. This
- # lets the function access (but not mutate) collections of the containing
- # graph, such as the global step and the summary writer collections.
- curr_graph = ops.get_default_graph()
- for collection in curr_graph.collections:
- tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
- collection)
- if context.executing_eagerly():
- tmp_graph.seed = context.global_seed()
- else:
- tmp_graph.seed = curr_graph.seed
- with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_args = _get_defun_inputs(args)
- func_kwds = _get_defun_inputs(kwds)
+ func_args = _get_defun_inputs_from_signature(signature)
+ func_kwds = {}
+
+ # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
+ func_graph.inputs.extend(
+ x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
+ if isinstance(x, ops.Tensor)
+ )
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
@@ -703,15 +851,23 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
def convert(x):
+ """Converts an argument to a Tensor."""
if x is None:
return None
- x = ops.convert_to_tensor_or_indexed_slices(x)
+ try:
+ x = ops.convert_to_tensor_or_indexed_slices(x)
+ except (ValueError, TypeError):
+ raise TypeError(
+ "To be compatible with tf.contrib.eager.defun, Python functions "
+ "must return zero or more Tensors; in compilation of %s, found "
+ "return value of type %s, which is not a Tensor." %
+ (str(python_func), type(x)))
x = a.mark_as_return(x)
return x
this_tape = tape.push_new_tape()
try:
- func_outputs = func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwds)
func_outputs = nest.map_structure(convert, func_outputs)
def check_mutation(n1, n2):
@@ -734,41 +890,39 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
finally:
tape.pop_tape(this_tape)
- variables = this_tape.watched_variables()
+ func_graph.structured_outputs = func_outputs
+ variables = list(this_tape.watched_variables())
+
+ # Some variables captured by the tape can come from a DistributedValue.
+ # At call time, DistributedValue can return another variable (e.g. if
+ # the function is run on a different device). Thus, instead of storing
+ # the specific captured variable, we replace it with its distributed
+ # container.
+ strategy = distribute.get_distribution_strategy()
+ for i, variable in enumerate(variables):
+ # If variable is not distributed value_container returns itself.
+ variables[i] = strategy.value_container(variable)
+
+ func_graph.variables = variables
# Returning a closed-over tensor as an output does not trigger a
# call to convert_to_tensor, so we manually capture all such tensors.
- outputs_list = _flatten(func_outputs)
- func_def_outputs = [
- tmp_graph.capture(x) for x in outputs_list
+ func_graph.outputs.extend(
+ func_graph.capture(x) for x in _flatten(func_graph.structured_outputs)
if x is not None
- ]
+ )
- captures = tmp_graph.captures
- ids = list(sorted(captures.keys()))
- if ids:
- extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
- else:
- extra_inputs = []
- extra_placeholders = []
output_shapes = tuple(
x.shape if isinstance(x, ops.Tensor) else None
- for x in func_def_outputs)
+ for x in func_graph.outputs)
- func_kwds_values = _deterministic_dict_values(func_kwds)
- flat_inputs = [
- x for x in nest.flatten(func_args) + nest.flatten(func_kwds_values)
- if isinstance(x, ops.Tensor)
- ]
- all_inputs = flat_inputs + list(extra_placeholders)
- all_ignored_ops = frozenset(x.op for x in all_inputs)
- fname = _inference_name(name)
- operations = tuple(x for x in tmp_graph.get_operations()
+ all_ignored_ops = frozenset(x.op for x in func_graph.inputs)
+ operations = tuple(x for x in func_graph.get_operations()
if x not in all_ignored_ops)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
if context.executing_eagerly():
- for f in tmp_graph._functions.values(): # pylint: disable=protected-access
+ for f in func_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
@@ -777,41 +931,55 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
return GraphModeFunction(
- fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
- func_outputs, output_shapes, variables, attrs)
+ func_graph.name, func_graph.inputs, func_graph.captures.keys(),
+ func_graph, operations, func_graph.outputs, func_graph.structured_outputs,
+ output_shapes, func_graph.variables, attrs)
+
+_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
-# Defun uses this instead of Tensor as a cache key. Using dtype because
-# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
-# performance reasons, as much TensorFlow code specializes on known shapes to
-# produce slimmer graphs.
-_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
-_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
+def _encode_arg(arg):
+ """A canonical representation for this argument, for use in a cache key."""
-def _cache_key(x):
- """Cache key for tfe functions."""
- if isinstance(x, ops.Tensor):
- return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
- if isinstance(x, ops.IndexedSlices):
- if x.dense_shape is not None:
+ # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
+ # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
+ # are used for both performance reasons, as much TensorFlow code specializes
+ # on known shapes to produce slimmer graphs, and correctness, as some
+ # high-level APIs require shapes to be fully-known.
+ #
+ # TODO(akshayka): Add support for sparse tensors.
+ #
+ # pylint: disable=protected-access
+ if isinstance(arg, ops.Tensor):
+ return _TensorType(arg.dtype, arg._shape_tuple())
+ elif isinstance(arg, ops.IndexedSlices):
+ if arg.dense_shape is not None:
return tuple([
- _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access
+ _TensorType(arg.values.dtype, arg.values._shape_tuple()),
+ _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
+ _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
])
else:
return tuple([
- _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access
+ _TensorType(arg.values.dtype, arg.values._shape_tuple()),
+ _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- if isinstance(x, np.ndarray):
- return ("array", x.shape, tuple(x.reshape(-1)))
- if isinstance(x, (list, tuple)):
- return tuple([_cache_key(a) for a in x])
- if isinstance(x, dict):
- return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
- return x
+ elif isinstance(arg, np.ndarray):
+ tensor = ops.convert_to_tensor(arg)
+ return _TensorType(tensor.dtype, tensor._shape_tuple())
+ # pylint: enable=protected-access
+ elif isinstance(arg, (list, tuple)):
+ return tuple([_encode_arg(elem) for elem in arg])
+ elif isinstance(arg, dict):
+ return tuple(
+ (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
+ else:
+ return arg
+
+
+def _deterministic_dict_values(dictionary):
+ return tuple(dictionary[key] for key in sorted(dictionary))
class _PolymorphicFunction(object):
@@ -826,16 +994,37 @@ class _PolymorphicFunction(object):
synchronization is necessary.
"""
- def __init__(self, python_function, name, compiled=False):
+ def __init__(self,
+ python_function,
+ name,
+ input_signature=None,
+ compiled=False):
"""Initializes a polymorphic function.
Args:
python_function: the function to be wrapped.
name: the name given to it.
+ input_signature: a possibly nested sequence of `TensorSpec` objects
+ specifying the input signature of this function. If `None`, a separate
+ function is instantiated for each inferred input signature.
compiled: if True, the framework will attempt to compile func with XLA.
+
+ Raises:
+ ValueError: if `input_signature` is not None and the `python_function`'s
+ argspec has keyword arguments.
+ TypeError: if `input_signature` contains anything other than
+ `TensorSpec` objects, or (if not None) is anything other than a tuple or
+ list.
"""
- self._python_function = python_function
+ if isinstance(python_function, functools.partial):
+ self._python_function = python_function.func
+ self._args_to_prepend = python_function.args or tuple()
+ self._kwds_to_include = python_function.keywords or {}
+ else:
+ self._python_function = python_function
+ self._args_to_prepend = tuple()
+ self._kwds_to_include = {}
self._name = name
self._compiled = compiled
self._arguments_to_functions = {}
@@ -843,6 +1032,41 @@ class _PolymorphicFunction(object):
self._lock = threading.Lock()
+ fullargspec = tf_inspect.getfullargspec(self._python_function)
+ if tf_inspect.ismethod(self._python_function):
+ # Remove `self`: default arguments shouldn't be matched to it.
+ args = fullargspec.args[1:]
+ else:
+ args = fullargspec.args
+
+ # A cache mapping from argument name to index, for canonicalizing
+ # arguments that are called in a keyword-like fashion.
+ self._args_to_indices = {arg: i for i, arg in enumerate(args)}
+ # A cache mapping from arg index to default value, for canonicalization.
+ offset = len(args) - len(fullargspec.defaults or [])
+ self._arg_indices_to_default_values = {
+ offset + index: default
+ for index, default in enumerate(fullargspec.defaults or [])
+ }
+ if input_signature is None:
+ self._input_signature = None
+ else:
+ if fullargspec.varkw is not None or fullargspec.kwonlyargs:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+
+ if not isinstance(input_signature, (tuple, list)):
+ raise TypeError("input_signature must be either a tuple or a "
+ "list, received " + str(type(input_signature)))
+
+ self._input_signature = tuple(input_signature)
+ self._flat_input_signature = tuple(nest.flatten(input_signature))
+ if any(not isinstance(arg, tensor_spec.TensorSpec)
+ for arg in self._flat_input_signature):
+ raise TypeError("Invalid input_signature %s; input_signature must be "
+ "a possibly nested sequence of TensorSpec objects.")
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -861,36 +1085,119 @@ class _PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
+ def _cache_key(self, args, kwds):
+ """Computes the cache key given inputs."""
+ if self._input_signature is None:
+ inputs = (args, kwds) if kwds else args
+ cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ else:
+ del args, kwds
+ cache_key = self._flat_input_signature
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ return cache_key + (context.executing_eagerly() or ops.get_default_graph(),)
+
+ def _canonicalize_function_inputs(self, *args, **kwds):
+ """Canonicalizes `args` and `kwds`.
+
+ Canonicalize the inputs to the Python function using its fullargspec. In
+ particular, we parse the varags and kwargs that this
+ `_PolymorphicFunction` was called with into a tuple corresponding to the
+ Python function's positional (named) arguments and a dictionary
+ corresponding to its kwargs.
+
+ Args:
+ *args: The varargs this object was called with.
+ **kwds: The keyword args this function was called with.
+
+ Returns:
+ A canonicalized ordering of the inputs.
+
+ Raises:
+ ValueError: If a keyword in `kwds` cannot be matched with a positional
+ argument when an input signature is specified, or when the inputs
+ do not conform to the input signature.
+ """
+ args = self._args_to_prepend + args
+ kwds = dict(kwds, **self._kwds_to_include)
+ # Maps from index of arg to its corresponding value, according to `args`
+ # and `kwds`; seeded with the default values for the named args that aren't
+ # in `args`.
+ arg_indices_to_values = {
+ index: default
+ for index, default in six.iteritems(self._arg_indices_to_default_values)
+ if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwds):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwds` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwds.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if self._input_signature is None:
+ return inputs, kwds
+ else:
+ assert not kwds
+ try:
+ nest.assert_same_structure(self._input_signature, inputs)
+ except (ValueError, TypeError):
+ raise ValueError("Structure of Python function inputs does not match "
+ "input_signature.")
+ flat_inputs = nest.flatten(inputs)
+ if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
+ raise ValueError("When input_signature is provided, all inputs to "
+ "the Python function must be Tensors.")
+ tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor)
+ for tensor in flat_inputs]
+ if any(not spec.is_compatible_with(other)
+ for spec, other in zip(self._flat_input_signature, tensor_specs)):
+ raise ValueError("Python inputs incompatible with input_signature: "
+ "inputs (%s), input_signature (%s)" %
+ (str(inputs), str(self._input_signature)))
+ return inputs, {}
+
def _maybe_define_function(self, *args, **kwds):
"""Gets a function for these inputs, defining it if necessary.
Args:
- *args: args for the Python function; used to compute the signature
- **kwds: kwds for the Python function; used to compute the signature
+ *args: args for the Python function.
+ **kwds: keywords for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
kwds, as well as the inputs that the object should be called with.
- """
- # TODO(apassos): Better error messages for non-hashable arguments.
- kwd_values = _deterministic_dict_values(kwds)
- inputs = args + kwd_values
- signature = tuple(_cache_key(x) for x in inputs)
- # The graph, or whether we're executing eagerly, should be a part of the
- # signature so we don't improperly capture tensors such as variables.
- signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
+ Raises:
+ ValueError: If inputs are incompatible with the input signature.
+ TypeError: If the function inputs include non-hashable objects
+ """
+ args, kwds = self._canonicalize_function_inputs(*args, **kwds)
+ cache_key = self._cache_key(args, kwds)
with self._lock:
- if signature not in self._arguments_to_functions:
+ try:
+ graph_function = self._arguments_to_functions.get(cache_key, None)
+ except TypeError:
+ raise TypeError("Arguments supplied to `defun`-generated functions "
+ "must be hashable.")
+
+ if graph_function is None:
graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
+ self._name, self._python_function, self._compiled, args, kwds,
+ self._input_signature)
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ self._arguments_to_functions[cache_key] = graph_function
+ return graph_function, (args, kwds)
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""
@@ -910,11 +1217,11 @@ class _PolymorphicFunction(object):
# TODO(akshayka): Remove the `compiled` flag and create a separate
# API for xla compilation (`defun` is already complicated enough
# as it is, and the keyword argument makes 'compiled' an overloaded concept)
-def defun(func=None, compiled=False):
+def defun(func=None, input_signature=None, compiled=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
- composed of TensorFlow operations into a callable that executes a @{tf.Graph}
+ composed of TensorFlow operations into a callable that executes a `tf.Graph`
containing those operations. The callable produced by `defun` contains only
the subgraph of TensorFlow operations that were executed when the Python
function was called with a particular input signature, defined as a list
@@ -937,9 +1244,9 @@ def defun(func=None, compiled=False):
For a Python function to be compatible with `defun`, all of its arguments must
be hashable Python objects or lists thereof. The function itself may not
modify the list/map structure of its arguments. Additionally, it must return
- zero or more @{tf.Tensor} objects. If the Python function returns
- a @{tf.Variable}, its compiled version will return the value of that variable
- as a @{tf.Tensor}.
+ zero or more `tf.Tensor` objects. If the Python function returns
+ a `tf.Variable`, its compiled version will return the value of that variable
+ as a `tf.Tensor`.
Executing a graph generated by `defun` respects device annotations (i.e.,
all `with tf.device` directives present in a Python function will also be
@@ -1008,7 +1315,7 @@ def defun(func=None, compiled=False):
When using `defun`, there are subtleties regarding inputs, Python control
flow, and variable creation that one should be aware of. For concreteness, let
- `f` be a Python function that returns zero or more @{tf.Tensor} objects and
+ `f` be a Python function that returns zero or more `tf.Tensor` objects and
let `F = defun(f)`. `F` builds a graph for each unique input signature it
sees, Python control flow is baked into graphs, and operations related to
variable initialization are automatically lifted out of the graphs that `F`
@@ -1091,10 +1398,10 @@ def defun(func=None, compiled=False):
On the other hand, because `defun` generates graphs by tracing and not by
source code analysis, it fully unrolls Python `for` and `while` loops,
potentially creating large graphs. If your Python function has native loops
- that run for many iterations, consider replacing them with @{tf.while_loop}
+ that run for many iterations, consider replacing them with `tf.while_loop`
operations.
- When constructing graphs, @{tf.Tensor} objects cannot be used as Python
+ When constructing graphs, `tf.Tensor` objects cannot be used as Python
`bool` objects. This means, for example, that you should replace code in `f`
resembling
@@ -1113,7 +1420,7 @@ def defun(func=None, compiled=False):
automatically lifted out of the graphs generated by `defun`. In practice, this
implies that variable creation and initialization only happen the first time
`F` is called, and that variables are reused every time thereafter. Many
- TensorFlow APIs, like @{tf.keras.layers.Layer} objects, create variables the
+ TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
first time they are called and reuse them thereafter. Automatic variable
lifting makes it possible to compile these APIs without extra effort, at the
cost of introducing a discrepancy between the semantics of executing Python
@@ -1152,7 +1459,7 @@ def defun(func=None, compiled=False):
to reference the same set of variables, add logic to your Python function that
ensures that variables are only created the first time it is called and are
reused for every subsequent invocation; note that this is precisely what
- @{tf.keras.layers.Layer} objects do, so we recommend using them to represent
+ `tf.keras.layers.Layer` objects do, so we recommend using them to represent
variable-bearing computations whenever possible.
Args:
@@ -1165,6 +1472,13 @@ def defun(func=None, compiled=False):
def foo(...):
...
+ input_signature: A possibly nested sequence of
+ `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
+ the Tensors that will be supplied to this function. If `None`, a separate
+ function is instantiated for each inferred input signature. If a
+ signature is specified, every input to `func` must be a `Tensor`, and
+ `func` cannot accept `**kwargs`.
+
compiled: If True, an attempt to compile `func` with XLA will be made.
If it fails, function will be run normally. Experimental. Currently
supported only for execution on TPUs. For the vast majority of users,
@@ -1183,7 +1497,9 @@ def defun(func=None, compiled=False):
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
- function, _PolymorphicFunction(function, name, compiled=compiled))
+ function,
+ _PolymorphicFunction(
+ function, name, input_signature=input_signature, compiled=compiled))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None: