diff options
author | Dan Moldovan <mdan@google.com> | 2018-10-05 20:07:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 20:11:34 -0700 |
commit | 7d3bfc143a74d8e49f138841a07f7f4693b0a911 (patch) | |
tree | d86e258eb0131d2ec437a895ce204a28d1879c0e /tensorflow/python | |
parent | 45f594a0bce42787356700c0e20f5fbc47193fa3 (diff) |
Add the plumbing for an autograph flag to defun. Disabled and experimental for now.
PiperOrigin-RevId: 216003028
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 61 |
2 files changed, 51 insertions, 11 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index d0c1a93118..cae809a7c3 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -251,6 +251,7 @@ py_library( "//tensorflow/python:gradients_impl", "//tensorflow/python:graph_to_function_def", "//tensorflow/python:util", + "//tensorflow/python/autograph", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", "//tensorflow/python/eager:execute", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f06148b5d2..bafe07de2b 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -31,6 +31,7 @@ import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 +from tensorflow.python import autograph from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute @@ -877,7 +878,8 @@ def func_graph_from_py_func(name, args, kwargs, signature=None, - func_graph=None): + func_graph=None, + experimental_autograph=False): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -894,6 +896,8 @@ def func_graph_from_py_func(name, inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. + experimental_autograph: whether to use autograph to compile `python_func`. + See https://www.tensorflow.org/guide/autograph for more information. Returns: A FuncGraph. @@ -939,7 +943,17 @@ def func_graph_from_py_func(name, this_tape = tape.push_new_tape() try: - func_outputs = python_func(*func_args, **func_kwargs) + if experimental_autograph: + func_outputs = autograph.converted_call( + python_func, + autograph.ConversionOptions( + verbose=True, + recursive=True, + force_conversion=False, + strip_decorators=(defun,), + arg_types={}), *func_args, **func_kwargs) + else: + func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) @@ -1035,7 +1049,8 @@ class PolymorphicFunction(object): python_function, name, input_signature=None, - attributes=None): + attributes=None, + experimental_autograph=False): """Initializes a polymorphic function. Args: @@ -1045,7 +1060,10 @@ class PolymorphicFunction(object): specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. attributes: dict, extra keyword arguments that will be added as attribute - of the function. + of the function. + experimental_autograph: whether to use autograph to compile + `python_function`. See https://www.tensorflow.org/guide/autograph for + more information. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -1061,6 +1079,7 @@ class PolymorphicFunction(object): self._args_to_prepend = tuple() self._kwargs_to_include = {} self._name = name + self._experimental_autograph = experimental_autograph self._function_cache = collections.OrderedDict() self._function_attributes = attributes or {} @@ -1286,8 +1305,13 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( - func_graph_from_py_func(self._name, self._python_function, args, - kwargs, self._input_signature), + func_graph_from_py_func( + self._name, + self._python_function, + args, + kwargs, + self._input_signature, + experimental_autograph=self._experimental_autograph), self._function_attributes) self._function_cache[cache_key] = graph_function return graph_function, [ @@ -1348,7 +1372,7 @@ def _validate_signature(signature): "a possibly nested sequence of TensorSpec objects.") -def defun(func=None, input_signature=None): +def defun(func=None, input_signature=None, experimental_autograph=False): """Compiles a Python function into a callable TensorFlow graph. `defun` (short for "define function") trace-compiles a Python function @@ -1657,6 +1681,10 @@ def defun(func=None, input_signature=None): function is instantiated for each inferred input signature. If a signature is specified, every input to `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. + experimental_autograph: Whether `func` should be compiled before + constructing the graph. See https://www.tensorflow.org/guide/autograph + for more information. + Returns: If `func` is not None, returns a callable that will execute the compiled @@ -1668,10 +1696,16 @@ def defun(func=None, input_signature=None): TypeError: If `input_signature` is neither `None` nor a sequence of `tf.contrib.eager.TensorSpec` objects. """ - return defun_with_attributes(func=func, input_signature=input_signature) + return defun_with_attributes( + func=func, + input_signature=input_signature, + experimental_autograph=experimental_autograph) -def defun_with_attributes(func=None, input_signature=None, attributes=None): +def defun_with_attributes(func=None, + input_signature=None, + attributes=None, + experimental_autograph=False): """Compiles a Python function into a callable TensorFlow graph. This function supports adding extra function attributes. See detailed @@ -1686,6 +1720,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None): attributes. Currently only support primitive types as value, and only whitelisted attribute name is allowed. Unwhitelisted attribute name or unsupported value will result into ValueError. + experimental_autograph: same as defun()'s experimental_autograph. Returns: Same as the return value of defun, with attributes added to the function in @@ -1702,8 +1737,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None): name = "function" return tf_decorator.make_decorator( function, - PolymorphicFunction(function, name, input_signature=input_signature, - attributes=attributes)) + PolymorphicFunction( + function, + name, + input_signature=input_signature, + attributes=attributes, + experimental_autograph=experimental_autograph)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: |