diff options
author | Dan Moldovan <mdan@google.com> | 2018-09-18 05:22:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 05:27:24 -0700 |
commit | c6a060c83cc56c8c0cc0f1105550def4bff93c0d (patch) | |
tree | 113d44c285991ff85e2f5f574453042768dad7fe /tensorflow/python/autograph | |
parent | 0cf3690400e46bd89b48a206eff8dd08a660aced (diff) |
Simplify the interface of conversion_call to allow a ConversionOptions object that can be more easily extended. Currently any new argument needs changing a lot of call sites and there is redundancy in argument documentation.
Note: this does not modify the public symbols yet - it's not clear whether we want to complicate their interface. However we may want to use it in to_graph and to_code.
PiperOrigin-RevId: 213433379
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/python/autograph/converters/call_trees.py | 11 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/converter_testing.py | 12 | ||||
-rw-r--r-- | tensorflow/python/autograph/impl/api.py | 83 | ||||
-rw-r--r-- | tensorflow/python/autograph/impl/api_test.py | 24 | ||||
-rw-r--r-- | tensorflow/python/autograph/impl/conversion.py | 1 |
6 files changed, 102 insertions, 31 deletions
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py index c3448e6e58..5ed5e85158 100644 --- a/tensorflow/python/autograph/__init__.py +++ b/tensorflow/python/autograph/__init__.py @@ -27,6 +27,7 @@ from tensorflow.python.autograph import utils from tensorflow.python.autograph.core.errors import GraphConstructionError from tensorflow.python.autograph.core.errors import TfRuntimeError from tensorflow.python.autograph.core.errors import improved_errors +from tensorflow.python.autograph.impl.api import ConversionOptions from tensorflow.python.autograph.impl.api import RunMode from tensorflow.python.autograph.impl.api import convert from tensorflow.python.autograph.impl.api import converted_call @@ -42,6 +43,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ # Main API + 'ConversionOptions', 'RunMode', 'convert', 'converted_call', diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 6a606c450d..fc2075b781 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -238,9 +238,16 @@ class CallTreeTransformer(converter.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, False, {}, args) + ag__.converted_call( + func, + ag__.ConversionOptions.new(recursive=recursive_val), + args) """ - call_expr = templates.replace(template, func=node.func, args=node.args) + call_expr = templates.replace( + template, + func=node.func, + recursive_val=parser.parse_expression(str(self.ctx.program.recursive)), + args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 0a0c6f9002..7ce1b7c4c5 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -93,11 +93,21 @@ class TestCase(test.TestCase): self.dynamic_calls.append(args) return 7 + class ConversionOptions(object): + """Mock version of api.ConversionOptions.""" + + def __init__(self, recursive): + self.recursive = recursive + + @classmethod + def new(cls, recursive): + cls(recursive) + try: result, source = compiler.ast_to_object(node, include_source_map=True) result.tf = self.make_fake_mod('fake_tf', *symbols) - fake_ag = self.make_fake_mod('fake_ag', converted_call) + fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 669d36bd28..ee2467e0dc 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -18,7 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from functools import wraps +import collections +import functools from enum import Enum @@ -38,6 +39,41 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) +class ConversionOptions( + collections.namedtuple('ConversionOptions', + ('recursive', 'verbose', 'strip_decorators', + 'force_conversion', 'arg_types'))): + """Container for conversion flags. + + Attributes: + recursive: bool, whether to recursively convert any user functions or + classes that the converted function may use. + verbose: bool, whether to log the compiled code. + strip_decorators: Tuple[Callable], contains decorators that should be in + excluded from the compiled output. By default, when converting a + function before the decorators are applied, the compiled output will + include those decorators. + force_conversion: bool, whether to force convertinng the target entity. + When force_conversion is turned off, the converter may decide to + return the function as-is. + arg_types: Optional[Dict[Text, Type]], type hints for symbols including + function arguments. + """ + + @classmethod + def new(cls, + recursive=False, + verbose=False, + strip_decorators=None, + force_conversion=False, + arg_types=None): + return cls(recursive=recursive, + verbose=verbose, + strip_decorators=strip_decorators or (), + force_conversion=force_conversion, + arg_types=arg_types or {}) + + # TODO(mdan): This should behave like to_graph (e.g. convert statically). def convert(recursive=False, verbose=False): """Decorator that compiles a function to use TensorFlow ops. @@ -59,9 +95,15 @@ def convert(recursive=False, verbose=False): def decorator(f): """Decorator implementation.""" - @wraps(f) + @functools.wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, True, {}, *args, **kwargs) + return converted_call( + f, + ConversionOptions.new( + recursive=recursive, + verbose=verbose, + force_conversion=True, + ), *args, **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) @@ -107,11 +149,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): def decorator(f): """Decorator implementation.""" - @wraps(f) + @functools.wraps(f) def graph_wrapper(*args, **kwargs): return f(*args, **kwargs) - @wraps(f) + @functools.wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') @@ -135,12 +177,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): # TODO(mdan): Move to a private, undocumented module. -def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, - **kwargs): +def converted_call(f, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - if not force_conversion and conversion.is_whitelisted_for_graph(f): + if not options.force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value @@ -183,8 +224,8 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. - if name not in arg_types: - arg_types[name] = (arg_class.__name__, arg_class) + if name not in options.arg_types: + options.arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied @@ -199,23 +240,25 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, converted_f = to_graph( target_entity, - recursive=recursive, - verbose=verbose, + recursive=options.recursive, + verbose=options.verbose, arg_values=arg_values, - arg_types=arg_types, - partial_types=partial_types) + arg_types=options.arg_types, + partial_types=partial_types, + strip_decorators=options.strip_decorators) return converted_f(*effective_args, **kwargs) # TODO(mdan): Rename: to_ops? -# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun. +# TODO(mdan): Look into overloading as function and decorator, like tfe.defun? # TODO(mdan): Remove partial_types. def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, - partial_types=None): + partial_types=None, + strip_decorators=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: @@ -234,6 +277,8 @@ def to_graph(e, arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. + strip_decorators: Tuple[Callable], same as + ConversionOptions.strip_decorators. Returns: Union[Callable, Type], the converted entity, which is the same kind as e @@ -243,9 +288,13 @@ def to_graph(e, Raises: ValueError: If the entity could not be converted. """ + if strip_decorators is None: + strip_decorators = () + strip_decorators += (convert, do_not_convert, converted_call) + program_ctx = converter.ProgramContext( recursive=recursive, - autograph_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=strip_decorators, partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 54e12f0223..e0770ef4c6 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -32,7 +32,6 @@ from tensorflow.python.util import tf_inspect tf = utils.fake_tf() - class ApiTest(test.TestCase): def setUp(self): @@ -180,8 +179,9 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, False, {}, - self, a) + x //= api.converted_call( + self.called_member, + api.ConversionOptions.new(), self, a) return x tc = TestClass() @@ -192,7 +192,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, False, {}, 3) + x = api.converted_call(range, api.ConversionOptions.new(), 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -203,7 +203,7 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call(test_fn, False, False, False, {}, + x = api.converted_call(test_fn, api.ConversionOptions.new(), constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) @@ -221,7 +221,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, False, {}, tc) + x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -238,7 +238,9 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) + x = api.converted_call( + TestClass.test_method, + api.ConversionOptions.new(), tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -255,7 +257,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, False, {}) + x = api.converted_call(tc, api.ConversionOptions.new()) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -271,7 +273,7 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call(TestClass, False, False, False, {}, + tc = api.converted_call(TestClass, api.ConversionOptions.new(), constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() @@ -283,12 +285,12 @@ class ApiTest(test.TestCase): return x == 0 with self.test_session() as sess: - x = api.converted_call(f, False, False, False, {}, + x = api.converted_call(f, api.ConversionOptions.new(), constant_op.constant(0)) self.assertTrue(sess.run(x)) converted_f = api.to_graph(f) - x = api.converted_call(converted_f, False, False, False, {}, + x = api.converted_call(converted_f, api.ConversionOptions.new(), constant_op.constant(0)) self.assertTrue(sess.run(x)) diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 928ff9e7ea..a0d13c82a8 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -255,6 +255,7 @@ def _add_self_references(namespace, autograph_module): # internal modules. ag_internal = imp.new_module('autograph') ag_internal.converted_call = autograph_module.converted_call + ag_internal.ConversionOptions = autograph_module.ConversionOptions ag_internal.utils = utils ag_internal.rewrite_graph_construction_error = ( errors.rewrite_graph_construction_error) |