aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-05 20:07:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 20:11:34 -0700
commit7d3bfc143a74d8e49f138841a07f7f4693b0a911 (patch)
treed86e258eb0131d2ec437a895ce204a28d1879c0e /tensorflow/python
parent45f594a0bce42787356700c0e20f5fbc47193fa3 (diff)
Add the plumbing for an autograph flag to defun. Disabled and experimental for now.
PiperOrigin-RevId: 216003028
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/function.py61
2 files changed, 51 insertions, 11 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d0c1a93118..cae809a7c3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -251,6 +251,7 @@ py_library(
"//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f06148b5d2..bafe07de2b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -31,6 +31,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
+from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@@ -877,7 +878,8 @@ def func_graph_from_py_func(name,
args,
kwargs,
signature=None,
- func_graph=None):
+ func_graph=None,
+ experimental_autograph=False):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -894,6 +896,8 @@ def func_graph_from_py_func(name,
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
+ experimental_autograph: whether to use autograph to compile `python_func`.
+ See https://www.tensorflow.org/guide/autograph for more information.
Returns:
A FuncGraph.
@@ -939,7 +943,17 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwargs)
+ if experimental_autograph:
+ func_outputs = autograph.converted_call(
+ python_func,
+ autograph.ConversionOptions(
+ verbose=True,
+ recursive=True,
+ force_conversion=False,
+ strip_decorators=(defun,),
+ arg_types={}), *func_args, **func_kwargs)
+ else:
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -1035,7 +1049,8 @@ class PolymorphicFunction(object):
python_function,
name,
input_signature=None,
- attributes=None):
+ attributes=None,
+ experimental_autograph=False):
"""Initializes a polymorphic function.
Args:
@@ -1045,7 +1060,10 @@ class PolymorphicFunction(object):
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
- of the function.
+ of the function.
+ experimental_autograph: whether to use autograph to compile
+ `python_function`. See https://www.tensorflow.org/guide/autograph for
+ more information.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -1061,6 +1079,7 @@ class PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._name = name
+ self._experimental_autograph = experimental_autograph
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1286,8 +1305,13 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwargs, self._input_signature),
+ func_graph_from_py_func(
+ self._name,
+ self._python_function,
+ args,
+ kwargs,
+ self._input_signature,
+ experimental_autograph=self._experimental_autograph),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
@@ -1348,7 +1372,7 @@ def _validate_signature(signature):
"a possibly nested sequence of TensorSpec objects.")
-def defun(func=None, input_signature=None):
+def defun(func=None, input_signature=None, experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1657,6 +1681,10 @@ def defun(func=None, input_signature=None):
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
+ experimental_autograph: Whether `func` should be compiled before
+ constructing the graph. See https://www.tensorflow.org/guide/autograph
+ for more information.
+
Returns:
If `func` is not None, returns a callable that will execute the compiled
@@ -1668,10 +1696,16 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
- return defun_with_attributes(func=func, input_signature=input_signature)
+ return defun_with_attributes(
+ func=func,
+ input_signature=input_signature,
+ experimental_autograph=experimental_autograph)
-def defun_with_attributes(func=None, input_signature=None, attributes=None):
+def defun_with_attributes(func=None,
+ input_signature=None,
+ attributes=None,
+ experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@@ -1686,6 +1720,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
attributes. Currently only support primitive types as value, and only
whitelisted attribute name is allowed. Unwhitelisted attribute name or
unsupported value will result into ValueError.
+ experimental_autograph: same as defun()'s experimental_autograph.
Returns:
Same as the return value of defun, with attributes added to the function in
@@ -1702,8 +1737,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature,
- attributes=attributes))
+ PolymorphicFunction(
+ function,
+ name,
+ input_signature=input_signature,
+ attributes=attributes,
+ experimental_autograph=experimental_autograph))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None: