diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 294 |
1 files changed, 170 insertions, 124 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index b28befeb62..93168826b1 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +import re import sys import threading import weakref @@ -30,6 +31,7 @@ import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 +from tensorflow.python import autograph from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute @@ -61,9 +63,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce # This is to avoid a circular dependency with gradients_impl gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access +FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" +BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" # TODO(scottzhu): Update this to allow arbitrary attribute names in future. -WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" +WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [ + "experimental_.*", + FORWARD_FUNCTION_ATTRIBUTE_NAME, + BACKWARD_FUNCTION_ATTRIBUTE_NAME +] def _create_substitute_placeholder(value, name=None, dtype=None): @@ -140,10 +148,11 @@ def _parse_func_attrs(attributes): """ attrs = {} for key, value in attributes.items(): - if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX): + if not any([re.match(reg, key) + for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]): raise ValueError("Attribute name is not whitelisted. " "Whitelisted: prefix %s, got: %s" % - (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key)) + (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key)) if isinstance(value, attr_value_pb2.AttrValue): attrs[key] = value @@ -154,7 +163,7 @@ def _parse_func_attrs(attributes): attrs[key] = attr_value_pb2.AttrValue(i=value) elif isinstance(value, float): attrs[key] = attr_value_pb2.AttrValue(f=value) - elif isinstance(value, str): + elif isinstance(value, (str, bytes)): attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) else: raise ValueError("Unsupported attribute type for %s with type %s" % @@ -486,6 +495,9 @@ class _EagerDefinedFunction(object): Returns: The outputs of the function call. + + Raises: + ValueError: if the number of arguments is incorrect. """ executing_eagerly = ctx.executing_eagerly() @@ -519,6 +531,10 @@ class _EagerDefinedFunction(object): # TODO(akshayka): Either remove this if the FunctionLibraryRuntime # creates `PartitionedCallOp` kernels by default, or remove the previous # branch if a TPU kernel is registered for `PartitionedCall`. + if len(args) != len(self.signature.input_arg): + raise ValueError( + "Arguments and signature arguments do not match: %s %s " % + (len(args), len(list(self.signature.input_arg)))) outputs = functional_ops.partitioned_call( args=args, f=self, @@ -705,6 +721,7 @@ class Function(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) + forward_function_name = _forward_name(self._func_graph.name) with backwards_graph.as_default(): gradients_wrt_outputs = [ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs @@ -715,11 +732,11 @@ class Function(object): grad_ys=gradients_wrt_outputs, src_graph=self._func_graph) - self._forward_function = _EagerDefinedFunction( - _forward_name( - self._func_graph.name), self._func_graph, self._func_graph.inputs, - self._func_graph.outputs + list(backwards_graph.captures.keys()), - self._attrs) + backwards_graph_captures = list(backwards_graph.captures.keys()) + + backward_function_attr = _parse_func_attrs( + {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) + backward_function_attr.update(self._attrs) # The ordering of `backwards_graph.inputs` is important: inputs of # `self._backward_graph_function` correspond to outputs of @@ -732,7 +749,16 @@ class Function(object): grad for grad in _flatten(gradients_wrt_inputs) if grad is not None) backwards_graph.structured_outputs = gradients_wrt_inputs self._backward_graph_function = Function( - backwards_graph, attrs=self._attrs) + backwards_graph, attrs=backward_function_attr) + + forward_function_attr = _parse_func_attrs({ + BACKWARD_FUNCTION_ATTRIBUTE_NAME: + self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access + forward_function_attr.update(self._attrs) + self._forward_function = _EagerDefinedFunction( + forward_function_name, self._func_graph, self._func_graph.inputs, + self._func_graph.outputs + backwards_graph_captures, + forward_function_attr) def _backprop_call(self, args): """Calls the forward function and records the result on a tape. @@ -829,20 +855,12 @@ class Function(object): return ret -def _get_defun_inputs_from_signature(signature): - """Maps a signature to graph-construction inputs.""" - function_inputs = [ - graph_placeholder(spec.dtype, spec.shape) - for spec in nest.flatten(signature) - ] - return nest.pack_sequence_as(signature, function_inputs) - - def _get_defun_inputs_from_args(args): """Maps python function args to graph-construction inputs.""" function_inputs = [ graph_placeholder(arg.dtype, arg.shape) - if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args) + if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) + else arg for arg in nest.flatten(args) ] return nest.pack_sequence_as(args, function_inputs) @@ -852,7 +870,8 @@ def func_graph_from_py_func(name, args, kwargs, signature=None, - func_graph=None): + func_graph=None, + experimental_autograph=False): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -869,6 +888,8 @@ def func_graph_from_py_func(name, inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. + experimental_autograph: whether to use autograph to compile `python_func`. + See https://www.tensorflow.org/guide/autograph for more information. Returns: A FuncGraph. @@ -883,12 +904,12 @@ def func_graph_from_py_func(name, with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) - if signature is None: - func_args = _get_defun_inputs_from_args(args) - func_kwargs = _get_defun_inputs_from_args(kwargs) - else: - func_args = _get_defun_inputs_from_signature(signature) - func_kwargs = {} + if signature is not None: + args = signature + kwargs = {} + + func_args = _get_defun_inputs_from_args(args) + func_kwargs = _get_defun_inputs_from_args(kwargs) # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function @@ -914,7 +935,17 @@ def func_graph_from_py_func(name, this_tape = tape.push_new_tape() try: - func_outputs = python_func(*func_args, **func_kwargs) + if experimental_autograph: + func_outputs = autograph.converted_call( + python_func, + autograph.ConversionOptions( + verbose=True, + recursive=True, + force_conversion=False, + strip_decorators=(defun,), + arg_types={}), *func_args, **func_kwargs) + else: + func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) @@ -986,52 +1017,8 @@ def func_graph_from_py_func(name, return func_graph -_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"]) - - -def _encode_arg(arg): - """A canonical representation for this argument, for use in a cache key.""" - - # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes - # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes - # are used for both performance reasons, as much TensorFlow code specializes - # on known shapes to produce slimmer graphs, and correctness, as some - # high-level APIs require shapes to be fully-known. - # - # TODO(akshayka): Add support for sparse tensors. - # - # pylint: disable=protected-access - if isinstance(arg, ops.Tensor): - return _TensorType(arg.dtype, arg._shape_tuple()) - elif isinstance(arg, ops.IndexedSlices): - if arg.dense_shape is not None: - return tuple([ - _TensorType(arg.values.dtype, arg.values._shape_tuple()), - _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), - _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()), - ]) - else: - return tuple([ - _TensorType(arg.values.dtype, arg.values._shape_tuple()), - _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), - ]) - # pylint: enable=protected-access - elif isinstance(arg, (list, tuple)): - return tuple([_encode_arg(elem) for elem in arg]) - elif isinstance(arg, dict): - return tuple( - (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg)) - else: - try: - # If possible, keep only a weak reference to Python objects. Weak - # references hash to the same value as the original object. - # TODO(allenl): Clean up dead functions and their cache keys if the cache - # gets large. Right now creating objects with a defunned method, calling - # the method, and losing a reference to the object in a loop will leak - # memory here. - return weakref.ref(arg) - except TypeError: - return arg +pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) +pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) def _deterministic_dict_values(dictionary): @@ -1054,7 +1041,8 @@ class PolymorphicFunction(object): python_function, name, input_signature=None, - attributes=None): + attributes=None, + experimental_autograph=False): """Initializes a polymorphic function. Args: @@ -1064,7 +1052,10 @@ class PolymorphicFunction(object): specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. attributes: dict, extra keyword arguments that will be added as attribute - of the function. + of the function. + experimental_autograph: whether to use autograph to compile + `python_function`. See https://www.tensorflow.org/guide/autograph for + more information. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -1080,6 +1071,7 @@ class PolymorphicFunction(object): self._args_to_prepend = tuple() self._kwargs_to_include = {} self._name = name + self._experimental_autograph = experimental_autograph self._function_cache = collections.OrderedDict() self._function_attributes = attributes or {} @@ -1101,6 +1093,8 @@ class PolymorphicFunction(object): offset + index: default for index, default in enumerate(fullargspec.defaults or []) } + self._default_values = fullargspec.defaults + self._default_values_start_index = offset if input_signature is None: self._input_signature = None else: @@ -1161,30 +1155,29 @@ class PolymorphicFunction(object): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: inputs = (args, kwargs) if kwargs else args - cache_key = tuple(_encode_arg(arg) for arg in inputs) + cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs) else: del args, kwargs cache_key = self._flat_input_signature + ctx = context.context() with ops.init_scope(): - init_graph = ops.get_default_graph() - # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. - executing_eagerly = context.executing_eagerly() - execution_context = executing_eagerly or init_graph + executing_eagerly = ctx.executing_eagerly() + execution_context = executing_eagerly or ops.get_default_graph() - default_graph = ops.get_default_graph() - # Putting the device in the cache key ensures that call-site device - # annotations are respected. - device_functions = _get_device_functions(context.context(), default_graph) - - # `ops.colocate_with` directives translate into `ops.device` directives when - # eager execution is enabled. - colocation_stack = (() if executing_eagerly else - tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access + if executing_eagerly: + device_functions = (pydev.merge_device(ctx.device_name),) + colocation_stack = () + else: + default_graph = ops.get_default_graph() + # Putting the device in the cache key ensures that call-site device + # annotations are respected. + device_functions = tuple(default_graph._device_functions_outer_to_inner) # pylint: disable=protected-access + colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) # pylint: disable=protected-access - return cache_key + (execution_context, device_functions, colocation_stack) + return (cache_key, execution_context, device_functions, colocation_stack) def _canonicalize_function_inputs(self, *args, **kwargs): """Canonicalizes `args` and `kwargs`. @@ -1209,35 +1202,44 @@ class PolymorphicFunction(object): """ args = self._args_to_prepend + args kwargs = dict(kwargs, **self._kwargs_to_include) - # Maps from index of arg to its corresponding value, according to `args` - # and `kwargs`; seeded with the default values for the named args that - # aren't in `args`. - arg_indices_to_values = { - index: default - for index, default in six.iteritems(self._arg_indices_to_default_values) - if index >= len(args) - } - consumed_args = [] - for arg, value in six.iteritems(kwargs): - index = self._args_to_indices.get(arg, None) - if index is not None: - arg_indices_to_values[index] = value - consumed_args.append(arg) - elif self._input_signature is not None: - raise ValueError("Cannot define a TensorFlow function from a Python " - "function with keyword arguments when " - "input_signature is provided.") - for arg in consumed_args: - # After this loop, `kwargs` will only contain true keyword arguments, as - # opposed to named arguments called in a keyword-like fashion. - kwargs.pop(arg) - inputs = args + _deterministic_dict_values(arg_indices_to_values) + if not kwargs: + if self._default_values: + inputs = args + self._default_values[len(args) - + self._default_values_start_index:] + else: + inputs = args + else: + # Maps from index of arg to its corresponding value, according to `args` + # and `kwargs`; seeded with the default values for the named args that + # aren't in `args`. + arg_indices_to_values = { + index: default for index, default in six.iteritems( + self._arg_indices_to_default_values) if index >= len(args) + } + consumed_args = [] + for arg, value in six.iteritems(kwargs): + index = self._args_to_indices.get(arg, None) + if index is not None: + arg_indices_to_values[index] = value + consumed_args.append(arg) + elif self._input_signature is not None: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + for arg in consumed_args: + # After this loop, `kwargs` will only contain true keyword arguments, as + # opposed to named arguments called in a keyword-like fashion. + kwargs.pop(arg) + inputs = args + _deterministic_dict_values(arg_indices_to_values) flat_inputs = nest.flatten(inputs) # Check for NumPy arrays in arguments and convert them to Tensors. + # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps + # finding a way to store them directly in the cache key (currently not + # possible since ndarrays are not hashable). need_packing = False for index, value in enumerate(flat_inputs): - if isinstance(value, np.ndarray): + if type(value) == np.ndarray: flat_inputs[index] = constant_op.constant(value) need_packing = True if need_packing: @@ -1295,8 +1297,13 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( - func_graph_from_py_func(self._name, self._python_function, args, - kwargs, self._input_signature), + func_graph_from_py_func( + self._name, + self._python_function, + args, + kwargs, + self._input_signature, + experimental_autograph=self._experimental_autograph), self._function_attributes) self._function_cache[cache_key] = graph_function return graph_function, [ @@ -1328,8 +1335,25 @@ def register(func, *args, **kwargs): "Got type: %s" % type(func)) concrete_func = func.get_concrete_function(*args, **kwargs) graph = ops.get_default_graph() - concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access - # TODO(scottzhu): support concrete_func._backward_graph_function in future. + + # There are two situations for the actual call of a defun: + # 1. If none of the input args are resource variables or watch by any tape, + # it will run the _inference_function of concrete_func for forward pass, and + # the gradient will be generated by standard mechanism. + # 2. Otherwise, defun will create two functions, one for forward pass, and the + # backward pass will be created via tape. + # When registering the function, we put both cases into graph. + # pylint: disable=protected-access + concrete_func._inference_function.add_to_graph(graph) + + if concrete_func._backward_graph_function is None: + concrete_func._construct_backprop_function() + forward_function = concrete_func._forward_function + backward_function = concrete_func._backward_graph_function._inference_function + forward_function.add_to_graph(graph) + backward_function.add_to_graph(graph) + # pylint: enable=protected-access + return concrete_func @@ -1340,7 +1364,7 @@ def _validate_signature(signature): "a possibly nested sequence of TensorSpec objects.") -def defun(func=None, input_signature=None): +def defun(func=None, input_signature=None, experimental_autograph=False): """Compiles a Python function into a callable TensorFlow graph. `defun` (short for "define function") trace-compiles a Python function @@ -1649,6 +1673,10 @@ def defun(func=None, input_signature=None): function is instantiated for each inferred input signature. If a signature is specified, every input to `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. + experimental_autograph: Whether `func` should be compiled before + constructing the graph. See https://www.tensorflow.org/guide/autograph + for more information. + Returns: If `func` is not None, returns a callable that will execute the compiled @@ -1660,10 +1688,16 @@ def defun(func=None, input_signature=None): TypeError: If `input_signature` is neither `None` nor a sequence of `tf.contrib.eager.TensorSpec` objects. """ - return defun_with_attributes(func=func, input_signature=input_signature) + return defun_with_attributes( + func=func, + input_signature=input_signature, + experimental_autograph=experimental_autograph) -def defun_with_attributes(func=None, input_signature=None, attributes=None): +def defun_with_attributes(func=None, + input_signature=None, + attributes=None, + experimental_autograph=False): """Compiles a Python function into a callable TensorFlow graph. This function supports adding extra function attributes. See detailed @@ -1678,6 +1712,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None): attributes. Currently only support primitive types as value, and only whitelisted attribute name is allowed. Unwhitelisted attribute name or unsupported value will result into ValueError. + experimental_autograph: same as defun()'s experimental_autograph. Returns: Same as the return value of defun, with attributes added to the function in @@ -1694,8 +1729,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None): name = "function" return tf_decorator.make_decorator( function, - PolymorphicFunction(function, name, input_signature=input_signature, - attributes=attributes)) + PolymorphicFunction( + function, + name, + input_signature=input_signature, + attributes=attributes, + experimental_autograph=experimental_autograph)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1898,8 +1937,10 @@ class AutomaticControlDependencies(object): last_op_using_resource_tensor[inp] = op ops_which_must_run = set([op]) continue + found_resource = False for inp in op.inputs: if inp.dtype == dtypes_module.resource: + found_resource = True # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, @@ -1914,6 +1955,11 @@ class AutomaticControlDependencies(object): if inp in merge_for_resource: merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[inp] = op + if (op.op_def.is_stateful and not found_resource + and op._control_flow_context is None): # pylint: disable=protected-access + if None in last_op_using_resource_tensor: + op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access + last_op_using_resource_tensor[None] = op control_inputs = [c for c in control_inputs if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access op._add_control_inputs(control_inputs) # pylint: disable=protected-access |