aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-07-16 08:58:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 09:02:09 -0700
commitcde36bc1667d80c9569bfa09b1cb6e71a77700b9 (patch)
tree6edb64a2804ac50176bfb6334eec0834f04349d2 /tensorflow/contrib/autograph
parent70b89c7eb28f3a2e87168a55d0f2c3f46f1e8add (diff)
Fix reaching_definitions to correctly mark the definition of modified symbols in the statement that replaces them, e.g. a = a.
PiperOrigin-RevId: 204749753
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py69
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py66
2 files changed, 102 insertions, 33 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
index 4ea7fd93cd..9a84f1231c 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
@@ -112,7 +112,6 @@ class Analyzer(cfg.GraphVisitor):
def __init__(self, graph, definition_factory):
self._definition_factory = definition_factory
super(Analyzer, self).__init__(graph)
- self.defs_by_ast_node = {}
# This allows communicating that nodes have extra reaching definitions,
# e.g. those that a function closes over.
self.extra_in = {}
@@ -160,13 +159,12 @@ class Analyzer(cfg.GraphVisitor):
self.in_[node] = defs_in
self.out[node] = defs_out
- self.defs_by_ast_node[node.ast_node] = defs_out.value
# TODO(mdan): Move this to the superclass?
return prev_defs_out != defs_out
-class WholeTreeAnalyzer(transformer.Base):
+class TreeAnnotator(transformer.Base):
"""AST visitor that annotates each symbol name with its reaching definitions.
Simultaneously, the visitor runs the dataflow analysis on each function node,
@@ -179,12 +177,11 @@ class WholeTreeAnalyzer(transformer.Base):
"""
def __init__(self, source_info, graphs, definition_factory):
- super(WholeTreeAnalyzer, self).__init__(source_info)
- self.stmt_reaching_defs_info = None
+ super(TreeAnnotator, self).__init__(source_info)
+ self.definition_factory = definition_factory
self.graphs = graphs
self.current_analyzer = None
- self.definition_factory = definition_factory
- self.current_stmt_defs = None
+ self.current_cfg_node = None
def visit_FunctionDef(self, node):
parent_analyzer = self.current_analyzer
@@ -209,7 +206,11 @@ class WholeTreeAnalyzer(transformer.Base):
# Recursively process any remaining subfunctions.
self.current_analyzer = analyzer
- node = self.generic_visit(node)
+ # Note: not visiting name, decorator_list and returns because they don't
+ # apply to this anlysis.
+ # TODO(mdan): Should we still process the function name?
+ node.args = self.visit(node.args)
+ node.body = self.visit_block(node.body)
self.current_analyzer = parent_analyzer
return node
@@ -226,11 +227,19 @@ class WholeTreeAnalyzer(transformer.Base):
# definitions.
return node
+ analyzer = self.current_analyzer
+ cfg_node = self.current_cfg_node
+
+ assert cfg_node is not None, 'name node outside of any statement?'
+
qn = anno.getanno(node, anno.Basic.QN)
- assert self.current_stmt_defs is not None, (
- 'name node outside of any statement?')
- anno.setanno(node, anno.Static.DEFINITIONS,
- tuple(self.current_stmt_defs.get(qn, ())))
+ if isinstance(node.ctx, gast.Load):
+ anno.setanno(node, anno.Static.DEFINITIONS,
+ tuple(analyzer.in_[cfg_node].value.get(qn, ())))
+ else:
+ anno.setanno(node, anno.Static.DEFINITIONS,
+ tuple(analyzer.out[cfg_node].value.get(qn, ())))
+
return node
def _aggregate_predecessors_defined_in(self, node):
@@ -239,23 +248,41 @@ class WholeTreeAnalyzer(transformer.Base):
for p in preds:
node_defined_in |= set(self.current_analyzer.out[p].value.keys())
anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
- node = self.generic_visit(node)
- return node
def visit_If(self, node):
- return self._aggregate_predecessors_defined_in(node)
+ self._aggregate_predecessors_defined_in(node)
+ return self.generic_visit(node)
def visit_For(self, node):
- return self._aggregate_predecessors_defined_in(node)
+ self._aggregate_predecessors_defined_in(node)
+
+ # Manually accounting for the shortcoming described in
+ # cfg.AstToCfg.visit_For.
+ parent = self.current_cfg_node
+ self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
+ node.target = self.visit(node.target)
+ self.current_cfg_node = parent
+
+ node.iter = self.visit(node.iter)
+ node.body = self.visit_block(node.body)
+ node.orelse = self.visit_block(node.orelse)
+
+ return node
def visit_While(self, node):
- return self._aggregate_predecessors_defined_in(node)
+ self._aggregate_predecessors_defined_in(node)
+ return self.generic_visit(node)
def visit(self, node):
+ parent = self.current_cfg_node
+
if (self.current_analyzer is not None and
- node in self.current_analyzer.defs_by_ast_node):
- self.current_stmt_defs = self.current_analyzer.defs_by_ast_node[node]
- return super(WholeTreeAnalyzer, self).visit(node)
+ node in self.current_analyzer.graph.index):
+ self.current_cfg_node = self.current_analyzer.graph.index[node]
+ node = super(TreeAnnotator, self).visit(node)
+
+ self.current_cfg_node = parent
+ return node
def resolve(node, source_info, graphs, definition_factory):
@@ -269,6 +296,6 @@ def resolve(node, source_info, graphs, definition_factory):
Returns:
ast.AST
"""
- visitor = WholeTreeAnalyzer(source_info, graphs, definition_factory)
+ visitor = TreeAnnotator(source_info, graphs, definition_factory)
node = visitor.visit(node)
return node
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
index 0410bb2a35..243fe804b2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -61,6 +61,20 @@ class DefinitionInfoTest(test.TestCase):
expected = (expected,)
self.assertSetEqual(defined_in_str, set(expected))
+ def assertSameDef(self, first, second):
+ self.assertHasDefs(first, 1)
+ self.assertHasDefs(second, 1)
+ self.assertIs(
+ anno.getanno(first, anno.Static.DEFINITIONS)[0],
+ anno.getanno(second, anno.Static.DEFINITIONS)[0])
+
+ def assertNotSameDef(self, first, second):
+ self.assertHasDefs(first, 1)
+ self.assertHasDefs(second, 1)
+ self.assertIsNot(
+ anno.getanno(first, anno.Static.DEFINITIONS)[0],
+ anno.getanno(second, anno.Static.DEFINITIONS)[0])
+
def test_conditional(self):
def test_fn(a, b):
@@ -93,10 +107,10 @@ class DefinitionInfoTest(test.TestCase):
self.assertHasDefs(fn_body[0].value.args[0], 1)
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
- self.assertHasDefs(fn_body[1].body[0].value, 1)
self.assertHasDefs(fn_body[1].body[1].targets[0], 1)
self.assertHasDefs(fn_body[1].body[1].value, 1)
# The loop does have an invariant test, but the CFG doesn't know that.
+ self.assertHasDefs(fn_body[1].body[0].value, 2)
self.assertHasDefs(fn_body[2].value, 2)
def test_while_else(self):
@@ -171,10 +185,7 @@ class DefinitionInfoTest(test.TestCase):
self.assertHasDefs(fn_body[2].value, 2)
inner_fn_body = fn_body[1].body[1].body
- self.assertHasDefs(inner_fn_body[0].value, 1)
- self.assertTrue(
- anno.getanno(inner_fn_body[0].value, anno.Static.DEFINITIONS)[0] is
- anno.getanno(def_of_a_in_if, anno.Static.DEFINITIONS)[0])
+ self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if)
def test_nested_functions_isolation(self):
@@ -191,17 +202,12 @@ class DefinitionInfoTest(test.TestCase):
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
- self.assertHasDefs(fn_body[3].value, 1)
- self.assertHasDefs(fn_body[1].body[1].value, 1)
-
parent_return = fn_body[3]
child_return = fn_body[1].body[1]
# The assignment `a = 1` makes `a` local to `child`.
- self.assertFalse(
- anno.getanno(parent_return.value, anno.Static.DEFINITIONS)[0] is
- anno.getanno(child_return.value, anno.Static.DEFINITIONS)[0])
+ self.assertNotSameDef(parent_return.value, child_return.value)
- def test_debug(self):
+ def test_function_call_in_with(self):
def foo(_):
pass
@@ -216,6 +222,42 @@ class DefinitionInfoTest(test.TestCase):
self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0)
self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1)
+ def test_mutation_subscript(self):
+
+ def test_fn(a):
+ l = []
+ l[0] = a
+ return l
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ creation = fn_body[0].targets[0]
+ mutation = fn_body[1].targets[0].value
+ use = fn_body[2].value
+ self.assertSameDef(creation, mutation)
+ self.assertSameDef(creation, use)
+
+ def test_replacement(self):
+
+ def foo(a):
+ return a
+
+ def test_fn(a):
+ a = foo(a)
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ param = node.body[0].args.args[0]
+ source = fn_body[0].value.args[0]
+ target = fn_body[0].targets[0]
+ retval = fn_body[1].value
+ self.assertSameDef(param, source)
+ self.assertNotSameDef(source, target)
+ self.assertSameDef(target, retval)
+
if __name__ == '__main__':
test.main()