diff options
author | 2018-01-22 11:27:33 -0800 | |
---|---|---|
committer | 2018-01-22 11:31:49 -0800 | |
commit | 398852bb70a64a03465cb712a9616a0d56c4c3de (patch) | |
tree | f838ec86298bdea4dc971444288536472893f0f5 /tensorflow/contrib/py2tf | |
parent | 8451cee9722f11b3d28d234e2c383d59562eec15 (diff) |
Extend the API with a pair of decorators that can convert functions inline.
Add logic to detect these decorators and treat them appropriately (e.g. when converting recursively, do all decorated functions as they are).
PiperOrigin-RevId: 182808673
Diffstat (limited to 'tensorflow/contrib/py2tf')
-rw-r--r-- | tensorflow/contrib/py2tf/api.py | 134 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/api_test.py | 140 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/conversion.py | 54 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/conversion_test.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/call_trees.py | 109 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/call_trees_test.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/decorators.py | 56 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/naming.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/naming_test.py | 12 |
10 files changed, 477 insertions, 50 deletions
diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/api.py index 3a36720969..9a2b70c53c 100644 --- a/tensorflow/contrib/py2tf/api.py +++ b/tensorflow/contrib/py2tf/api.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from functools import wraps + import gast import six @@ -32,7 +34,111 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def to_graph(o, arg_value_hints=None): +def graph_ready(f): + """No-op decorator that explicitly marks a function as graph-ready. + + Graph-ready functions are assumed to not need any conversion. + + Args: + f: Any callable. + Returns: + f itself. + """ + setattr(f, '__pyct_is_compile_decorator', True) + return f + + +def convert_inline(f, *args, **kwargs): + """Shorthand to convert and call a function. + + For example, the following two statements are equivalent: + + @convert() + def foo(): + ... + foo(bar) + + def foo(): + ... + convert_inline(foo, bar) + + Args: + f: Function to convert. Only this call will be converted. + *args: Passed through to f. + **kwargs: Passed through to f, with the following exceptions: + * arg_value_hints: A dict mapping parameter names to objects that can + hint at the type of those parameters. + + Returns: + The result of the converted f applied to args and kwargs. + """ + if 'arg_value_hints' in kwargs: + arg_value_hints = kwargs['arg_value_hints'] + del kwargs['arg_value_hints'] + else: + arg_value_hints = None + if tf_inspect.ismethod(f): + # When converting methods, the result is still an unbound function. + args = (f.__self__,) + args + return convert(arg_value_hints)(f)(*args, **kwargs) + + +def convert(recursive=False, arg_value_hints=None): + """Decorator that compiles a function to graph mode. + + The decorator is dynamic - invoking compilation whenever the decorated fuction + is called. This means the parameter values are known at compilation. + + Args: + recursive: Whether to recusrively convert any functions that the decorator + function may call. + arg_value_hints: A dict mapping parameter names to objects that can hint + at the type of those parameters. + + Returns: + A decorator that compiles the given function to graph mode. + + Raises: + ValueError: If any of the arguments are illegal. + """ + if arg_value_hints is None: + arg_value_hints = {} + + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def wrapper(*args, **kwargs): + """Wrapper that calls the compiled version of the wrapped function.""" + partial_types = () + arg_names = tf_inspect.getargspec(f)[0] + for name, arg in zip(arg_names, args): + arg_class = arg.__class__ + if tf_inspect.isclass(arg_class): + # If arg_value_hints specifies any name, use that instead. + # TODO(mdan): Shouldn't this just be in the func's globals? + if name not in arg_value_hints: + arg_value_hints[name] = (arg_class.__name__, arg_class) + # Annotated methods need to specify that their owner type is partial, + # otherwise other members they call will not be converted. + if name == 'self': + partial_types = (arg_class,) + wrapped = to_graph( + f, + recursive=recursive, + arg_value_hints=arg_value_hints, + partial_types=partial_types) + return wrapped(*args, **kwargs) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +def to_graph(o, recursive=True, arg_value_hints=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: @@ -43,14 +149,22 @@ def to_graph(o, arg_value_hints=None): Args: o: A Python function or class. + recursive: Whether to recusrively convert any functions that the decorator + function may call. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. + partial_types: A set of types (e.g. classes) that will not be converted + entirely. Calls to member functions for these types will be renamed + independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ - conversion_map = conversion.ConversionMap() + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, graph_ready, convert_inline), + partial_types=partial_types) _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints) module = gast.Module([]) @@ -69,21 +183,33 @@ def to_graph(o, arg_value_hints=None): return compiled_fn -def to_code(o, arg_value_hints=None, indentation=' '): +def to_code(o, + recursive=True, + arg_value_hints=None, + partial_types=None, + indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: o: A Python function or class. + recursive: Whether to recusrively convert any functions that the decorator + function may call. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. + partial_types: A set of types (e.g. classes) that will not be converted + entirely. Calls to member functions for these types will be renamed + independently. indentation: String, when to use for each level of indentation. Returns: String. """ - conversion_map = conversion.ConversionMap() + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, graph_ready, convert_inline), + partial_types=partial_types) conversion.object_to_graph(o, conversion_map, arg_value_hints) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/api_test.py index 225b6d305f..2384447708 100644 --- a/tensorflow/contrib/py2tf/api_test.py +++ b/tensorflow/contrib/py2tf/api_test.py @@ -28,17 +28,146 @@ from tensorflow.python.platform import test class ApiTest(test.TestCase): + def setUp(self): + config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) + config.COMPILED_IMPORT_STATEMENTS = ( + 'from tensorflow.python.ops ' + 'import control_flow_ops as tf',) + + def test_decorator_recurses(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_does_not_recurse(self): + + class TestClass(object): + + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=False) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_calls_converted(self): + + class TestClass(object): + + @api.graph_ready + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_calls_decorated(self): + + class TestClass(object): + + @api.convert() + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_convert_call_site_decorator(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= api.convert_inline(self.called_member, a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_graph_ready_call_site_decorator(self): + + class TestClass(object): + + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= api.graph_ready(self.called_member(a)) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_to_graph_basic(self): def test_fn(x, s): while math_ops.reduce_sum(x) > s: x //= 2 return x - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) - config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.ops ' - 'import control_flow_ops as tf', - ) compiled_fn = api.to_graph(test_fn) with self.test_session() as sess: @@ -51,7 +180,6 @@ class ApiTest(test.TestCase): x /= 2 return x - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) compiled_code = api.to_code(test_fn) # Just check for some key words and that it is parseable Python code. diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py index 43bccae953..3bdbc66a99 100644 --- a/tensorflow/contrib/py2tf/conversion.py +++ b/tensorflow/contrib/py2tf/conversion.py @@ -28,6 +28,7 @@ from tensorflow.contrib.py2tf.convert import builtin_functions from tensorflow.contrib.py2tf.convert import call_trees from tensorflow.contrib.py2tf.convert import continue_canonicalization from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.convert import decorators from tensorflow.contrib.py2tf.convert import for_canonicalization from tensorflow.contrib.py2tf.convert import logical_expressions from tensorflow.contrib.py2tf.convert import print_functions @@ -39,22 +40,35 @@ from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect +# TODO(mdan): Might we not need any renaming at all? + + class ConversionMap(object): """ConversionMaps keep track of converting function hierarchies. Attributes: + recursive: Whether to recusrively convert any functions that the decorator + function may call. + nocompile_decorators: tuple of decorator functions that toggle compilation + off. dependency_cache: dict[object]: ast; maps original objects to their converted AST name_map: dict[string]: string; maps original objects to the name of their converted counterparts """ - def __init__(self): + # TODO(mdan): Rename to ConversionContext, and pull in additional flags. + + def __init__(self, recursive, nocompile_decorators, partial_types): + self.recursive = recursive + self.nocompile_decorators = nocompile_decorators + self.partial_types = partial_types if partial_types else () self.dependency_cache = {} self.name_map = {} def new_namer(self, global_symbols): - return naming.Namer(global_symbols, self.name_map) + return naming.Namer(global_symbols, self.recursive, self.name_map, + self.partial_types) def update_name_map(self, namer): for o, name in namer.renamed_calls.items(): @@ -102,19 +116,23 @@ def object_to_graph(o, conversion_map, value_hints): node, new_name = class_to_graph(o, conversion_map, value_hints) elif tf_inspect.isfunction(o): node, new_name = function_to_graph(o, conversion_map, value_hints) + elif tf_inspect.ismethod(o): + node, new_name = function_to_graph(o, conversion_map, value_hints) else: raise ValueError( - 'Unsupported object type %s. Only functions and classes are supported' - ' for now.') + 'Entity "%s" has unsupported type "%s". Only functions and classes are ' + 'supported for now.' % (o, type(o))) conversion_map.add_to_cache(o, node) - # Recursively convert remaining dependencies. - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: - if hasattr(obj, 'im_class'): - # Class members are converted with their objects. - continue - object_to_graph(obj, conversion_map, None) + if conversion_map.recursive: + for obj in conversion_map.name_map.keys(): + if obj not in conversion_map.dependency_cache: + if (hasattr(obj, 'im_class') and + getattr(obj, 'im_class') not in conversion_map.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + object_to_graph(obj, conversion_map, None) return node, new_name @@ -163,7 +181,8 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): node_globals[fn.__name__] = fn namer = conversion_map.new_namer(node_globals) - node = node_to_graph(node, namer, node_globals, param_value_hints) + node = node_to_graph(node, namer, node_globals, param_value_hints, + conversion_map.nocompile_decorators) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made @@ -184,7 +203,7 @@ def _static_analysis_pass(node, namespace, value_hints): return node -def node_to_graph(node, namer, namespace, value_hints): +def node_to_graph(node, namer, namespace, value_hints, nocompile_decorators): """Convert Python code to equivalent TF graph mode code. Args: @@ -193,6 +212,8 @@ def node_to_graph(node, namer, namespace, value_hints): namespace: Dict mapping symbol names to their corresponding live objects. value_hints: A dict containing value hints for symbols like function parameters. + nocompile_decorators: A tuple containing decorators to be stripped from + functions during conversion. Returns: A tuple (node, deps): @@ -200,6 +221,8 @@ def node_to_graph(node, namer, namespace, value_hints): * deps: A set of strings, the fully qualified names of object dependencies that this node has. """ + # TODO(mdan): Verify arguments for correctness. + # TODO(mdan): Factor out common elements. # These include: # * keeping track of symbols that have been created @@ -213,6 +236,7 @@ def node_to_graph(node, namer, namespace, value_hints): # to re-run the analysis. node = _static_analysis_pass(node, namespace, value_hints) + node = decorators.transform(node, nocompile_decorators) node = break_canonicalization.transform(node, namer) # Note: sequencing continue canonicalization before for loop one avoids @@ -230,7 +254,9 @@ def node_to_graph(node, namer, namespace, value_hints): node = _static_analysis_pass(node, namespace, value_hints) node = print_functions.transform(node) - node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) + node = call_trees.transform(node, namer, namespace, + config.DEFAULT_UNCOMPILED_MODULES, + nocompile_decorators) node = control_flow.transform(node, namer) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, namer) diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/conversion_test.py index d76f141809..e48bfe4464 100644 --- a/tensorflow/contrib/py2tf/conversion_test.py +++ b/tensorflow/contrib/py2tf/conversion_test.py @@ -28,13 +28,13 @@ class ConversionTest(test.TestCase): def test_object_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion.object_to_graph('dummy', {}, {}) + conversion.object_to_graph('dummy', None, {}) def test_object_to_graph_callable(self): def f(a): return a - conversion_map = conversion.ConversionMap() + conversion_map = conversion.ConversionMap(True, (), ()) ast, new_name = conversion.object_to_graph(f, conversion_map, {}) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name) @@ -46,7 +46,7 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = conversion.ConversionMap() + conversion_map = conversion.ConversionMap(True, (), ()) conversion.object_to_graph(f, conversion_map, {}) self.assertTrue(f in conversion_map.dependency_cache) diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/convert/BUILD index 0eb7998dc4..050e2ef108 100644 --- a/tensorflow/contrib/py2tf/convert/BUILD +++ b/tensorflow/contrib/py2tf/convert/BUILD @@ -22,6 +22,7 @@ py_library( "call_trees.py", "continue_canonicalization.py", "control_flow.py", + "decorators.py", "for_canonicalization.py", "logical_expressions.py", "print_functions.py", diff --git a/tensorflow/contrib/py2tf/convert/call_trees.py b/tensorflow/contrib/py2tf/convert/call_trees.py index 92c3439101..df071f596f 100644 --- a/tensorflow/contrib/py2tf/convert/call_trees.py +++ b/tensorflow/contrib/py2tf/convert/call_trees.py @@ -27,6 +27,7 @@ import types import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import templates @@ -64,16 +65,75 @@ class FunctionNamer(object): class CallTreeTransformer(gast.NodeTransformer): """Transforms the call tree by renaming transformed symbols.""" - def __init__(self, namer, uncompiled_modules): + def __init__(self, namer, namespace, uncompiled_modules, + nocompile_decorators): self.namer = namer + self.namespace = namespace self.uncompiled_modules = uncompiled_modules + self.nocompile_decorators = nocompile_decorators # pylint:disable=invalid-name - def _should_compile(self, fqn): + def _resolve_name(self, node): + if isinstance(node, gast.Call): + return self._resolve_name(node.func) + if isinstance(node, gast.Name): + return self.namespace.get(node.id) + if isinstance(node, gast.Attribute): + parent = self._resolve_name(node.value) + if parent is not None: + return getattr(parent, node.attr) + return None + raise ValueError(node) + + def _try_resolve_target(self, node): + """Works for methods of objects of known type.""" + if anno.hasanno(node, 'live_val'): + return anno.getanno(node, 'live_val') + if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): + member = getattr(anno.getanno(node, 'type'), node.attr) + return member + return None + + def _should_compile(self, node, fqn): for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False + + # Check for local decorations + if anno.hasanno(node, 'graph_ready'): + return False + + # The decorators themselves are not to be converted. + # If present, the decorators should appear as static functions. + target_obj = self._try_resolve_target(node.func) + if target_obj is not None: + # This attribute is set by the decorator itself. + # TODO(mdan): This may not play nicely with other wrapping decorators. + if hasattr(target_obj, '__pyct_is_compile_decorator'): + return False + + if target_obj in self.nocompile_decorators: + return False + + # Inspect the target function decorators. If any include a @convert + # or @graph_ready annotation, then they must be called as they are. + # TODO(mdan): This may be quite heavy. + # To parse and re-analize each function for every call site could be quite + # wasteful. Maybe we could cache the parsed AST? + try: + target_node = parser.parse_object(target_obj).body[0] + except TypeError: + # Functions whose source we cannot access are compilable (e.g. wrapped + # to py_func). + return True + + for dec in target_node.decorator_list: + decorator_fn = self._resolve_name(dec) + if (decorator_fn is not None and + decorator_fn in self.nocompile_decorators): + return False + return True def _rename_compilable_function(self, node): @@ -82,15 +142,15 @@ class CallTreeTransformer(gast.NodeTransformer): target_obj = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') - if not self._should_compile(target_fqn): + if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.namer.compiled_class_name( - '.'.join(target_fqn), live_object=target_obj) + '__'.join(target_fqn), live_object=target_obj) else: new_name = self.namer.compiled_function_name( - '.'.join(target_fqn), live_object=target_obj) + '__'.join(target_fqn), live_object=target_obj) node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None) return node @@ -101,15 +161,24 @@ class CallTreeTransformer(gast.NodeTransformer): assert anno.hasanno(node.func, 'type') target_type = anno.getanno(node.func, 'type') - if not self._should_compile(type_fqn): + if not self._should_compile(node, type_fqn): return node # TODO(mdan): We should not assume that the namer only needs the # member function name. + method_name = node.func.attr + method_object = getattr(target_type, method_name) new_name = self.namer.compiled_function_name( - node.func.attr, live_object=None, owner_type=target_type) - node.func.attr = new_name - + method_name, live_object=method_object, owner_type=target_type) + if new_name != node.func.attr: + # If a member function call is renamed, then the new function is no + # longer bound to the target object. We then refactor the call from: + # foo.bar(...) + # to: + # renamed_foo(bar, ...) + # TODO(mdan): This risks causing duplication, if target_type is renamed. + node.args = [node.func.value] + node.args + node.func = gast.Name(new_name, gast.Load(), None) return node def _wrap_to_py_func_no_return(self, node): @@ -136,6 +205,7 @@ class CallTreeTransformer(gast.NodeTransformer): wrapper=gast.Name(wrapper_name, gast.Load(), None), args=args) anno.setanno(call_expr.value, 'args_scope', args_scope) + # TODO(mdan): Rename this annotation to 'graph_ready' anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr) @@ -151,7 +221,7 @@ class CallTreeTransformer(gast.NodeTransformer): if not self._function_is_compilable(target_obj): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') - if not self._should_compile(target_fqn): + if not self._should_compile(node.value, target_fqn): return node node = self._wrap_to_py_func_no_return(node.value) return node @@ -163,6 +233,17 @@ class CallTreeTransformer(gast.NodeTransformer): return node def visit_Call(self, node): + # If the function is wrapped by one of the marker decorators, + # consider it graph ready. + if anno.hasanno(node.func, 'live_val'): + target_obj = anno.getanno(node.func, 'live_val') + if target_obj in self.nocompile_decorators: + if len(node.args) < 1: + raise ValueError( + 'Found call to decorator function "%s", but it had no arguments. ' + 'A decorator needs at least an argument.') + anno.setanno(node.args[0], 'graph_ready', True) + self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') @@ -180,20 +261,24 @@ class CallTreeTransformer(gast.NodeTransformer): # pylint:enable=invalid-name -def transform(node, namer, uncompiled_modules): +def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. Args: node: AST to transform. namer: FunctionNamer-like. + namespace: Dict mapping symbol names to their corresponding live objects. uncompiled_modules: set of string tuples, each tuple represents the fully qualified name of a package containing functions that will not be compiled. + nocompile_decorators: A tuple containing decorators to be stripped from + functions during conversion. Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - transformer = CallTreeTransformer(namer, uncompiled_modules) + transformer = CallTreeTransformer(namer, namespace, uncompiled_modules, + nocompile_decorators) node = transformer.visit(node) return node diff --git a/tensorflow/contrib/py2tf/convert/call_trees_test.py b/tensorflow/contrib/py2tf/convert/call_trees_test.py index 38c701eaad..3367d41db3 100644 --- a/tensorflow/contrib/py2tf/convert/call_trees_test.py +++ b/tensorflow/contrib/py2tf/convert/call_trees_test.py @@ -56,7 +56,7 @@ class CallTreesTest(test.TestCase): return test_fn_1(a) + 1 node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, TestNamer(), set()) + node = call_trees.transform(node, TestNamer(), {}, (), ()) result = compiler.ast_to_object(node) # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually. setattr(result, 'renamed_test_fn_1', renamed_test_fn_1) @@ -74,9 +74,9 @@ class CallTreesTest(test.TestCase): 'math_ops': math_ops, 'constant_op': constant_op }) - node = call_trees.transform(node, TestNamer(), + node = call_trees.transform(node, TestNamer(), {}, set(((math_ops.__name__,), - (constant_op.__name__,)))) + (constant_op.__name__,))), ()) result = compiler.ast_to_object(node) setattr(result, 'math_ops', math_ops) setattr(result, 'constant_op', constant_op) diff --git a/tensorflow/contrib/py2tf/convert/decorators.py b/tensorflow/contrib/py2tf/convert/decorators.py new file mode 100644 index 0000000000..a4313bfa51 --- /dev/null +++ b/tensorflow/contrib/py2tf/convert/decorators.py @@ -0,0 +1,56 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles decorators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import pretty_printer + + +class DecoratorsTransformer(gast.NodeTransformer): + """Converts or removes decorators.""" + + def __init__(self, remove_decorators): + self.remove_decorators = remove_decorators + + # pylint:disable=invalid-name + + def visit_FunctionDef(self, node): + self.generic_visit(node) + for dec in node.decorator_list: + if isinstance(dec, gast.Call): + dec = dec.func + if not anno.hasanno(dec, 'live_val'): + raise ValueError( + 'Could not resolve decorator: %s' % pretty_printer.fmt(dec)) + dec_value = anno.getanno(dec, 'live_val') + if dec_value in self.remove_decorators: + continue + raise ValueError('Dont know how to convert decorators for now.') + node.decorator_list = [] + return node + + # pylint:enable=invalid-name + + +def transform(node, remove_decorators): + transformer = DecoratorsTransformer(remove_decorators) + node = transformer.visit(node) + return node diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/naming.py index 61772ec07b..a90758962b 100644 --- a/tensorflow/contrib/py2tf/naming.py +++ b/tensorflow/contrib/py2tf/naming.py @@ -34,8 +34,10 @@ class Namer(object): * side_effect_guards.SymbolNamer """ - def __init__(self, global_namespace, name_map=None): + def __init__(self, global_namespace, recursive, name_map, partial_types): self.global_namespace = global_namespace + self.recursive = recursive + self.partial_types = partial_types self.renamed_calls = {} if name_map is not None: @@ -54,6 +56,7 @@ class Namer(object): while new_name in self.global_namespace: n += 1 new_name = '%s_%d' % (new_name_root, n) + if live_object is not None: self.renamed_calls[live_object] = new_name self.generated_names.add(new_name) @@ -67,7 +70,9 @@ class Namer(object): if live_object is not None and live_object in self.renamed_calls: return self.renamed_calls[live_object] - if owner_type is None: + if not self.recursive: + new_name = original_name + elif owner_type is None or owner_type in self.partial_types: # Top level functions: rename new_name_root = 'tf__%s' % original_name new_name = new_name_root diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/naming_test.py index 9403d9ae1f..7bfc9b8733 100644 --- a/tensorflow/contrib/py2tf/naming_test.py +++ b/tensorflow/contrib/py2tf/naming_test.py @@ -28,7 +28,7 @@ class NamerTest(test.TestCase): def bar(): pass - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('tf__foo', namer.compiled_function_name('foo')) self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar)) self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls) @@ -38,7 +38,7 @@ class NamerTest(test.TestCase): def foo(): pass - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) @@ -46,22 +46,22 @@ class NamerTest(test.TestCase): def foo(): pass - namer = naming.Namer(set(('tf__foo',))) + namer = naming.Namer({'tf__foo': 1}, True, None, ()) self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo)) def test_new_symbol_tracks_names(self): - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp',), namer.generated_names) def test_new_symbol_avoids_duplicates(self): - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertEqual('temp_1', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names) def test_new_symbol_avoids_conflicts(self): - namer = naming.Namer(set(('temp',))) + namer = naming.Namer({'temp': 1}, True, None, ()) # temp is reserved in the global namespace self.assertEqual('temp_1', namer.new_symbol('temp', set())) # temp_2 is reserved in the local namespace |