aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py294
1 files changed, 170 insertions, 124 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b28befeb62..93168826b1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import re
import sys
import threading
import weakref
@@ -30,6 +31,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
+from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@@ -61,9 +63,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
# This is to avoid a circular dependency with gradients_impl
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
+BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
-WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [
+ "experimental_.*",
+ FORWARD_FUNCTION_ATTRIBUTE_NAME,
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME
+]
def _create_substitute_placeholder(value, name=None, dtype=None):
@@ -140,10 +148,11 @@ def _parse_func_attrs(attributes):
"""
attrs = {}
for key, value in attributes.items():
- if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ if not any([re.match(reg, key)
+ for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
raise ValueError("Attribute name is not whitelisted. "
"Whitelisted: prefix %s, got: %s" %
- (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+ (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))
if isinstance(value, attr_value_pb2.AttrValue):
attrs[key] = value
@@ -154,7 +163,7 @@ def _parse_func_attrs(attributes):
attrs[key] = attr_value_pb2.AttrValue(i=value)
elif isinstance(value, float):
attrs[key] = attr_value_pb2.AttrValue(f=value)
- elif isinstance(value, str):
+ elif isinstance(value, (str, bytes)):
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
else:
raise ValueError("Unsupported attribute type for %s with type %s" %
@@ -486,6 +495,9 @@ class _EagerDefinedFunction(object):
Returns:
The outputs of the function call.
+
+ Raises:
+ ValueError: if the number of arguments is incorrect.
"""
executing_eagerly = ctx.executing_eagerly()
@@ -519,6 +531,10 @@ class _EagerDefinedFunction(object):
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
+ if len(args) != len(self.signature.input_arg):
+ raise ValueError(
+ "Arguments and signature arguments do not match: %s %s " %
+ (len(args), len(list(self.signature.input_arg))))
outputs = functional_ops.partitioned_call(
args=args,
f=self,
@@ -705,6 +721,7 @@ class Function(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ forward_function_name = _forward_name(self._func_graph.name)
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
@@ -715,11 +732,11 @@ class Function(object):
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph)
- self._forward_function = _EagerDefinedFunction(
- _forward_name(
- self._func_graph.name), self._func_graph, self._func_graph.inputs,
- self._func_graph.outputs + list(backwards_graph.captures.keys()),
- self._attrs)
+ backwards_graph_captures = list(backwards_graph.captures.keys())
+
+ backward_function_attr = _parse_func_attrs(
+ {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+ backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `self._backward_graph_function` correspond to outputs of
@@ -732,7 +749,16 @@ class Function(object):
grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
self._backward_graph_function = Function(
- backwards_graph, attrs=self._attrs)
+ backwards_graph, attrs=backward_function_attr)
+
+ forward_function_attr = _parse_func_attrs({
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+ self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
+ forward_function_attr.update(self._attrs)
+ self._forward_function = _EagerDefinedFunction(
+ forward_function_name, self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + backwards_graph_captures,
+ forward_function_attr)
def _backprop_call(self, args):
"""Calls the forward function and records the result on a tape.
@@ -829,20 +855,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -852,7 +870,8 @@ def func_graph_from_py_func(name,
args,
kwargs,
signature=None,
- func_graph=None):
+ func_graph=None,
+ experimental_autograph=False):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -869,6 +888,8 @@ def func_graph_from_py_func(name,
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
+ experimental_autograph: whether to use autograph to compile `python_func`.
+ See https://www.tensorflow.org/guide/autograph for more information.
Returns:
A FuncGraph.
@@ -883,12 +904,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
@@ -914,7 +935,17 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwargs)
+ if experimental_autograph:
+ func_outputs = autograph.converted_call(
+ python_func,
+ autograph.ConversionOptions(
+ verbose=True,
+ recursive=True,
+ force_conversion=False,
+ strip_decorators=(defun,),
+ arg_types={}), *func_args, **func_kwargs)
+ else:
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -986,52 +1017,8 @@ def func_graph_from_py_func(name,
return func_graph
-_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
-
-
-def _encode_arg(arg):
- """A canonical representation for this argument, for use in a cache key."""
-
- # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
- # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
- # are used for both performance reasons, as much TensorFlow code specializes
- # on known shapes to produce slimmer graphs, and correctness, as some
- # high-level APIs require shapes to be fully-known.
- #
- # TODO(akshayka): Add support for sparse tensors.
- #
- # pylint: disable=protected-access
- if isinstance(arg, ops.Tensor):
- return _TensorType(arg.dtype, arg._shape_tuple())
- elif isinstance(arg, ops.IndexedSlices):
- if arg.dense_shape is not None:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
- ])
- else:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- ])
- # pylint: enable=protected-access
- elif isinstance(arg, (list, tuple)):
- return tuple([_encode_arg(elem) for elem in arg])
- elif isinstance(arg, dict):
- return tuple(
- (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
- else:
- try:
- # If possible, keep only a weak reference to Python objects. Weak
- # references hash to the same value as the original object.
- # TODO(allenl): Clean up dead functions and their cache keys if the cache
- # gets large. Right now creating objects with a defunned method, calling
- # the method, and losing a reference to the object in a loop will leak
- # memory here.
- return weakref.ref(arg)
- except TypeError:
- return arg
+pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
+pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
def _deterministic_dict_values(dictionary):
@@ -1054,7 +1041,8 @@ class PolymorphicFunction(object):
python_function,
name,
input_signature=None,
- attributes=None):
+ attributes=None,
+ experimental_autograph=False):
"""Initializes a polymorphic function.
Args:
@@ -1064,7 +1052,10 @@ class PolymorphicFunction(object):
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
- of the function.
+ of the function.
+ experimental_autograph: whether to use autograph to compile
+ `python_function`. See https://www.tensorflow.org/guide/autograph for
+ more information.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -1080,6 +1071,7 @@ class PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._name = name
+ self._experimental_autograph = experimental_autograph
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1101,6 +1093,8 @@ class PolymorphicFunction(object):
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
+ self._default_values = fullargspec.defaults
+ self._default_values_start_index = offset
if input_signature is None:
self._input_signature = None
else:
@@ -1161,30 +1155,29 @@ class PolymorphicFunction(object):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
- cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
else:
del args, kwargs
cache_key = self._flat_input_signature
+ ctx = context.context()
with ops.init_scope():
- init_graph = ops.get_default_graph()
-
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = context.executing_eagerly()
- execution_context = executing_eagerly or init_graph
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or ops.get_default_graph()
- default_graph = ops.get_default_graph()
- # Putting the device in the cache key ensures that call-site device
- # annotations are respected.
- device_functions = _get_device_functions(context.context(), default_graph)
-
- # `ops.colocate_with` directives translate into `ops.device` directives when
- # eager execution is enabled.
- colocation_stack = (() if executing_eagerly else
- tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ if executing_eagerly:
+ device_functions = (pydev.merge_device(ctx.device_name),)
+ colocation_stack = ()
+ else:
+ default_graph = ops.get_default_graph()
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = tuple(default_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+ colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) # pylint: disable=protected-access
- return cache_key + (execution_context, device_functions, colocation_stack)
+ return (cache_key, execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@@ -1209,35 +1202,44 @@ class PolymorphicFunction(object):
"""
args = self._args_to_prepend + args
kwargs = dict(kwargs, **self._kwargs_to_include)
- # Maps from index of arg to its corresponding value, according to `args`
- # and `kwargs`; seeded with the default values for the named args that
- # aren't in `args`.
- arg_indices_to_values = {
- index: default
- for index, default in six.iteritems(self._arg_indices_to_default_values)
- if index >= len(args)
- }
- consumed_args = []
- for arg, value in six.iteritems(kwargs):
- index = self._args_to_indices.get(arg, None)
- if index is not None:
- arg_indices_to_values[index] = value
- consumed_args.append(arg)
- elif self._input_signature is not None:
- raise ValueError("Cannot define a TensorFlow function from a Python "
- "function with keyword arguments when "
- "input_signature is provided.")
- for arg in consumed_args:
- # After this loop, `kwargs` will only contain true keyword arguments, as
- # opposed to named arguments called in a keyword-like fashion.
- kwargs.pop(arg)
- inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if not kwargs:
+ if self._default_values:
+ inputs = args + self._default_values[len(args) -
+ self._default_values_start_index:]
+ else:
+ inputs = args
+ else:
+ # Maps from index of arg to its corresponding value, according to `args`
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
+ arg_indices_to_values = {
+ index: default for index, default in six.iteritems(
+ self._arg_indices_to_default_values) if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwargs):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwargs` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwargs.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
+ # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
+ # finding a way to store them directly in the cache key (currently not
+ # possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
- if isinstance(value, np.ndarray):
+ if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
need_packing = True
if need_packing:
@@ -1295,8 +1297,13 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwargs, self._input_signature),
+ func_graph_from_py_func(
+ self._name,
+ self._python_function,
+ args,
+ kwargs,
+ self._input_signature,
+ experimental_autograph=self._experimental_autograph),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
@@ -1328,8 +1335,25 @@ def register(func, *args, **kwargs):
"Got type: %s" % type(func))
concrete_func = func.get_concrete_function(*args, **kwargs)
graph = ops.get_default_graph()
- concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
- # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+
+ # There are two situations for the actual call of a defun:
+ # 1. If none of the input args are resource variables or watch by any tape,
+ # it will run the _inference_function of concrete_func for forward pass, and
+ # the gradient will be generated by standard mechanism.
+ # 2. Otherwise, defun will create two functions, one for forward pass, and the
+ # backward pass will be created via tape.
+ # When registering the function, we put both cases into graph.
+ # pylint: disable=protected-access
+ concrete_func._inference_function.add_to_graph(graph)
+
+ if concrete_func._backward_graph_function is None:
+ concrete_func._construct_backprop_function()
+ forward_function = concrete_func._forward_function
+ backward_function = concrete_func._backward_graph_function._inference_function
+ forward_function.add_to_graph(graph)
+ backward_function.add_to_graph(graph)
+ # pylint: enable=protected-access
+
return concrete_func
@@ -1340,7 +1364,7 @@ def _validate_signature(signature):
"a possibly nested sequence of TensorSpec objects.")
-def defun(func=None, input_signature=None):
+def defun(func=None, input_signature=None, experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1649,6 +1673,10 @@ def defun(func=None, input_signature=None):
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
+ experimental_autograph: Whether `func` should be compiled before
+ constructing the graph. See https://www.tensorflow.org/guide/autograph
+ for more information.
+
Returns:
If `func` is not None, returns a callable that will execute the compiled
@@ -1660,10 +1688,16 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
- return defun_with_attributes(func=func, input_signature=input_signature)
+ return defun_with_attributes(
+ func=func,
+ input_signature=input_signature,
+ experimental_autograph=experimental_autograph)
-def defun_with_attributes(func=None, input_signature=None, attributes=None):
+def defun_with_attributes(func=None,
+ input_signature=None,
+ attributes=None,
+ experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@@ -1678,6 +1712,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
attributes. Currently only support primitive types as value, and only
whitelisted attribute name is allowed. Unwhitelisted attribute name or
unsupported value will result into ValueError.
+ experimental_autograph: same as defun()'s experimental_autograph.
Returns:
Same as the return value of defun, with attributes added to the function in
@@ -1694,8 +1729,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature,
- attributes=attributes))
+ PolymorphicFunction(
+ function,
+ name,
+ input_signature=input_signature,
+ attributes=attributes,
+ experimental_autograph=experimental_autograph))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1898,8 +1937,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1914,6 +1955,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access