aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/py2tf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-22 11:27:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-22 11:31:49 -0800
commit398852bb70a64a03465cb712a9616a0d56c4c3de (patch)
treef838ec86298bdea4dc971444288536472893f0f5 /tensorflow/contrib/py2tf
parent8451cee9722f11b3d28d234e2c383d59562eec15 (diff)
Extend the API with a pair of decorators that can convert functions inline.
Add logic to detect these decorators and treat them appropriately (e.g. when converting recursively, do all decorated functions as they are). PiperOrigin-RevId: 182808673
Diffstat (limited to 'tensorflow/contrib/py2tf')
-rw-r--r--tensorflow/contrib/py2tf/api.py134
-rw-r--r--tensorflow/contrib/py2tf/api_test.py140
-rw-r--r--tensorflow/contrib/py2tf/conversion.py54
-rw-r--r--tensorflow/contrib/py2tf/conversion_test.py6
-rw-r--r--tensorflow/contrib/py2tf/convert/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/convert/call_trees.py109
-rw-r--r--tensorflow/contrib/py2tf/convert/call_trees_test.py6
-rw-r--r--tensorflow/contrib/py2tf/convert/decorators.py56
-rw-r--r--tensorflow/contrib/py2tf/naming.py9
-rw-r--r--tensorflow/contrib/py2tf/naming_test.py12
10 files changed, 477 insertions, 50 deletions
diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/api.py
index 3a36720969..9a2b70c53c 100644
--- a/tensorflow/contrib/py2tf/api.py
+++ b/tensorflow/contrib/py2tf/api.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from functools import wraps
+
import gast
import six
@@ -32,7 +34,111 @@ from tensorflow.python.util import tf_inspect
# (currently we require (module + class name, type))
-def to_graph(o, arg_value_hints=None):
+def graph_ready(f):
+ """No-op decorator that explicitly marks a function as graph-ready.
+
+ Graph-ready functions are assumed to not need any conversion.
+
+ Args:
+ f: Any callable.
+ Returns:
+ f itself.
+ """
+ setattr(f, '__pyct_is_compile_decorator', True)
+ return f
+
+
+def convert_inline(f, *args, **kwargs):
+ """Shorthand to convert and call a function.
+
+ For example, the following two statements are equivalent:
+
+ @convert()
+ def foo():
+ ...
+ foo(bar)
+
+ def foo():
+ ...
+ convert_inline(foo, bar)
+
+ Args:
+ f: Function to convert. Only this call will be converted.
+ *args: Passed through to f.
+ **kwargs: Passed through to f, with the following exceptions:
+ * arg_value_hints: A dict mapping parameter names to objects that can
+ hint at the type of those parameters.
+
+ Returns:
+ The result of the converted f applied to args and kwargs.
+ """
+ if 'arg_value_hints' in kwargs:
+ arg_value_hints = kwargs['arg_value_hints']
+ del kwargs['arg_value_hints']
+ else:
+ arg_value_hints = None
+ if tf_inspect.ismethod(f):
+ # When converting methods, the result is still an unbound function.
+ args = (f.__self__,) + args
+ return convert(arg_value_hints)(f)(*args, **kwargs)
+
+
+def convert(recursive=False, arg_value_hints=None):
+ """Decorator that compiles a function to graph mode.
+
+ The decorator is dynamic - invoking compilation whenever the decorated fuction
+ is called. This means the parameter values are known at compilation.
+
+ Args:
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
+ arg_value_hints: A dict mapping parameter names to objects that can hint
+ at the type of those parameters.
+
+ Returns:
+ A decorator that compiles the given function to graph mode.
+
+ Raises:
+ ValueError: If any of the arguments are illegal.
+ """
+ if arg_value_hints is None:
+ arg_value_hints = {}
+
+ def decorator(f):
+ """Decorator implementation."""
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ """Wrapper that calls the compiled version of the wrapped function."""
+ partial_types = ()
+ arg_names = tf_inspect.getargspec(f)[0]
+ for name, arg in zip(arg_names, args):
+ arg_class = arg.__class__
+ if tf_inspect.isclass(arg_class):
+ # If arg_value_hints specifies any name, use that instead.
+ # TODO(mdan): Shouldn't this just be in the func's globals?
+ if name not in arg_value_hints:
+ arg_value_hints[name] = (arg_class.__name__, arg_class)
+ # Annotated methods need to specify that their owner type is partial,
+ # otherwise other members they call will not be converted.
+ if name == 'self':
+ partial_types = (arg_class,)
+ wrapped = to_graph(
+ f,
+ recursive=recursive,
+ arg_value_hints=arg_value_hints,
+ partial_types=partial_types)
+ return wrapped(*args, **kwargs)
+
+ # Sometimes the decorator is just desugared, making it impossible to detect.
+ # This attribute makes detection easier.
+ setattr(wrapper, '__pyct_is_compile_decorator', True)
+ return wrapper
+
+ return decorator
+
+
+def to_graph(o, recursive=True, arg_value_hints=None, partial_types=None):
"""Compile a Python entity into equivalent TensorFlow code.
Currently supported entities:
@@ -43,14 +149,22 @@ def to_graph(o, arg_value_hints=None):
Args:
o: A Python function or class.
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
arg_value_hints: A dict mapping parameter names to objects that can hint
at the type of those parameters.
+ partial_types: A set of types (e.g. classes) that will not be converted
+ entirely. Calls to member functions for these types will be renamed
+ independently.
Returns:
A function with a signature identical to `o`, but which when executed it
creates TF a graph that has the same functionality as the original entity.
"""
- conversion_map = conversion.ConversionMap()
+ conversion_map = conversion.ConversionMap(
+ recursive=recursive,
+ nocompile_decorators=(convert, graph_ready, convert_inline),
+ partial_types=partial_types)
_, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)
module = gast.Module([])
@@ -69,21 +183,33 @@ def to_graph(o, arg_value_hints=None):
return compiled_fn
-def to_code(o, arg_value_hints=None, indentation=' '):
+def to_code(o,
+ recursive=True,
+ arg_value_hints=None,
+ partial_types=None,
+ indentation=' '):
"""Return the equivalent of an entity in TensorFlow code.
See `to_graph` for more details.
Args:
o: A Python function or class.
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
arg_value_hints: A dict mapping parameter names to objects that can hint
at the type of those parameters.
+ partial_types: A set of types (e.g. classes) that will not be converted
+ entirely. Calls to member functions for these types will be renamed
+ independently.
indentation: String, when to use for each level of indentation.
Returns:
String.
"""
- conversion_map = conversion.ConversionMap()
+ conversion_map = conversion.ConversionMap(
+ recursive=recursive,
+ nocompile_decorators=(convert, graph_ready, convert_inline),
+ partial_types=partial_types)
conversion.object_to_graph(o, conversion_map, arg_value_hints)
imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/api_test.py
index 225b6d305f..2384447708 100644
--- a/tensorflow/contrib/py2tf/api_test.py
+++ b/tensorflow/contrib/py2tf/api_test.py
@@ -28,17 +28,146 @@ from tensorflow.python.platform import test
class ApiTest(test.TestCase):
+ def setUp(self):
+ config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
+ config.COMPILED_IMPORT_STATEMENTS = (
+ 'from tensorflow.python.ops '
+ 'import control_flow_ops as tf',)
+
+ def test_decorator_recurses(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_does_not_recurse(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=False)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_converted(self):
+
+ class TestClass(object):
+
+ @api.graph_ready
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_decorated(self):
+
+ class TestClass(object):
+
+ @api.convert()
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_convert_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= api.convert_inline(self.called_member, a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_graph_ready_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= api.graph_ready(self.called_member(a))
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
def test_to_graph_basic(self):
def test_fn(x, s):
while math_ops.reduce_sum(x) > s:
x //= 2
return x
- config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
- config.COMPILED_IMPORT_STATEMENTS = (
- 'from tensorflow.python.ops '
- 'import control_flow_ops as tf',
- )
compiled_fn = api.to_graph(test_fn)
with self.test_session() as sess:
@@ -51,7 +180,6 @@ class ApiTest(test.TestCase):
x /= 2
return x
- config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
compiled_code = api.to_code(test_fn)
# Just check for some key words and that it is parseable Python code.
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py
index 43bccae953..3bdbc66a99 100644
--- a/tensorflow/contrib/py2tf/conversion.py
+++ b/tensorflow/contrib/py2tf/conversion.py
@@ -28,6 +28,7 @@ from tensorflow.contrib.py2tf.convert import builtin_functions
from tensorflow.contrib.py2tf.convert import call_trees
from tensorflow.contrib.py2tf.convert import continue_canonicalization
from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.convert import decorators
from tensorflow.contrib.py2tf.convert import for_canonicalization
from tensorflow.contrib.py2tf.convert import logical_expressions
from tensorflow.contrib.py2tf.convert import print_functions
@@ -39,22 +40,35 @@ from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.util import tf_inspect
+# TODO(mdan): Might we not need any renaming at all?
+
+
class ConversionMap(object):
"""ConversionMaps keep track of converting function hierarchies.
Attributes:
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
+ nocompile_decorators: tuple of decorator functions that toggle compilation
+ off.
dependency_cache: dict[object]: ast; maps original objects to their
converted AST
name_map: dict[string]: string; maps original objects to the name of
their converted counterparts
"""
- def __init__(self):
+ # TODO(mdan): Rename to ConversionContext, and pull in additional flags.
+
+ def __init__(self, recursive, nocompile_decorators, partial_types):
+ self.recursive = recursive
+ self.nocompile_decorators = nocompile_decorators
+ self.partial_types = partial_types if partial_types else ()
self.dependency_cache = {}
self.name_map = {}
def new_namer(self, global_symbols):
- return naming.Namer(global_symbols, self.name_map)
+ return naming.Namer(global_symbols, self.recursive, self.name_map,
+ self.partial_types)
def update_name_map(self, namer):
for o, name in namer.renamed_calls.items():
@@ -102,19 +116,23 @@ def object_to_graph(o, conversion_map, value_hints):
node, new_name = class_to_graph(o, conversion_map, value_hints)
elif tf_inspect.isfunction(o):
node, new_name = function_to_graph(o, conversion_map, value_hints)
+ elif tf_inspect.ismethod(o):
+ node, new_name = function_to_graph(o, conversion_map, value_hints)
else:
raise ValueError(
- 'Unsupported object type %s. Only functions and classes are supported'
- ' for now.')
+ 'Entity "%s" has unsupported type "%s". Only functions and classes are '
+ 'supported for now.' % (o, type(o)))
conversion_map.add_to_cache(o, node)
- # Recursively convert remaining dependencies.
- for obj in conversion_map.name_map.keys():
- if obj not in conversion_map.dependency_cache:
- if hasattr(obj, 'im_class'):
- # Class members are converted with their objects.
- continue
- object_to_graph(obj, conversion_map, None)
+ if conversion_map.recursive:
+ for obj in conversion_map.name_map.keys():
+ if obj not in conversion_map.dependency_cache:
+ if (hasattr(obj, 'im_class') and
+ getattr(obj, 'im_class') not in conversion_map.partial_types):
+ # Class members are converted with their objects, unless they're
+ # only converted partially.
+ continue
+ object_to_graph(obj, conversion_map, None)
return node, new_name
@@ -163,7 +181,8 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
node_globals[fn.__name__] = fn
namer = conversion_map.new_namer(node_globals)
- node = node_to_graph(node, namer, node_globals, param_value_hints)
+ node = node_to_graph(node, namer, node_globals, param_value_hints,
+ conversion_map.nocompile_decorators)
# Simulate a rename to ensure the top level is in the name map. This is needed
# for top level functions, and it also helps the consistency verification made
@@ -184,7 +203,7 @@ def _static_analysis_pass(node, namespace, value_hints):
return node
-def node_to_graph(node, namer, namespace, value_hints):
+def node_to_graph(node, namer, namespace, value_hints, nocompile_decorators):
"""Convert Python code to equivalent TF graph mode code.
Args:
@@ -193,6 +212,8 @@ def node_to_graph(node, namer, namespace, value_hints):
namespace: Dict mapping symbol names to their corresponding live objects.
value_hints: A dict containing value hints for symbols like function
parameters.
+ nocompile_decorators: A tuple containing decorators to be stripped from
+ functions during conversion.
Returns:
A tuple (node, deps):
@@ -200,6 +221,8 @@ def node_to_graph(node, namer, namespace, value_hints):
* deps: A set of strings, the fully qualified names of object
dependencies that this node has.
"""
+ # TODO(mdan): Verify arguments for correctness.
+
# TODO(mdan): Factor out common elements.
# These include:
# * keeping track of symbols that have been created
@@ -213,6 +236,7 @@ def node_to_graph(node, namer, namespace, value_hints):
# to re-run the analysis.
node = _static_analysis_pass(node, namespace, value_hints)
+ node = decorators.transform(node, nocompile_decorators)
node = break_canonicalization.transform(node, namer)
# Note: sequencing continue canonicalization before for loop one avoids
@@ -230,7 +254,9 @@ def node_to_graph(node, namer, namespace, value_hints):
node = _static_analysis_pass(node, namespace, value_hints)
node = print_functions.transform(node)
- node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
+ node = call_trees.transform(node, namer, namespace,
+ config.DEFAULT_UNCOMPILED_MODULES,
+ nocompile_decorators)
node = control_flow.transform(node, namer)
node = logical_expressions.transform(node)
node = side_effect_guards.transform(node, namer)
diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/conversion_test.py
index d76f141809..e48bfe4464 100644
--- a/tensorflow/contrib/py2tf/conversion_test.py
+++ b/tensorflow/contrib/py2tf/conversion_test.py
@@ -28,13 +28,13 @@ class ConversionTest(test.TestCase):
def test_object_to_graph_unsupported_types(self):
with self.assertRaises(ValueError):
- conversion.object_to_graph('dummy', {}, {})
+ conversion.object_to_graph('dummy', None, {})
def test_object_to_graph_callable(self):
def f(a):
return a
- conversion_map = conversion.ConversionMap()
+ conversion_map = conversion.ConversionMap(True, (), ())
ast, new_name = conversion.object_to_graph(f, conversion_map, {})
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
self.assertEqual('tf__f', new_name)
@@ -46,7 +46,7 @@ class ConversionTest(test.TestCase):
def f(a):
return g(a)
- conversion_map = conversion.ConversionMap()
+ conversion_map = conversion.ConversionMap(True, (), ())
conversion.object_to_graph(f, conversion_map, {})
self.assertTrue(f in conversion_map.dependency_cache)
diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/convert/BUILD
index 0eb7998dc4..050e2ef108 100644
--- a/tensorflow/contrib/py2tf/convert/BUILD
+++ b/tensorflow/contrib/py2tf/convert/BUILD
@@ -22,6 +22,7 @@ py_library(
"call_trees.py",
"continue_canonicalization.py",
"control_flow.py",
+ "decorators.py",
"for_canonicalization.py",
"logical_expressions.py",
"print_functions.py",
diff --git a/tensorflow/contrib/py2tf/convert/call_trees.py b/tensorflow/contrib/py2tf/convert/call_trees.py
index 92c3439101..df071f596f 100644
--- a/tensorflow/contrib/py2tf/convert/call_trees.py
+++ b/tensorflow/contrib/py2tf/convert/call_trees.py
@@ -27,6 +27,7 @@ import types
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import templates
@@ -64,16 +65,75 @@ class FunctionNamer(object):
class CallTreeTransformer(gast.NodeTransformer):
"""Transforms the call tree by renaming transformed symbols."""
- def __init__(self, namer, uncompiled_modules):
+ def __init__(self, namer, namespace, uncompiled_modules,
+ nocompile_decorators):
self.namer = namer
+ self.namespace = namespace
self.uncompiled_modules = uncompiled_modules
+ self.nocompile_decorators = nocompile_decorators
# pylint:disable=invalid-name
- def _should_compile(self, fqn):
+ def _resolve_name(self, node):
+ if isinstance(node, gast.Call):
+ return self._resolve_name(node.func)
+ if isinstance(node, gast.Name):
+ return self.namespace.get(node.id)
+ if isinstance(node, gast.Attribute):
+ parent = self._resolve_name(node.value)
+ if parent is not None:
+ return getattr(parent, node.attr)
+ return None
+ raise ValueError(node)
+
+ def _try_resolve_target(self, node):
+ """Works for methods of objects of known type."""
+ if anno.hasanno(node, 'live_val'):
+ return anno.getanno(node, 'live_val')
+ if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
+ member = getattr(anno.getanno(node, 'type'), node.attr)
+ return member
+ return None
+
+ def _should_compile(self, node, fqn):
for i in range(1, len(fqn)):
if fqn[:i] in self.uncompiled_modules:
return False
+
+ # Check for local decorations
+ if anno.hasanno(node, 'graph_ready'):
+ return False
+
+ # The decorators themselves are not to be converted.
+ # If present, the decorators should appear as static functions.
+ target_obj = self._try_resolve_target(node.func)
+ if target_obj is not None:
+ # This attribute is set by the decorator itself.
+ # TODO(mdan): This may not play nicely with other wrapping decorators.
+ if hasattr(target_obj, '__pyct_is_compile_decorator'):
+ return False
+
+ if target_obj in self.nocompile_decorators:
+ return False
+
+ # Inspect the target function decorators. If any include a @convert
+ # or @graph_ready annotation, then they must be called as they are.
+ # TODO(mdan): This may be quite heavy.
+ # To parse and re-analize each function for every call site could be quite
+ # wasteful. Maybe we could cache the parsed AST?
+ try:
+ target_node = parser.parse_object(target_obj).body[0]
+ except TypeError:
+ # Functions whose source we cannot access are compilable (e.g. wrapped
+ # to py_func).
+ return True
+
+ for dec in target_node.decorator_list:
+ decorator_fn = self._resolve_name(dec)
+ if (decorator_fn is not None and
+ decorator_fn in self.nocompile_decorators):
+ return False
+
return True
def _rename_compilable_function(self, node):
@@ -82,15 +142,15 @@ class CallTreeTransformer(gast.NodeTransformer):
target_obj = anno.getanno(node.func, 'live_val')
target_fqn = anno.getanno(node.func, 'fqn')
- if not self._should_compile(target_fqn):
+ if not self._should_compile(node, target_fqn):
return node
if anno.hasanno(node, 'is_constructor'):
new_name = self.namer.compiled_class_name(
- '.'.join(target_fqn), live_object=target_obj)
+ '__'.join(target_fqn), live_object=target_obj)
else:
new_name = self.namer.compiled_function_name(
- '.'.join(target_fqn), live_object=target_obj)
+ '__'.join(target_fqn), live_object=target_obj)
node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
return node
@@ -101,15 +161,24 @@ class CallTreeTransformer(gast.NodeTransformer):
assert anno.hasanno(node.func, 'type')
target_type = anno.getanno(node.func, 'type')
- if not self._should_compile(type_fqn):
+ if not self._should_compile(node, type_fqn):
return node
# TODO(mdan): We should not assume that the namer only needs the
# member function name.
+ method_name = node.func.attr
+ method_object = getattr(target_type, method_name)
new_name = self.namer.compiled_function_name(
- node.func.attr, live_object=None, owner_type=target_type)
- node.func.attr = new_name
-
+ method_name, live_object=method_object, owner_type=target_type)
+ if new_name != node.func.attr:
+ # If a member function call is renamed, then the new function is no
+ # longer bound to the target object. We then refactor the call from:
+ # foo.bar(...)
+ # to:
+ # renamed_foo(bar, ...)
+ # TODO(mdan): This risks causing duplication, if target_type is renamed.
+ node.args = [node.func.value] + node.args
+ node.func = gast.Name(new_name, gast.Load(), None)
return node
def _wrap_to_py_func_no_return(self, node):
@@ -136,6 +205,7 @@ class CallTreeTransformer(gast.NodeTransformer):
wrapper=gast.Name(wrapper_name, gast.Load(), None),
args=args)
anno.setanno(call_expr.value, 'args_scope', args_scope)
+ # TODO(mdan): Rename this annotation to 'graph_ready'
anno.setanno(wrapper_def, 'skip_processing', True)
return (wrapper_def, call_expr)
@@ -151,7 +221,7 @@ class CallTreeTransformer(gast.NodeTransformer):
if not self._function_is_compilable(target_obj):
if anno.hasanno(node.value.func, 'fqn'):
target_fqn = anno.getanno(node.value.func, 'fqn')
- if not self._should_compile(target_fqn):
+ if not self._should_compile(node.value, target_fqn):
return node
node = self._wrap_to_py_func_no_return(node.value)
return node
@@ -163,6 +233,17 @@ class CallTreeTransformer(gast.NodeTransformer):
return node
def visit_Call(self, node):
+ # If the function is wrapped by one of the marker decorators,
+ # consider it graph ready.
+ if anno.hasanno(node.func, 'live_val'):
+ target_obj = anno.getanno(node.func, 'live_val')
+ if target_obj in self.nocompile_decorators:
+ if len(node.args) < 1:
+ raise ValueError(
+ 'Found call to decorator function "%s", but it had no arguments. '
+ 'A decorator needs at least an argument.')
+ anno.setanno(node.args[0], 'graph_ready', True)
+
self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
target_obj = anno.getanno(node.func, 'live_val')
@@ -180,20 +261,24 @@ class CallTreeTransformer(gast.NodeTransformer):
# pylint:enable=invalid-name
-def transform(node, namer, uncompiled_modules):
+def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators):
"""Transform function call to the compiled counterparts.
Args:
node: AST to transform.
namer: FunctionNamer-like.
+ namespace: Dict mapping symbol names to their corresponding live objects.
uncompiled_modules: set of string tuples, each tuple represents the fully
qualified name of a package containing functions that will not be
compiled.
+ nocompile_decorators: A tuple containing decorators to be stripped from
+ functions during conversion.
Returns:
A tuple (node, new_names):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
- transformer = CallTreeTransformer(namer, uncompiled_modules)
+ transformer = CallTreeTransformer(namer, namespace, uncompiled_modules,
+ nocompile_decorators)
node = transformer.visit(node)
return node
diff --git a/tensorflow/contrib/py2tf/convert/call_trees_test.py b/tensorflow/contrib/py2tf/convert/call_trees_test.py
index 38c701eaad..3367d41db3 100644
--- a/tensorflow/contrib/py2tf/convert/call_trees_test.py
+++ b/tensorflow/contrib/py2tf/convert/call_trees_test.py
@@ -56,7 +56,7 @@ class CallTreesTest(test.TestCase):
return test_fn_1(a) + 1
node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
- node = call_trees.transform(node, TestNamer(), set())
+ node = call_trees.transform(node, TestNamer(), {}, (), ())
result = compiler.ast_to_object(node)
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
@@ -74,9 +74,9 @@ class CallTreesTest(test.TestCase):
'math_ops': math_ops,
'constant_op': constant_op
})
- node = call_trees.transform(node, TestNamer(),
+ node = call_trees.transform(node, TestNamer(), {},
set(((math_ops.__name__,),
- (constant_op.__name__,))))
+ (constant_op.__name__,))), ())
result = compiler.ast_to_object(node)
setattr(result, 'math_ops', math_ops)
setattr(result, 'constant_op', constant_op)
diff --git a/tensorflow/contrib/py2tf/convert/decorators.py b/tensorflow/contrib/py2tf/convert/decorators.py
new file mode 100644
index 0000000000..a4313bfa51
--- /dev/null
+++ b/tensorflow/contrib/py2tf/convert/decorators.py
@@ -0,0 +1,56 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles decorators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import pretty_printer
+
+
+class DecoratorsTransformer(gast.NodeTransformer):
+ """Converts or removes decorators."""
+
+ def __init__(self, remove_decorators):
+ self.remove_decorators = remove_decorators
+
+ # pylint:disable=invalid-name
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ for dec in node.decorator_list:
+ if isinstance(dec, gast.Call):
+ dec = dec.func
+ if not anno.hasanno(dec, 'live_val'):
+ raise ValueError(
+ 'Could not resolve decorator: %s' % pretty_printer.fmt(dec))
+ dec_value = anno.getanno(dec, 'live_val')
+ if dec_value in self.remove_decorators:
+ continue
+ raise ValueError('Dont know how to convert decorators for now.')
+ node.decorator_list = []
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node, remove_decorators):
+ transformer = DecoratorsTransformer(remove_decorators)
+ node = transformer.visit(node)
+ return node
diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/naming.py
index 61772ec07b..a90758962b 100644
--- a/tensorflow/contrib/py2tf/naming.py
+++ b/tensorflow/contrib/py2tf/naming.py
@@ -34,8 +34,10 @@ class Namer(object):
* side_effect_guards.SymbolNamer
"""
- def __init__(self, global_namespace, name_map=None):
+ def __init__(self, global_namespace, recursive, name_map, partial_types):
self.global_namespace = global_namespace
+ self.recursive = recursive
+ self.partial_types = partial_types
self.renamed_calls = {}
if name_map is not None:
@@ -54,6 +56,7 @@ class Namer(object):
while new_name in self.global_namespace:
n += 1
new_name = '%s_%d' % (new_name_root, n)
+
if live_object is not None:
self.renamed_calls[live_object] = new_name
self.generated_names.add(new_name)
@@ -67,7 +70,9 @@ class Namer(object):
if live_object is not None and live_object in self.renamed_calls:
return self.renamed_calls[live_object]
- if owner_type is None:
+ if not self.recursive:
+ new_name = original_name
+ elif owner_type is None or owner_type in self.partial_types:
# Top level functions: rename
new_name_root = 'tf__%s' % original_name
new_name = new_name_root
diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/naming_test.py
index 9403d9ae1f..7bfc9b8733 100644
--- a/tensorflow/contrib/py2tf/naming_test.py
+++ b/tensorflow/contrib/py2tf/naming_test.py
@@ -28,7 +28,7 @@ class NamerTest(test.TestCase):
def bar():
pass
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('tf__foo', namer.compiled_function_name('foo'))
self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar))
self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls)
@@ -38,7 +38,7 @@ class NamerTest(test.TestCase):
def foo():
pass
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
@@ -46,22 +46,22 @@ class NamerTest(test.TestCase):
def foo():
pass
- namer = naming.Namer(set(('tf__foo',)))
+ namer = naming.Namer({'tf__foo': 1}, True, None, ())
self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo))
def test_new_symbol_tracks_names(self):
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('temp', namer.new_symbol('temp', set()))
self.assertItemsEqual(('temp',), namer.generated_names)
def test_new_symbol_avoids_duplicates(self):
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('temp', namer.new_symbol('temp', set()))
self.assertEqual('temp_1', namer.new_symbol('temp', set()))
self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names)
def test_new_symbol_avoids_conflicts(self):
- namer = naming.Namer(set(('temp',)))
+ namer = naming.Namer({'temp': 1}, True, None, ())
# temp is reserved in the global namespace
self.assertEqual('temp_1', namer.new_symbol('temp', set()))
# temp_2 is reserved in the local namespace