From 3ef35b81fd753401e3d69989b3bd1146749cc3b3 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 9 Oct 2018 09:34:47 -0700 Subject: Include live-in symbols in liveness analysis. These are required for control flow conversion. PiperOrigin-RevId: 216370439 --- tensorflow/python/autograph/pyct/anno.py | 1 + tensorflow/python/autograph/pyct/cfg.py | 10 ++- .../autograph/pyct/static_analysis/liveness.py | 36 ++++++--- .../pyct/static_analysis/liveness_test.py | 86 ++++++++++++++++++++-- 4 files changed, 112 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py index 1a52110ef3..5392e6ea03 100644 --- a/tensorflow/python/autograph/pyct/anno.py +++ b/tensorflow/python/autograph/pyct/anno.py @@ -91,6 +91,7 @@ class Static(NoValue): DEFINED_VARS_IN = ( 'Symbols defined when entering the node. See reaching_definitions.py.') LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') + LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.') FAIL = object() diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index fca0eb62e4..ec733ea38f 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -22,6 +22,10 @@ Once built, the CFG itself is immutable, but the values it holds need not be; they are usually annotated with information extracted by walking the graph. """ +# TODO(mdan): The notion of 'statements' below is inaccurate. +# They should rather be called 'block statements', because they include +# statements that may have a body, e.g. if and while. + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -763,9 +767,9 @@ class AstToCfg(gast.NodeVisitor): self.builder.enter_section(node) - # TODO(mdan): Strictly speaking, this should be node.target + node.iter. - # A blind dataflow analysis would have to process both node.target and - # node.iter to properly process read and write access. + # Note: Strictly speaking, this should be node.target + node.iter. + # However, the activity analysis accounts for this inconsistency, + # so dataflow analysis produces the correct values. self.builder.enter_loop_section(node, node.iter) for stmt in node.body: self.visit(stmt) diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py index 41c903beb9..36960d0103 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py @@ -14,8 +14,13 @@ # ============================================================================== """Live variable analysis. -This analysis attaches a set containing the live symbols that are live at the -exit of control flow statements. +See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of +the following idioms: live variable, live in, live out, which are used +throughout this file. + +This analysis attaches the following: + * symbols that are live at the exit of control flow statements + * symbols that are live at the entry of control flow statements Requires activity analysis. """ @@ -164,23 +169,34 @@ class Annotator(transformer.Base): self.current_analyzer = parent_analyzer return node - def _aggregate_successors_live_in(self, node): + def _block_statement_live_out(self, node): successors = self.current_analyzer.graph.stmt_next[node] - node_live_out = set() + stmt_live_out = set() for s in successors: - node_live_out.update(self.current_analyzer.in_[s]) - anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out)) - node = self.generic_visit(node) + stmt_live_out.update(self.current_analyzer.in_[s]) + anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out)) + return node + + def _block_statement_live_in(self, node, entry_node): + cfg_node = self.current_analyzer.graph.index[entry_node] + stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) + anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node def visit_If(self, node): - return self._aggregate_successors_live_in(node) + node = self.generic_visit(node) + node = self._block_statement_live_out(node) + return self._block_statement_live_in(node, node.test) def visit_For(self, node): - return self._aggregate_successors_live_in(node) + node = self.generic_visit(node) + node = self._block_statement_live_out(node) + return self._block_statement_live_in(node, node.iter) def visit_While(self, node): - return self._aggregate_successors_live_in(node) + node = self.generic_visit(node) + node = self._block_statement_live_out(node) + return self._block_statement_live_in(node, node.test) def resolve(node, source_info, graphs): diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index 0d5f369e92..7b67f8f608 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -47,14 +47,23 @@ class LivenessTest(test.TestCase): def assertHasLiveOut(self, node, expected): live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - live_out_str = set(str(v) for v in live_out) + live_out_strs = set(str(v) for v in live_out) if not expected: expected = () if not isinstance(expected, tuple): expected = (expected,) - self.assertSetEqual(live_out_str, set(expected)) + self.assertSetEqual(live_out_strs, set(expected)) - def test_stacked_if(self): + def assertHasLiveIn(self, node, expected): + live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) + live_in_strs = set(str(v) for v in live_in) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(live_in_strs, set(expected)) + + def test_live_out_stacked_if(self): def test_fn(x, a): if a > 0: @@ -69,7 +78,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], ('a', 'x')) self.assertHasLiveOut(fn_body[1], 'x') - def test_stacked_if_else(self): + def test_live_out_stacked_if_else(self): def test_fn(x, a): if a > 0: @@ -86,7 +95,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'a') self.assertHasLiveOut(fn_body[1], 'x') - def test_for_basic(self): + def test_live_out_for_basic(self): def test_fn(x, a): for i in range(a): @@ -98,7 +107,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'x') - def test_attributes(self): + def test_live_out_attributes(self): def test_fn(x, a): if a > 0: @@ -110,7 +119,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], ('x.y', 'x')) - def test_nested_functions(self): + def test_live_out_nested_functions(self): def test_fn(a, b): if b: @@ -126,7 +135,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'a') - def test_nested_functions_isolation(self): + def test_live_out_nested_functions_isolation(self): def test_fn(b): if b: @@ -144,6 +153,67 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'max') + def test_live_in_stacked_if(self): + + def test_fn(x, a, b, c): + if a > 0: + x = b + if c > 1: + x = 0 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x')) + self.assertHasLiveIn(fn_body[1], ('c', 'x')) + + def test_live_in_stacked_if_else(self): + + def test_fn(x, a, b, c, d): + if a > 1: + x = b + else: + x = c + if d > 0: + x = 0 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd')) + self.assertHasLiveIn(fn_body[1], ('d', 'x')) + + def test_live_in_for_basic(self): + + def test_fn(x, y, a): + for i in a: + x = i + y += x + z = 0 + return y, z + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) + + def test_live_in_for_nested(self): + + def test_fn(x, y, a): + for i in a: + for j in i: + x = i + y += x + z = j + return y, z + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) + if __name__ == '__main__': test.main() -- cgit v1.2.3