aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-29 12:12:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 12:16:30 -0800
commit95a8af24058c168ce8a5327451e1cfcbc56461eb (patch)
treeb3764f80ff693a29584cc09461ce0b5f3c5c69bd
parent8bfaa9213b640201b6886f3f245a1ad1a7461030 (diff)
Ensure that non-recursive conversion is identity transformation wrt all types of function calls by only failing on unresolved symbols if they're needed.
Simplify code structure all around. Remove the awkward activity analysis that deemed a function parameter as "modified". Consolidate activity analysis by tracking function parameters and returned symbols separately. Strengthen the type inference a little by using more interpret-like constructs. PiperOrigin-RevId: 183705547
-rw-r--r--tensorflow/contrib/py2tf/conversion.py33
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees.py159
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees_test.py46
-rw-r--r--tensorflow/contrib/py2tf/converters/converter_test_base.py19
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards.py5
-rw-r--r--tensorflow/contrib/py2tf/naming.py68
-rw-r--r--tensorflow/contrib/py2tf/naming_test.py14
-rw-r--r--tensorflow/contrib/py2tf/pyct/context.py3
-rw-r--r--tensorflow/contrib/py2tf/pyct/parser.py10
-rw-r--r--tensorflow/contrib/py2tf/pyct/parser_test.py11
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access.py67
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py42
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py52
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py56
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py66
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py60
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py3
-rw-r--r--tensorflow/contrib/py2tf/pyct/transformer.py17
18 files changed, 442 insertions, 289 deletions
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py
index b484eebbd5..e277eadec4 100644
--- a/tensorflow/contrib/py2tf/conversion.py
+++ b/tensorflow/contrib/py2tf/conversion.py
@@ -171,7 +171,8 @@ def class_to_graph(c, conversion_map):
def function_to_graph(f, conversion_map, arg_values, arg_types,
owner_type=None):
"""Specialization of `entity_to_graph` for callable functions."""
- node = parser.parse_object(f).body[0]
+ node, source = parser.parse_entity(f)
+ node = node.body[0]
namespace = six.get_function_globals(f)
# This is needed for non-global functions.
@@ -185,28 +186,29 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
namer = conversion_map.new_namer(namespace)
ctx = context.EntityContext(
namer=namer,
- source_code=tf_inspect.getsource(f),
- source_file=tf_inspect.getfile(f),
+ source_code=source,
+ source_file='<fragment>',
namespace=namespace,
arg_values=arg_values,
- arg_types=arg_types)
+ arg_types=arg_types,
+ recursive=conversion_map.recursive)
node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
- # Simulate a rename to ensure the top level is in the name map. This is needed
- # for top level functions, and it also helps the consistency verification made
- # by update_name_map.
- if owner_type is not None:
- new_name = namer.compiled_function_name(f.__name__, f, owner_type)
- else:
- new_name = namer.compiled_function_name(f.__name__, f)
+ # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
+ new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
+ if not did_rename:
+ new_name = f.__name__
+ if node.name != f.__name__:
+ raise NotImplementedError('Strange corner case. Send us offending code!')
+
node.name = new_name
conversion_map.update_name_map(namer)
- return node, conversion_map.name_map[f]
+ return node, new_name
def _static_analysis_pass(node, ctx):
- node = access.resolve(node)
- node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS)
+ node = access.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, config.PYTHON_LITERALS)
node = type_info.resolve(node, ctx)
return node
@@ -259,8 +261,7 @@ def node_to_graph(node, ctx, nocompile_decorators):
node = _static_analysis_pass(node, ctx)
node = print_functions.transform(node)
- node = call_trees.transform(node, ctx.namer, ctx.namespace,
- config.DEFAULT_UNCOMPILED_MODULES,
+ node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
nocompile_decorators)
node = control_flow.transform(node, ctx.namer)
node = logical_expressions.transform(node)
diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py
index 0aae030450..4c238b7fb9 100644
--- a/tensorflow/contrib/py2tf/converters/call_trees.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees.py
@@ -29,46 +29,46 @@ import gast
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.python.util import tf_inspect
class FunctionNamer(object):
"""Describes the interface for CallTreeTransformer's namer."""
def compiled_function_name(self,
- original_name,
- live_object=None,
+ original_fqn,
+ live_entity=None,
owner_type=None):
"""Generate the name corresponding to the compiled version of a function.
Args:
- original_name: String
- live_object: Callable, the actual target function, if known.
+ original_fqn: string or tuple(string)
+ live_entity: Callable, the actual target function, if known.
owner_type: Optional object. If present, it indicates that the function is
a member of the given type.
Returns:
- String.
+ string, bool
"""
raise NotImplementedError()
- def compiled_class_name(self, original_name, live_object=None):
+ def compiled_class_name(self, original_fqn, live_entity=None):
"""Generate the name corresponding to the compiled version of a class.
Args:
- original_name: String
- live_object: The actual target class, if known.
+ original_fqn: string or tuple(string)
+ live_entity: The actual target class, if known.
Returns:
- String.
+ string
"""
raise NotImplementedError()
-class CallTreeTransformer(gast.NodeTransformer):
+class CallTreeTransformer(transformer.Base):
"""Transforms the call tree by renaming transformed symbols."""
- def __init__(self, namer, namespace, uncompiled_modules,
- nocompile_decorators):
- self.namer = namer
- self.namespace = namespace
+ def __init__(self, context, uncompiled_modules, nocompile_decorators):
+ super(CallTreeTransformer, self).__init__(context)
self.uncompiled_modules = uncompiled_modules
self.nocompile_decorators = nocompile_decorators
@@ -78,7 +78,7 @@ class CallTreeTransformer(gast.NodeTransformer):
if isinstance(node, gast.Call):
return self._resolve_name(node.func)
if isinstance(node, gast.Name):
- return self.namespace.get(node.id)
+ return self.context.namespace.get(node.id)
if isinstance(node, gast.Attribute):
parent = self._resolve_name(node.value)
if parent is not None:
@@ -91,8 +91,12 @@ class CallTreeTransformer(gast.NodeTransformer):
if anno.hasanno(node, 'live_val'):
return anno.getanno(node, 'live_val')
if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
- member = getattr(anno.getanno(node, 'type'), node.attr)
- return member
+ owner_type = anno.getanno(node, 'type')
+ if hasattr(owner_type, node.attr):
+ return getattr(owner_type, node.attr)
+ else:
+ raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
+ (owner_type, node.attr))
return None
def _should_compile(self, node, fqn):
@@ -106,14 +110,14 @@ class CallTreeTransformer(gast.NodeTransformer):
# The decorators themselves are not to be converted.
# If present, the decorators should appear as static functions.
- target_obj = self._try_resolve_target(node.func)
- if target_obj is not None:
+ target_entity = self._try_resolve_target(node.func)
+ if target_entity is not None:
# This attribute is set by the decorator itself.
# TODO(mdan): This may not play nicely with other wrapping decorators.
- if hasattr(target_obj, '__pyct_is_compile_decorator'):
+ if hasattr(target_entity, '__pyct_is_compile_decorator'):
return False
- if target_obj in self.nocompile_decorators:
+ if target_entity in self.nocompile_decorators:
return False
# Inspect the target function decorators. If any include a @convert
@@ -122,7 +126,8 @@ class CallTreeTransformer(gast.NodeTransformer):
# To parse and re-analize each function for every call site could be quite
# wasteful. Maybe we could cache the parsed AST?
try:
- target_node = parser.parse_object(target_obj).body[0]
+ target_node, _ = parser.parse_entity(target_entity)
+ target_node = target_node.body[0]
except TypeError:
# Functions whose source we cannot access are compilable (e.g. wrapped
# to py_func).
@@ -136,48 +141,57 @@ class CallTreeTransformer(gast.NodeTransformer):
return True
+ def _determine_function_owner(self, m):
+ # TODO(mdan): The parent type should be known at analysis. Use that instead.
+ if hasattr(m, 'im_class'): # Python 2
+ return m.im_class
+ if hasattr(m, '__qualname__'): # Python 3
+ # Object attributes: should be bound to "self".
+ if hasattr(m, '__self__'):
+ return type(m.__self__)
+
+ # Class attributes: should have the owner name in their namespace.
+ qn = m.__qualname__.split('.')
+ if len(qn) < 2:
+ return None
+ owner_name, func_name = qn[-2:]
+ if func_name != m.__name__:
+ raise ValueError('Inconsistent names detected '
+ '(__qualname__[1] = "%s", __name__ = "%s") for %s.' %
+ (func_name, m.__name__, m))
+ if owner_name == '<locals>':
+ return None
+ if owner_name not in self.context.namespace:
+ raise ValueError(
+ 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' %
+ (owner_name, m, self.context.namespace))
+ return self.context.namespace[owner_name]
+ return None
+
def _rename_compilable_function(self, node):
assert anno.hasanno(node.func, 'live_val')
assert anno.hasanno(node.func, 'fqn')
- target_obj = anno.getanno(node.func, 'live_val')
+ target_entity = anno.getanno(node.func, 'live_val')
target_fqn = anno.getanno(node.func, 'fqn')
if not self._should_compile(node, target_fqn):
return node
if anno.hasanno(node, 'is_constructor'):
- new_name = self.namer.compiled_class_name(
- '__'.join(target_fqn), live_object=target_obj)
+ new_name = self.context.namer.compiled_class_name(
+ target_fqn, live_entity=target_entity)
+ do_rename = True
else:
- new_name = self.namer.compiled_function_name(
- '__'.join(target_fqn), live_object=target_obj)
- node.func = gast.Name(new_name, gast.Load(), None)
- return node
-
- def _rename_member_function_of_known_type(self, node):
- assert isinstance(node.func, gast.Attribute)
-
- type_fqn = anno.getanno(node.func, 'type_fqn')
- assert anno.hasanno(node.func, 'type')
- target_type = anno.getanno(node.func, 'type')
-
- if not self._should_compile(node, type_fqn):
- return node
-
- # TODO(mdan): We should not assume that the namer only needs the
- # member function name.
- method_name = node.func.attr
- method_object = getattr(target_type, method_name)
- new_name = self.namer.compiled_function_name(
- method_name, live_object=method_object, owner_type=target_type)
- if new_name != node.func.attr:
- # If a member function call is renamed, then the new function is no
- # longer bound to the target object. We then refactor the call from:
- # foo.bar(...)
- # to:
- # renamed_foo(bar, ...)
- # TODO(mdan): This risks causing duplication, if target_type is renamed.
- node.args = [node.func.value] + node.args
+ owner_type = self._determine_function_owner(target_entity)
+ new_name, do_rename = self.context.namer.compiled_function_name(
+ target_fqn, live_entity=target_entity, owner_type=owner_type)
+
+ if do_rename:
+ if target_entity is not None:
+ if tf_inspect.ismethod(target_entity):
+ # The renaming process will transform it into a regular function.
+ # TODO(mdan): Is this complete? How does it work with nested members?
+ node.args = [node.func.value] + node.args
node.func = gast.Name(new_name, gast.Load(), None)
return node
@@ -193,7 +207,7 @@ class CallTreeTransformer(gast.NodeTransformer):
wrapper_def, call_expr = templates.replace(
template,
call=node.func,
- wrapper=self.namer.compiled_function_name(node.func.id),
+ wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
anno.setanno(call_expr.value, 'args_scope', args_scope)
# TODO(mdan): Rename this annotation to 'graph_ready'
@@ -201,15 +215,15 @@ class CallTreeTransformer(gast.NodeTransformer):
return (wrapper_def, call_expr)
- def _function_is_compilable(self, target_obj):
+ def _function_is_compilable(self, target_entity):
# TODO(mdan): This is just a placeholder. Implement.
- return not isinstance(target_obj, types.BuiltinFunctionType)
+ return not isinstance(target_entity, types.BuiltinFunctionType)
def visit_Expr(self, node):
if isinstance(node.value, gast.Call):
if anno.hasanno(node.value.func, 'live_val'):
- target_obj = anno.getanno(node.value.func, 'live_val')
- if not self._function_is_compilable(target_obj):
+ target_entity = anno.getanno(node.value.func, 'live_val')
+ if not self._function_is_compilable(target_entity):
if anno.hasanno(node.value.func, 'fqn'):
target_fqn = anno.getanno(node.value.func, 'fqn')
if not self._should_compile(node.value, target_fqn):
@@ -227,8 +241,8 @@ class CallTreeTransformer(gast.NodeTransformer):
# If the function is wrapped by one of the marker decorators,
# consider it graph ready.
if anno.hasanno(node.func, 'live_val'):
- target_obj = anno.getanno(node.func, 'live_val')
- if target_obj in self.nocompile_decorators:
+ target_entity = anno.getanno(node.func, 'live_val')
+ if target_entity in self.nocompile_decorators:
if len(node.args) < 1:
raise ValueError(
'Found call to decorator function "%s", but it had no arguments. '
@@ -237,28 +251,28 @@ class CallTreeTransformer(gast.NodeTransformer):
self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
- target_obj = anno.getanno(node.func, 'live_val')
- if self._function_is_compilable(target_obj):
+ target_entity = anno.getanno(node.func, 'live_val')
+ if self._function_is_compilable(target_entity):
node = self._rename_compilable_function(node)
else:
raise NotImplementedError('py_func with return values')
- elif anno.hasanno(node.func, 'type_fqn'):
- node = self._rename_member_function_of_known_type(node)
else:
- raise NotImplementedError(
- 'Member function call (of unknown type): %s.' % node.func.id)
+ if self.context.recursive:
+ raise NotImplementedError('Could not resolve target function.')
+ else:
+ # TODO(mdan): Double check. Is this reachable code?
+ pass
return node
# pylint:enable=invalid-name
-def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators):
+def transform(node, context, uncompiled_modules, nocompile_decorators):
"""Transform function call to the compiled counterparts.
Args:
node: AST to transform.
- namer: FunctionNamer-like.
- namespace: Dict mapping symbol names to their corresponding live objects.
+ context: An EntityContext object.
uncompiled_modules: set of string tuples, each tuple represents the fully
qualified name of a package containing functions that will not be
compiled.
@@ -269,7 +283,6 @@ def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
- transformer = CallTreeTransformer(namer, namespace, uncompiled_modules,
- nocompile_decorators)
- node = transformer.visit(node)
+ t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators)
+ node = t.visit(node)
return node
diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py
index 8cb8d7be0f..e63c10de0f 100644
--- a/tensorflow/contrib/py2tf/converters/call_trees_test.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py
@@ -28,8 +28,13 @@ from tensorflow.python.platform import test
class TestNamer(call_trees.FunctionNamer):
- def compiled_function_name(self, original_name, live_object=None):
- return 'renamed_%s' % original_name
+ def compiled_function_name(self,
+ original_fqn,
+ live_entity=None,
+ owner_type=None):
+ if owner_type is not None:
+ return None, False
+ return ('renamed_%s' % '_'.join(original_fqn)), True
class CallTreesTest(converter_test_base.TestCase):
@@ -45,14 +50,35 @@ class CallTreesTest(converter_test_base.TestCase):
def test_fn_2(a):
return test_fn_1(a) + 1
- node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
- node = call_trees.transform(node, TestNamer(), {}, (), ())
+ node = self.parse_and_analyze(
+ test_fn_2, {'test_fn_1': test_fn_1}, namer=TestNamer())
+ node = call_trees.transform(node, self.ctx, (), ())
result = compiler.ast_to_object(node)
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
self.assertEquals(3, result.test_fn_2(1))
+ def test_simple_methods(self):
+
+ class TestClass(object):
+
+ def test_fn_1(self, a):
+ return a + 1
+
+ def test_fn_2(self, a):
+ return self.test_fn_1(a) + 1
+
+ node = self.parse_and_analyze(
+ TestClass.test_fn_2, {'TestClass': TestClass},
+ namer=TestNamer(),
+ arg_types={'self': (TestClass.__name__, TestClass)})
+ node = call_trees.transform(node, self.ctx, (), ())
+ result = compiler.ast_to_object(node)
+
+ tc = TestClass()
+ self.assertEquals(3, result.test_fn_2(tc, 1))
+
def test_uncompiled_modules(self):
def test_fn(a):
@@ -60,11 +86,13 @@ class CallTreesTest(converter_test_base.TestCase):
a = math_ops.add(a, constant_op.constant(1))
return a
- node = self.parse_and_analyze(test_fn, {
- 'math_ops': math_ops,
- 'constant_op': constant_op
- })
- node = call_trees.transform(node, TestNamer(), {},
+ node = self.parse_and_analyze(
+ test_fn, {
+ 'math_ops': math_ops,
+ 'constant_op': constant_op
+ },
+ namer=TestNamer())
+ node = call_trees.transform(node, self.ctx,
set(((math_ops.__name__,),
(constant_op.__name__,))), ())
result = compiler.ast_to_object(node)
diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py
index ed006bad6d..6bfa55443c 100644
--- a/tensorflow/contrib/py2tf/converters/converter_test_base.py
+++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py
@@ -31,18 +31,23 @@ class TestCase(test.TestCase):
def parse_and_analyze(self,
test_fn,
namespace,
+ namer=None,
arg_types=None,
- include_type_analysis=True):
+ include_type_analysis=True,
+ recursive=True):
+ node, source = parser.parse_entity(test_fn)
ctx = context.EntityContext(
- namer=None,
- source_code=None,
+ namer=namer,
+ source_code=source,
source_file=None,
namespace=namespace,
arg_values=None,
- arg_types=arg_types)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
+ arg_types=arg_types,
+ recursive=recursive)
+ node = access.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, {})
if include_type_analysis:
node = type_info.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, {})
+ self.ctx = ctx
return node
diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
index a88828ff80..46a2269c20 100644
--- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
@@ -94,6 +94,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
return node
def _gate_symbols(self, guard_statement, guarded_args):
+ # TODO(mdan): This won't work for variables.
template = """
(args,) = (tf.identity(a) for a in (args,))
"""
@@ -133,8 +134,8 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
# First, attempt to gate future evaluation of args. If that's not
# possible, gate all remaining statements (and that may fail too, see
# _visit_and_reindent.
- guarded_args = tuple(
- n for n in args_scope.used if n in args_scope.parent.modified)
+ guarded_args = tuple(args_scope.used & (args_scope.parent.modified
+ | args_scope.parent.returned))
if guarded_args:
node = tuple(statements[:-1]) + (
self._gate_symbols(control_deps_guard, guarded_args),)
diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/naming.py
index a90758962b..5c7e4c5f95 100644
--- a/tensorflow/contrib/py2tf/naming.py
+++ b/tensorflow/contrib/py2tf/naming.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.util import tf_inspect
-
class Namer(object):
"""Implementation of the namer interfaces required by various converters.
@@ -45,10 +43,15 @@ class Namer(object):
self.generated_names = set()
- def compiled_class_name(self, original_name, live_object=None):
+ def compiled_class_name(self, original_fqn, live_entity=None):
"""See call_trees.FunctionNamer.compiled_class_name."""
- if live_object is not None and live_object in self.renamed_calls:
- return self.renamed_calls[live_object]
+ if live_entity is not None and live_entity in self.renamed_calls:
+ return self.renamed_calls[live_entity]
+
+ if isinstance(original_fqn, tuple):
+ original_name = '__'.join(original_fqn)
+ else:
+ original_name = original_fqn
new_name_root = 'Tf%s' % original_name
new_name = new_name_root
@@ -57,41 +60,46 @@ class Namer(object):
n += 1
new_name = '%s_%d' % (new_name_root, n)
- if live_object is not None:
- self.renamed_calls[live_object] = new_name
+ if live_entity is not None:
+ self.renamed_calls[live_entity] = new_name
self.generated_names.add(new_name)
+ if live_entity is not None:
+ self.renamed_calls[live_entity] = new_name
return new_name
def compiled_function_name(self,
- original_name,
- live_object=None,
+ original_fqn,
+ live_entity=None,
owner_type=None):
"""See call_trees.FunctionNamer.compiled_function_name."""
- if live_object is not None and live_object in self.renamed_calls:
- return self.renamed_calls[live_object]
if not self.recursive:
- new_name = original_name
- elif owner_type is None or owner_type in self.partial_types:
- # Top level functions: rename
- new_name_root = 'tf__%s' % original_name
- new_name = new_name_root
- n = 0
- while new_name in self.global_namespace:
- n += 1
- new_name = '%s_%d' % (new_name_root, n)
+ return None, False
+
+ if owner_type is not None and owner_type not in self.partial_types:
+ # Members are not renamed when part of an entire converted class.
+ return None, False
+
+ if isinstance(original_fqn, tuple):
+ original_name = '__'.join(original_fqn)
else:
- if tf_inspect.isclass(owner_type):
- # Class members: do not rename (the entire class will be renamed)
- new_name = original_name
- else:
- raise NotImplementedError('Member function "%s" of non-class type: %s' %
- (original_name, owner_type))
-
- if live_object is not None:
- self.renamed_calls[live_object] = new_name
+ original_name = original_fqn
+
+ if live_entity is not None and live_entity in self.renamed_calls:
+ return self.renamed_calls[live_entity], True
+
+ new_name_root = 'tf__%s' % original_name
+ new_name = new_name_root
+ n = 0
+ while new_name in self.global_namespace:
+ n += 1
+ new_name = '%s_%d' % (new_name_root, n)
+
+ if live_entity is not None:
+ self.renamed_calls[live_entity] = new_name
self.generated_names.add(new_name)
- return new_name
+
+ return new_name, True
def new_symbol(self, name_root, reserved_locals):
"""See control_flow.SymbolNamer.new_symbol."""
diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/naming_test.py
index 7bfc9b8733..5cf0a3da2c 100644
--- a/tensorflow/contrib/py2tf/naming_test.py
+++ b/tensorflow/contrib/py2tf/naming_test.py
@@ -29,8 +29,9 @@ class NamerTest(test.TestCase):
pass
namer = naming.Namer({}, True, None, ())
- self.assertEqual('tf__foo', namer.compiled_function_name('foo'))
- self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar))
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo'))
+ self.assertEqual(('tf__bar', True), namer.compiled_function_name(
+ 'bar', bar))
self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls)
self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
@@ -39,15 +40,18 @@ class NamerTest(test.TestCase):
pass
namer = naming.Namer({}, True, None, ())
- self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
- self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name(
+ 'foo', foo))
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name(
+ 'foo', foo))
def test_compiled_function_name_avoids_global_conflicts(self):
def foo():
pass
namer = naming.Namer({'tf__foo': 1}, True, None, ())
- self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo))
+ self.assertEqual(('tf__foo_1', True),
+ namer.compiled_function_name('foo', foo))
def test_new_symbol_tracks_names(self):
namer = naming.Namer({}, True, None, ())
diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py
index 73f3613d09..fef74ebefa 100644
--- a/tensorflow/contrib/py2tf/pyct/context.py
+++ b/tensorflow/contrib/py2tf/pyct/context.py
@@ -33,10 +33,11 @@ class EntityContext(object):
"""
def __init__(self, namer, source_code, source_file, namespace, arg_values,
- arg_types):
+ arg_types, recursive):
self.namer = namer
self.source_code = source_code
self.source_file = source_file
self.namespace = namespace
self.arg_values = {} if arg_values is None else arg_values
self.arg_types = {} if arg_types is None else arg_types
+ self.recursive = recursive
diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/py2tf/pyct/parser.py
index 3daa69b9ce..dc7df883b3 100644
--- a/tensorflow/contrib/py2tf/pyct/parser.py
+++ b/tensorflow/contrib/py2tf/pyct/parser.py
@@ -28,11 +28,13 @@ import gast
from tensorflow.python.util import tf_inspect
-def parse_object(obj):
- """Return the AST of given object."""
- return parse_str(tf_inspect.getsource(obj))
+def parse_entity(entity):
+ """Return the AST of given entity."""
+ source = tf_inspect.getsource(entity)
+ source = textwrap.dedent(source)
+ return parse_str(source), source
def parse_str(src):
"""Return the AST of given piece of code."""
- return gast.parse(textwrap.dedent(src))
+ return gast.parse(src)
diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/py2tf/pyct/parser_test.py
index 46f9aa8207..f35dfa04c7 100644
--- a/tensorflow/contrib/py2tf/pyct/parser_test.py
+++ b/tensorflow/contrib/py2tf/pyct/parser_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import textwrap
+
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.python.platform import test
@@ -28,15 +30,16 @@ def f(x):
class ParserTest(test.TestCase):
- def test_parse_object(self):
- mod = parser.parse_object(f)
+ def test_parse_entity(self):
+ mod, _ = parser.parse_entity(f)
self.assertEqual('f', mod.body[0].name)
def test_parse_str(self):
- mod = parser.parse_str("""
+ mod = parser.parse_str(
+ textwrap.dedent("""
def f(x):
return x + 1
- """)
+ """))
self.assertEqual('f', mod.body[0].name)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
index 8f3ac48b68..33629f87d1 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
@@ -23,6 +23,7 @@ import copy
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import transformer
# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
@@ -53,6 +54,8 @@ class Scope(object):
self.modified = set()
self.created = set()
self.used = set()
+ self.params = set()
+ self.returned = set()
# TODO(mdan): Rename to `locals`
@property
@@ -69,42 +72,73 @@ class Scope(object):
self.modified = copy.copy(other.modified)
self.created = copy.copy(other.created)
self.used = copy.copy(other.used)
+ self.params = copy.copy(other.params)
+ self.returned = copy.copy(other.returned)
def merge_from(self, other):
self.modified |= other.modified
self.created |= other.created
self.used |= other.used
+ self.params |= other.params
+ self.returned |= other.returned
def has(self, name):
- if name in self.modified:
+ if name in self.modified or name in self.params:
return True
elif self.parent is not None:
return self.parent.has(name)
return False
+ def is_modified_since_entry(self, name):
+ if name in self.modified:
+ return True
+ elif self.parent is not None and not self.isolated:
+ return self.parent.is_modified_since_entry(name)
+ return False
+
+ def is_param(self, name):
+ if name in self.params:
+ return True
+ elif self.parent is not None and not self.isolated:
+ return self.parent.is_param(name)
+ return False
+
def mark_read(self, name):
self.used.add(name)
if self.parent is not None and name not in self.created:
self.parent.mark_read(name)
+ def mark_param(self, name):
+ self.params.add(name)
+
+ def mark_creation(self, name):
+ self.created.add(name)
+
def mark_write(self, name):
self.modified.add(name)
if self.isolated:
- self.created.add(name)
+ self.mark_creation(name)
else:
if self.parent is None:
- self.created.add(name)
+ self.mark_creation(name)
else:
if not self.parent.has(name):
- self.created.add(name)
+ self.mark_creation(name)
self.parent.mark_write(name)
+ def mark_returned(self, name):
+ self.returned.add(name)
+ if not self.isolated and self.parent is not None:
+ self.parent.mark_returned(name)
+
-class AccessResolver(gast.NodeTransformer):
+class AccessResolver(transformer.Base):
"""Annotates nodes with local scope information. See Scope."""
- def __init__(self):
+ def __init__(self, context):
+ super(AccessResolver, self).__init__(context)
self.scope = Scope(None)
+ self._in_return_statement = False
def visit_Name(self, node):
# TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate.
@@ -120,10 +154,17 @@ class AccessResolver(gast.NodeTransformer):
# TODO(mdan): This bay be incorrect with nested functions.
# For nested functions, we'll have to add the notion of hiding args from
# the parent scope, not writing to them.
- self.scope.mark_write(node.id)
+ self.scope.mark_creation(node.id)
+ self.scope.mark_param(node.id)
else:
raise ValueError('Unknown context %s for node %s.' % (type(node.ctx),
node.id))
+ anno.setanno(node, 'is_modified_since_entry',
+ self.scope.is_modified_since_entry(node.id))
+ anno.setanno(node, 'is_param', self.scope.is_param(node.id))
+
+ if self._in_return_statement:
+ self.scope.mark_returned(node.id)
return node
def visit_Print(self, node):
@@ -138,7 +179,7 @@ class AccessResolver(gast.NodeTransformer):
def visit_Call(self, node):
current_scope = self.scope
- args_scope = Scope(current_scope)
+ args_scope = Scope(current_scope, isolated=False)
self.scope = args_scope
for n in node.args:
self.visit(n)
@@ -200,6 +241,12 @@ class AccessResolver(gast.NodeTransformer):
node, ((node.body, 'body'), (node.orelse, 'orelse')))
return node
+ def visit_Return(self, node):
+ self._in_return_statement = True
+ node = self.generic_visit(node)
+ self._in_return_statement = False
+ return node
+
-def resolve(node):
- return AccessResolver().visit(node)
+def resolve(node, context):
+ return AccessResolver(context).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
index 0912ebb4c3..df0283b54d 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.python.platform import test
@@ -95,6 +96,19 @@ class ScopeTest(test.TestCase):
class AccessResolverTest(test.TestCase):
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ recursive=True)
+ node = access.resolve(node, ctx)
+ return node
+
def test_local_markers(self):
def test_fn(a): # pylint:disable=unused-argument
@@ -103,9 +117,7 @@ class AccessResolverTest(test.TestCase):
b -= 1
return b
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
self.assertFalse(anno.getanno(node.body[0].body[0].value,
'is_local')) # c in b = c
self.assertTrue(anno.getanno(node.body[0].body[1].test.left,
@@ -126,9 +138,7 @@ class AccessResolverTest(test.TestCase):
print(a, b)
return c
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
print_node = node.body[0].body[2]
if isinstance(print_node, gast.Print):
# Python 2
@@ -151,9 +161,7 @@ class AccessResolverTest(test.TestCase):
foo(a, b) # pylint:disable=undefined-variable
return c
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[2].value
# We basically need to detect which variables are captured by the call
# arguments.
@@ -169,15 +177,13 @@ class AccessResolverTest(test.TestCase):
b -= 1
return b, c
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
while_node = node.body[0].body[1]
self.assertScopeIs(
anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
self.assertScopeIs(
anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'),
- ('a', 'b', 'c'), ('a', 'b', 'c'))
+ ('b', 'c'), ('a', 'b', 'c'))
def test_for(self):
@@ -188,15 +194,13 @@ class AccessResolverTest(test.TestCase):
b -= 1
return b, c
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
for_node = node.body[0].body[1]
self.assertScopeIs(
anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
self.assertScopeIs(
anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'),
- ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_'))
+ ('b', 'c', '_'), ('a', 'b', 'c', '_'))
def test_if(self):
@@ -211,9 +215,7 @@ class AccessResolverTest(test.TestCase):
u = -y
return z, u
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
-
+ node = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
self.assertScopeIs(
anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
index 242e544b52..5a2903e6b5 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
@@ -26,26 +26,19 @@ from __future__ import print_function
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import transformer
-class LiveValueResolver(gast.NodeTransformer):
+class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
- def __init__(self, namespace, literals):
- """Create a new resolver.
-
- Args:
- namespace: A dict representing the namespace visible to the AST in the
- intended execution context.
- literals: A dict mapping literal lymbol names to their value. An example
- literal is "None".
- """
- self.namespace = namespace
+ def __init__(self, context, literals):
+ super(LiveValueResolver, self).__init__(context)
self.literals = literals
def visit_ClassDef(self, node):
self.generic_visit(node)
- anno.setanno(node, 'live_val', self.namespace[node.name])
+ anno.setanno(node, 'live_val', self.context.namespace[node.name])
return node
def visit_Name(self, node):
@@ -53,20 +46,31 @@ class LiveValueResolver(gast.NodeTransformer):
if isinstance(node.ctx, gast.Load):
assert anno.hasanno(node, 'is_local'), node
symbol_is_local = anno.getanno(node, 'is_local')
- if not symbol_is_local:
+ assert anno.hasanno(node, 'is_modified_since_entry'), node
+ symbol_is_modified = anno.getanno(node, 'is_modified_since_entry')
+ assert anno.hasanno(node, 'is_param'), node
+ symbol_is_param = anno.getanno(node, 'is_param')
+
+ if not symbol_is_local and not symbol_is_param:
if node.id in self.literals:
anno.setanno(node, 'live_val', self.literals[node.id])
# TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
- elif node.id in self.namespace:
- obj = self.namespace[node.id]
+ elif node.id in self.context.namespace:
+ obj = self.context.namespace[node.id]
anno.setanno(node, 'live_val', obj)
anno.setanno(node, 'fqn', (obj.__name__,))
else:
- raise ValueError('Could not find global symbol %s.' % node.id)
+ raise ValueError('Could not resolve symbol "%s".' % node.id)
else:
pass
# TODO(mdan): Attempt to trace its value through the local chain.
# TODO(mdan): Use type annotations as fallback.
+
+ if not symbol_is_modified:
+ if node.id in self.context.arg_values:
+ obj = self.context.arg_values[node.id]
+ anno.setanno(node, 'live_val', obj)
+ anno.setanno(node, 'fqn', (obj.__class__.__name__,))
return node
def visit_Attribute(self, node):
@@ -79,15 +83,25 @@ class LiveValueResolver(gast.NodeTransformer):
node.attr))
anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
+ # TODO(mdan): Investigate the role built-in annotations can play here.
+ elif anno.hasanno(node.value, 'type'):
+ parent_type = anno.getanno(node.value, 'type')
+ if hasattr(parent_type, node.attr):
+ # This should hold for static members like methods.
+ # This would not hold for dynamic members like function attributes.
+ # For the dynamic case, we simply leave the node without an annotation,
+ # and let downstream consumers figure out what to do.
+ anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
+ anno.setanno(node, 'fqn',
+ anno.getanno(node.value, 'type_fqn') + (node.attr,))
elif isinstance(node.value, gast.Name):
stem_name = node.value
# All nonlocal symbols should be fully resolved.
assert anno.hasanno(stem_name, 'is_local'), stem_name
- assert anno.getanno(stem_name, 'is_local'), stem_name
# TODO(mdan): Figure out what to do when calling attribute on local object
# Maybe just leave as-is?
return node
-def resolve(node, namespace, literals):
- return LiveValueResolver(namespace, literals).visit(node)
+def resolve(node, context, literals):
+ return LiveValueResolver(context, literals).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
index e77497654a..f3057b3466 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
@@ -19,24 +19,45 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class LiveValuesResolverTest(test.TestCase):
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ literals=None,
+ arg_types=None):
+ literals = literals or {}
+ arg_types = arg_types or {}
+ node, source = parser.parse_entity(test_fn)
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ recursive=True)
+ node = access.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, literals)
+ node = type_info.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, literals)
+ return node
+
def test_literals(self):
def test_fn():
return Foo # pylint: disable=undefined-variable
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {}, {'Foo': 'bar'})
-
+ node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'})
retval_node = node.body[0].body[0].value
self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
@@ -48,10 +69,7 @@ class LiveValuesResolverTest(test.TestCase):
def test_fn():
return foo()
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'foo': foo}, {})
-
+ node = self._parse_and_analyze(test_fn, {'foo': foo})
func_node = node.body[0].body[0].value.func
self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
@@ -61,15 +79,29 @@ class LiveValuesResolverTest(test.TestCase):
def test_fn():
return constant_op.constant(0)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'constant_op': constant_op}, {})
-
+ node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
func_node = node.body[0].body[0].value.func
self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
self.assertEquals((constant_op.__name__, 'constant'),
anno.getanno(func_node, 'fqn'))
+ def test_attributes_with_type_hints(self):
+
+ class TestClass(object):
+
+ def member(self):
+ pass
+
+ def test_fn(self):
+ return self.member()
+
+ node = self._parse_and_analyze(
+ TestClass.test_fn, {'constant_op': constant_op},
+ arg_types={'self': (TestClass.__name__, TestClass)})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
index 0042aa90ed..cf74142cbe 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
@@ -36,8 +36,6 @@ class Scope(object):
most recently assigned to the symbol.
"""
- # TODO(mdan): Should rather use a CFG here?
-
def __init__(self, parent):
"""Create a new scope.
@@ -117,18 +115,32 @@ class TypeInfoResolver(transformer.Base):
node.orelse = self._visit_block(node.orelse)
return node
+ def _process_function_arg(self, arg_name):
+ if self.function_level == 1 and arg_name in self.context.arg_types:
+ # Forge a node to hold the type information, so that method calls on
+ # it can resolve the type.
+ type_holder = gast.Name(arg_name, gast.Load(), None)
+ type_string, type_obj = self.context.arg_types[arg_name]
+ anno.setanno(type_holder, 'type', type_obj)
+ anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
+ self.scope.setval(arg_name, type_holder)
+
+ def visit_arg(self, node):
+ self._process_function_arg(node.arg)
+ return node
+
def visit_Name(self, node):
self.generic_visit(node)
if isinstance(node.ctx, gast.Param):
- self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
- if self.function_level == 1 and node.id in self.context.arg_types:
- # Forge a node to hold the type information, so that method calls on
- # it can resolve the type.
- type_holder = gast.Name(node.id, gast.Load(), None)
- type_string, type_obj = self.context.arg_types[node.id]
- anno.setanno(type_holder, 'type', type_obj)
- anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
- self.scope.setval(node.id, type_holder)
+ self._process_function_arg(node.id)
+ elif isinstance(node.ctx, gast.Load) and self.scope.hasval(node.id):
+ # E.g. if we had
+ # a = b
+ # then for future references to `a` we should have traced_source = `b`
+ traced_source = self.scope.getval(node.id)
+ if anno.hasanno(traced_source, 'type'):
+ anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
+ anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
return node
def _process_variable_assignment(self, source, targets):
@@ -172,38 +184,6 @@ class TypeInfoResolver(transformer.Base):
self._process_variable_assignment(node.value, node.targets)
return node
- def visit_Call(self, node):
- target = node.func
- if not anno.hasanno(target, 'live_val'):
- if not isinstance(target, gast.Attribute):
- # Suspecting this pattern would reach here:
- # foo = bar
- # foo()
- raise ValueError('Dont know how to handle dynamic functions.')
- if not isinstance(target.value, gast.Name):
- # Possible example of this kind:
- # foo = module.Foo()
- # foo.bar.baz()
- # TODO(mdan): This should be doable by using the FQN.
- raise ValueError('Dont know how to handle object properties yet.')
- # In the example below, object_source is 'tr.train.Optimizer()':
- # opt = tf.train.Optimizer()
- # opt.foo()
- if self.scope.hasval(target.value.id):
- object_source = self.scope.getval(target.value.id)
- if not anno.hasanno(object_source, 'type'):
- raise ValueError('Could not determine type of "%s". Is it dynamic?' %
- (target.value.id))
- anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
- anno.setanno(target, 'type_fqn', anno.getanno(object_source,
- 'type_fqn'))
- else:
- # TODO(mdan): Figure out what could the user do to get past this.
- raise ValueError('No info on "%s". Is it dynamically built?' %
- (target.value.id))
- self.generic_visit(node)
- return node
-
def resolve(node, context):
return TypeInfoResolver(context).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
index a491f49ca3..68fa1ee92a 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct import transformer
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
@@ -57,17 +56,19 @@ class ScopeTest(test.TestCase):
class TypeInfoResolverTest(test.TestCase):
def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
+ node, source = parser.parse_entity(test_fn)
ctx = context.EntityContext(
namer=None,
- source_code=None,
+ source_code=source,
source_file=None,
namespace=namespace,
arg_values=None,
- arg_types=arg_types)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
+ arg_types=arg_types,
+ recursive=True)
+ node = access.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, {})
node = type_info.resolve(node, ctx)
+ node = live_values.resolve(node, ctx, {})
return node
def test_constructor_detection(self):
@@ -83,16 +84,16 @@ class TypeInfoResolverTest(test.TestCase):
self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
anno.getanno(call_node, 'type_fqn'))
- def test_class_members(self):
+ def test_class_members_of_detected_constructor(self):
def test_fn():
opt = training.GradientDescentOptimizer(0.1)
opt.minimize(0)
node = self._parse_and_analyze(test_fn, {'training': training})
- attr_call_node = node.body[0].body[1].value.func
- self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
- anno.getanno(attr_call_node, 'type_fqn'))
+ method_call = node.body[0].body[1].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
def test_class_members_in_with_stmt(self):
@@ -106,11 +107,11 @@ class TypeInfoResolverTest(test.TestCase):
self.assertEquals((session.__name__, 'Session'),
anno.getanno(constructor_call, 'type_fqn'))
- member_call = node.body[0].body[0].body[0].value.func
- self.assertEquals((session.__name__, 'Session'),
- anno.getanno(member_call, 'type_fqn'))
+ method_call = node.body[0].body[0].body[0].value.func
+ self.assertEquals(session.Session.run, anno.getanno(method_call,
+ 'live_val'))
- def test_constructor_deta_dependent(self):
+ def test_constructor_data_dependent(self):
def test_fn(x):
if x > 0:
@@ -119,16 +120,18 @@ class TypeInfoResolverTest(test.TestCase):
opt = training.GradientDescentOptimizer(0.01)
opt.minimize(0)
- with self.assertRaises(transformer.PyFlowParseError):
- self._parse_and_analyze(test_fn, {'training': training})
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_parameter_class_members(self):
def test_fn(opt):
opt.minimize(0)
- with self.assertRaises(transformer.PyFlowParseError):
- self._parse_and_analyze(test_fn, {'training': training})
+ node = self._parse_and_analyze(test_fn, {})
+ method_call = node.body[0].body[0].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_parameter_class_members_with_value_hints(self):
@@ -138,14 +141,13 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(
test_fn, {'training': training},
arg_types={
- 'opt': (('%s.GradientDescentOptimizer' % training.__name__),
- training.GradientDescentOptimizer(0.1))
+ 'opt': (training.GradientDescentOptimizer.__name__,
+ training.GradientDescentOptimizer)
})
- attr_call_node = node.body[0].body[0].value.func
- self.assertEquals(
- tuple(training.__name__.split('.')) + ('GradientDescentOptimizer',),
- anno.getanno(attr_call_node, 'type_fqn'))
+ method_call = node.body[0].body[0].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
def test_function_variables(self):
@@ -156,8 +158,9 @@ class TypeInfoResolverTest(test.TestCase):
foo = bar
foo()
- with self.assertRaises(transformer.PyFlowParseError):
- self._parse_and_analyze(test_fn, {'bar': bar})
+ node = self._parse_and_analyze(test_fn, {'bar': bar})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_nested_members(self):
@@ -165,8 +168,9 @@ class TypeInfoResolverTest(test.TestCase):
foo = training.GradientDescentOptimizer(0.1)
foo.bar.baz()
- with self.assertRaises(transformer.PyFlowParseError):
- self._parse_and_analyze(test_fn, {'training': training})
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
index 77c5fbe02a..6be526f20d 100644
--- a/tensorflow/contrib/py2tf/pyct/templates.py
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -23,6 +23,7 @@ from __future__ import print_function
import ast
import copy
+import textwrap
import gast
@@ -119,7 +120,7 @@ def replace(template, **replacements):
"""
if not isinstance(template, str):
raise ValueError('Expected string template, got %s' % type(template))
- tree = parser.parse_str(template)
+ tree = parser.parse_str(textwrap.dedent(template))
for k in replacements:
replacements[k] = _strings_to_names(replacements[k])
return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/py2tf/pyct/transformer.py
index d5aa23eaeb..8a836b7c1b 100644
--- a/tensorflow/contrib/py2tf/pyct/transformer.py
+++ b/tensorflow/contrib/py2tf/pyct/transformer.py
@@ -18,7 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
import gast
+import six
from tensorflow.contrib.py2tf.pyct import pretty_printer
@@ -48,11 +51,15 @@ class Base(gast.NodeTransformer):
self._lineno = node.lineno
self._col_offset = node.col_offset
return super(Base, self).visit(node)
- except ValueError as e:
- msg = '%s\nOccurred at node:\n%s' % (str(e), pretty_printer.fmt(node))
+ except (ValueError, AttributeError, NotImplementedError) as e:
+ msg = '%s: %s\nOccurred at node:\n%s' % (e.__class__.__name__, str(e),
+ pretty_printer.fmt(node))
if source_code:
- line = self._source.splitlines()[self._lineno - 1]
+ line = source_code.splitlines()[self._lineno - 1]
else:
line = '<no source available>'
- raise PyFlowParseError(
- msg, (source_file, self._lineno, self._col_offset + 1, line))
+ six.reraise(PyFlowParseError,
+ PyFlowParseError(
+ msg,
+ (source_file, self._lineno, self._col_offset + 1, line)),
+ sys.exc_info()[2])