diff options
author | Dan Moldovan <mdan@google.com> | 2018-07-16 08:58:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 09:02:09 -0700 |
commit | cde36bc1667d80c9569bfa09b1cb6e71a77700b9 (patch) | |
tree | 6edb64a2804ac50176bfb6334eec0834f04349d2 /tensorflow/contrib/autograph | |
parent | 70b89c7eb28f3a2e87168a55d0f2c3f46f1e8add (diff) |
Fix reaching_definitions to correctly mark the definition of modified symbols in the statement that replaces them, e.g. a = a.
PiperOrigin-RevId: 204749753
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py | 69 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py | 66 |
2 files changed, 102 insertions, 33 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py index 4ea7fd93cd..9a84f1231c 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py @@ -112,7 +112,6 @@ class Analyzer(cfg.GraphVisitor): def __init__(self, graph, definition_factory): self._definition_factory = definition_factory super(Analyzer, self).__init__(graph) - self.defs_by_ast_node = {} # This allows communicating that nodes have extra reaching definitions, # e.g. those that a function closes over. self.extra_in = {} @@ -160,13 +159,12 @@ class Analyzer(cfg.GraphVisitor): self.in_[node] = defs_in self.out[node] = defs_out - self.defs_by_ast_node[node.ast_node] = defs_out.value # TODO(mdan): Move this to the superclass? return prev_defs_out != defs_out -class WholeTreeAnalyzer(transformer.Base): +class TreeAnnotator(transformer.Base): """AST visitor that annotates each symbol name with its reaching definitions. Simultaneously, the visitor runs the dataflow analysis on each function node, @@ -179,12 +177,11 @@ class WholeTreeAnalyzer(transformer.Base): """ def __init__(self, source_info, graphs, definition_factory): - super(WholeTreeAnalyzer, self).__init__(source_info) - self.stmt_reaching_defs_info = None + super(TreeAnnotator, self).__init__(source_info) + self.definition_factory = definition_factory self.graphs = graphs self.current_analyzer = None - self.definition_factory = definition_factory - self.current_stmt_defs = None + self.current_cfg_node = None def visit_FunctionDef(self, node): parent_analyzer = self.current_analyzer @@ -209,7 +206,11 @@ class WholeTreeAnalyzer(transformer.Base): # Recursively process any remaining subfunctions. self.current_analyzer = analyzer - node = self.generic_visit(node) + # Note: not visiting name, decorator_list and returns because they don't + # apply to this anlysis. + # TODO(mdan): Should we still process the function name? + node.args = self.visit(node.args) + node.body = self.visit_block(node.body) self.current_analyzer = parent_analyzer return node @@ -226,11 +227,19 @@ class WholeTreeAnalyzer(transformer.Base): # definitions. return node + analyzer = self.current_analyzer + cfg_node = self.current_cfg_node + + assert cfg_node is not None, 'name node outside of any statement?' + qn = anno.getanno(node, anno.Basic.QN) - assert self.current_stmt_defs is not None, ( - 'name node outside of any statement?') - anno.setanno(node, anno.Static.DEFINITIONS, - tuple(self.current_stmt_defs.get(qn, ()))) + if isinstance(node.ctx, gast.Load): + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.in_[cfg_node].value.get(qn, ()))) + else: + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.out[cfg_node].value.get(qn, ()))) + return node def _aggregate_predecessors_defined_in(self, node): @@ -239,23 +248,41 @@ class WholeTreeAnalyzer(transformer.Base): for p in preds: node_defined_in |= set(self.current_analyzer.out[p].value.keys()) anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) - node = self.generic_visit(node) - return node def visit_If(self, node): - return self._aggregate_predecessors_defined_in(node) + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) def visit_For(self, node): - return self._aggregate_predecessors_defined_in(node) + self._aggregate_predecessors_defined_in(node) + + # Manually accounting for the shortcoming described in + # cfg.AstToCfg.visit_For. + parent = self.current_cfg_node + self.current_cfg_node = self.current_analyzer.graph.index[node.iter] + node.target = self.visit(node.target) + self.current_cfg_node = parent + + node.iter = self.visit(node.iter) + node.body = self.visit_block(node.body) + node.orelse = self.visit_block(node.orelse) + + return node def visit_While(self, node): - return self._aggregate_predecessors_defined_in(node) + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) def visit(self, node): + parent = self.current_cfg_node + if (self.current_analyzer is not None and - node in self.current_analyzer.defs_by_ast_node): - self.current_stmt_defs = self.current_analyzer.defs_by_ast_node[node] - return super(WholeTreeAnalyzer, self).visit(node) + node in self.current_analyzer.graph.index): + self.current_cfg_node = self.current_analyzer.graph.index[node] + node = super(TreeAnnotator, self).visit(node) + + self.current_cfg_node = parent + return node def resolve(node, source_info, graphs, definition_factory): @@ -269,6 +296,6 @@ def resolve(node, source_info, graphs, definition_factory): Returns: ast.AST """ - visitor = WholeTreeAnalyzer(source_info, graphs, definition_factory) + visitor = TreeAnnotator(source_info, graphs, definition_factory) node = visitor.visit(node) return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py index 0410bb2a35..243fe804b2 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -61,6 +61,20 @@ class DefinitionInfoTest(test.TestCase): expected = (expected,) self.assertSetEqual(defined_in_str, set(expected)) + def assertSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIs( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def assertNotSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIsNot( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + def test_conditional(self): def test_fn(a, b): @@ -93,10 +107,10 @@ class DefinitionInfoTest(test.TestCase): self.assertHasDefs(fn_body[0].value.args[0], 1) self.assertHasDefs(fn_body[1].body[0].targets[0], 1) - self.assertHasDefs(fn_body[1].body[0].value, 1) self.assertHasDefs(fn_body[1].body[1].targets[0], 1) self.assertHasDefs(fn_body[1].body[1].value, 1) # The loop does have an invariant test, but the CFG doesn't know that. + self.assertHasDefs(fn_body[1].body[0].value, 2) self.assertHasDefs(fn_body[2].value, 2) def test_while_else(self): @@ -171,10 +185,7 @@ class DefinitionInfoTest(test.TestCase): self.assertHasDefs(fn_body[2].value, 2) inner_fn_body = fn_body[1].body[1].body - self.assertHasDefs(inner_fn_body[0].value, 1) - self.assertTrue( - anno.getanno(inner_fn_body[0].value, anno.Static.DEFINITIONS)[0] is - anno.getanno(def_of_a_in_if, anno.Static.DEFINITIONS)[0]) + self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if) def test_nested_functions_isolation(self): @@ -191,17 +202,12 @@ class DefinitionInfoTest(test.TestCase): node = self._parse_and_analyze(test_fn) fn_body = node.body[0].body - self.assertHasDefs(fn_body[3].value, 1) - self.assertHasDefs(fn_body[1].body[1].value, 1) - parent_return = fn_body[3] child_return = fn_body[1].body[1] # The assignment `a = 1` makes `a` local to `child`. - self.assertFalse( - anno.getanno(parent_return.value, anno.Static.DEFINITIONS)[0] is - anno.getanno(child_return.value, anno.Static.DEFINITIONS)[0]) + self.assertNotSameDef(parent_return.value, child_return.value) - def test_debug(self): + def test_function_call_in_with(self): def foo(_): pass @@ -216,6 +222,42 @@ class DefinitionInfoTest(test.TestCase): self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0) self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1) + def test_mutation_subscript(self): + + def test_fn(a): + l = [] + l[0] = a + return l + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + creation = fn_body[0].targets[0] + mutation = fn_body[1].targets[0].value + use = fn_body[2].value + self.assertSameDef(creation, mutation) + self.assertSameDef(creation, use) + + def test_replacement(self): + + def foo(a): + return a + + def test_fn(a): + a = foo(a) + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + param = node.body[0].args.args[0] + source = fn_body[0].value.args[0] + target = fn_body[0].targets[0] + retval = fn_body[1].value + self.assertSameDef(param, source) + self.assertNotSameDef(source, target) + self.assertSameDef(target, retval) + if __name__ == '__main__': test.main() |