aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-18 05:22:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 05:27:24 -0700
commitc6a060c83cc56c8c0cc0f1105550def4bff93c0d (patch)
tree113d44c285991ff85e2f5f574453042768dad7fe /tensorflow/python/autograph
parent0cf3690400e46bd89b48a206eff8dd08a660aced (diff)
Simplify the interface of conversion_call to allow a ConversionOptions object that can be more easily extended. Currently any new argument needs changing a lot of call sites and there is redundancy in argument documentation.
Note: this does not modify the public symbols yet - it's not clear whether we want to complicate their interface. However we may want to use it in to_graph and to_code. PiperOrigin-RevId: 213433379
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/__init__.py2
-rw-r--r--tensorflow/python/autograph/converters/call_trees.py11
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py12
-rw-r--r--tensorflow/python/autograph/impl/api.py83
-rw-r--r--tensorflow/python/autograph/impl/api_test.py24
-rw-r--r--tensorflow/python/autograph/impl/conversion.py1
6 files changed, 102 insertions, 31 deletions
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
index c3448e6e58..5ed5e85158 100644
--- a/tensorflow/python/autograph/__init__.py
+++ b/tensorflow/python/autograph/__init__.py
@@ -27,6 +27,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core.errors import GraphConstructionError
from tensorflow.python.autograph.core.errors import TfRuntimeError
from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import ConversionOptions
from tensorflow.python.autograph.impl.api import RunMode
from tensorflow.python.autograph.impl.api import convert
from tensorflow.python.autograph.impl.api import converted_call
@@ -42,6 +43,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
# Main API
+ 'ConversionOptions',
'RunMode',
'convert',
'converted_call',
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 6a606c450d..fc2075b781 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -238,9 +238,16 @@ class CallTreeTransformer(converter.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- ag__.converted_call(func, True, False, False, {}, args)
+ ag__.converted_call(
+ func,
+ ag__.ConversionOptions.new(recursive=recursive_val),
+ args)
"""
- call_expr = templates.replace(template, func=node.func, args=node.args)
+ call_expr = templates.replace(
+ template,
+ func=node.func,
+ recursive_val=parser.parse_expression(str(self.ctx.program.recursive)),
+ args=node.args)
new_call = call_expr[0].value
# TODO(mdan): Improve the template mechanism to better support this.
new_call.keywords = node.keywords
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 0a0c6f9002..7ce1b7c4c5 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -93,11 +93,21 @@ class TestCase(test.TestCase):
self.dynamic_calls.append(args)
return 7
+ class ConversionOptions(object):
+ """Mock version of api.ConversionOptions."""
+
+ def __init__(self, recursive):
+ self.recursive = recursive
+
+ @classmethod
+ def new(cls, recursive):
+ cls(recursive)
+
try:
result, source = compiler.ast_to_object(node, include_source_map=True)
result.tf = self.make_fake_mod('fake_tf', *symbols)
- fake_ag = self.make_fake_mod('fake_ag', converted_call)
+ fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 669d36bd28..ee2467e0dc 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -18,7 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from functools import wraps
+import collections
+import functools
from enum import Enum
@@ -38,6 +39,41 @@ from tensorflow.python.util import tf_inspect
# (currently we require (module + class name, type))
+class ConversionOptions(
+ collections.namedtuple('ConversionOptions',
+ ('recursive', 'verbose', 'strip_decorators',
+ 'force_conversion', 'arg_types'))):
+ """Container for conversion flags.
+
+ Attributes:
+ recursive: bool, whether to recursively convert any user functions or
+ classes that the converted function may use.
+ verbose: bool, whether to log the compiled code.
+ strip_decorators: Tuple[Callable], contains decorators that should be in
+ excluded from the compiled output. By default, when converting a
+ function before the decorators are applied, the compiled output will
+ include those decorators.
+ force_conversion: bool, whether to force convertinng the target entity.
+ When force_conversion is turned off, the converter may decide to
+ return the function as-is.
+ arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+ function arguments.
+ """
+
+ @classmethod
+ def new(cls,
+ recursive=False,
+ verbose=False,
+ strip_decorators=None,
+ force_conversion=False,
+ arg_types=None):
+ return cls(recursive=recursive,
+ verbose=verbose,
+ strip_decorators=strip_decorators or (),
+ force_conversion=force_conversion,
+ arg_types=arg_types or {})
+
+
# TODO(mdan): This should behave like to_graph (e.g. convert statically).
def convert(recursive=False, verbose=False):
"""Decorator that compiles a function to use TensorFlow ops.
@@ -59,9 +95,15 @@ def convert(recursive=False, verbose=False):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, True, {}, *args, **kwargs)
+ return converted_call(
+ f,
+ ConversionOptions.new(
+ recursive=recursive,
+ verbose=verbose,
+ force_conversion=True,
+ ), *args, **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -107,11 +149,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def graph_wrapper(*args, **kwargs):
return f(*args, **kwargs)
- @wraps(f)
+ @functools.wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
@@ -135,12 +177,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
# TODO(mdan): Move to a private, undocumented module.
-def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
- **kwargs):
+def converted_call(f, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
- if not force_conversion and conversion.is_whitelisted_for_graph(f):
+ if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
@@ -183,8 +224,8 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
continue
arg_class = arg.__class__
# If arg_value_hints specifies any name, use that instead.
- if name not in arg_types:
- arg_types[name] = (arg_class.__name__, arg_class)
+ if name not in options.arg_types:
+ options.arg_types[name] = (arg_class.__name__, arg_class)
# When called from within a decorator, this is the only indication that
# the function is a method - it appears that the decorator is applied
@@ -199,23 +240,25 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
converted_f = to_graph(
target_entity,
- recursive=recursive,
- verbose=verbose,
+ recursive=options.recursive,
+ verbose=options.verbose,
arg_values=arg_values,
- arg_types=arg_types,
- partial_types=partial_types)
+ arg_types=options.arg_types,
+ partial_types=partial_types,
+ strip_decorators=options.strip_decorators)
return converted_f(*effective_args, **kwargs)
# TODO(mdan): Rename: to_ops?
-# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun.
+# TODO(mdan): Look into overloading as function and decorator, like tfe.defun?
# TODO(mdan): Remove partial_types.
def to_graph(e,
recursive=True,
verbose=False,
arg_values=None,
arg_types=None,
- partial_types=None):
+ partial_types=None,
+ strip_decorators=None):
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
Supported Python entities include:
@@ -234,6 +277,8 @@ def to_graph(e,
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
partial_types: Set[Type], reserved for internal use.
+ strip_decorators: Tuple[Callable], same as
+ ConversionOptions.strip_decorators.
Returns:
Union[Callable, Type], the converted entity, which is the same kind as e
@@ -243,9 +288,13 @@ def to_graph(e,
Raises:
ValueError: If the entity could not be converted.
"""
+ if strip_decorators is None:
+ strip_decorators = ()
+ strip_decorators += (convert, do_not_convert, converted_call)
+
program_ctx = converter.ProgramContext(
recursive=recursive,
- autograph_decorators=(convert, do_not_convert, converted_call),
+ autograph_decorators=strip_decorators,
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 54e12f0223..e0770ef4c6 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -32,7 +32,6 @@ from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
-
class ApiTest(test.TestCase):
def setUp(self):
@@ -180,8 +179,9 @@ class ApiTest(test.TestCase):
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
- x //= api.converted_call(self.called_member, False, False, False, {},
- self, a)
+ x //= api.converted_call(
+ self.called_member,
+ api.ConversionOptions.new(), self, a)
return x
tc = TestClass()
@@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
- x = api.converted_call(range, False, False, False, {}, 3)
+ x = api.converted_call(range, api.ConversionOptions.new(), 3)
self.assertEqual((0, 1, 2), tuple(x))
def test_converted_call_function(self):
@@ -203,7 +203,7 @@ class ApiTest(test.TestCase):
return x
with self.test_session() as sess:
- x = api.converted_call(test_fn, False, False, False, {},
+ x = api.converted_call(test_fn, api.ConversionOptions.new(),
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
@@ -221,7 +221,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+ x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@@ -238,7 +238,9 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+ x = api.converted_call(
+ TestClass.test_method,
+ api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@@ -255,7 +257,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc, False, False, False, {})
+ x = api.converted_call(tc, api.ConversionOptions.new())
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@@ -271,7 +273,7 @@ class ApiTest(test.TestCase):
return self.x
with self.test_session() as sess:
- tc = api.converted_call(TestClass, False, False, False, {},
+ tc = api.converted_call(TestClass, api.ConversionOptions.new(),
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
@@ -283,12 +285,12 @@ class ApiTest(test.TestCase):
return x == 0
with self.test_session() as sess:
- x = api.converted_call(f, False, False, False, {},
+ x = api.converted_call(f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
- x = api.converted_call(converted_f, False, False, False, {},
+ x = api.converted_call(converted_f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 928ff9e7ea..a0d13c82a8 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -255,6 +255,7 @@ def _add_self_references(namespace, autograph_module):
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.converted_call = autograph_module.converted_call
+ ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)