aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-09 09:34:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 09:53:26 -0700
commit3ef35b81fd753401e3d69989b3bd1146749cc3b3 (patch)
tree3db4c4f9fdfaa85f5f7907f5fb0a2ad5bafb45ad
parent5d6adc910b8323b73a61d3089f3a3028be411e90 (diff)
Include live-in symbols in liveness analysis. These are required for control flow conversion.
PiperOrigin-RevId: 216370439
-rw-r--r--tensorflow/python/autograph/pyct/anno.py1
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py10
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness.py36
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness_test.py86
4 files changed, 112 insertions, 21 deletions
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index 1a52110ef3..5392e6ea03 100644
--- a/tensorflow/python/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -91,6 +91,7 @@ class Static(NoValue):
DEFINED_VARS_IN = (
'Symbols defined when entering the node. See reaching_definitions.py.')
LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
+ LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
FAIL = object()
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index fca0eb62e4..ec733ea38f 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -22,6 +22,10 @@ Once built, the CFG itself is immutable, but the values it holds need not be;
they are usually annotated with information extracted by walking the graph.
"""
+# TODO(mdan): The notion of 'statements' below is inaccurate.
+# They should rather be called 'block statements', because they include
+# statements that may have a body, e.g. if and while.
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -763,9 +767,9 @@ class AstToCfg(gast.NodeVisitor):
self.builder.enter_section(node)
- # TODO(mdan): Strictly speaking, this should be node.target + node.iter.
- # A blind dataflow analysis would have to process both node.target and
- # node.iter to properly process read and write access.
+ # Note: Strictly speaking, this should be node.target + node.iter.
+ # However, the activity analysis accounts for this inconsistency,
+ # so dataflow analysis produces the correct values.
self.builder.enter_loop_section(node, node.iter)
for stmt in node.body:
self.visit(stmt)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
index 41c903beb9..36960d0103 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/liveness.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -14,8 +14,13 @@
# ==============================================================================
"""Live variable analysis.
-This analysis attaches a set containing the live symbols that are live at the
-exit of control flow statements.
+See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of
+the following idioms: live variable, live in, live out, which are used
+throughout this file.
+
+This analysis attaches the following:
+ * symbols that are live at the exit of control flow statements
+ * symbols that are live at the entry of control flow statements
Requires activity analysis.
"""
@@ -164,23 +169,34 @@ class Annotator(transformer.Base):
self.current_analyzer = parent_analyzer
return node
- def _aggregate_successors_live_in(self, node):
+ def _block_statement_live_out(self, node):
successors = self.current_analyzer.graph.stmt_next[node]
- node_live_out = set()
+ stmt_live_out = set()
for s in successors:
- node_live_out.update(self.current_analyzer.in_[s])
- anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
- node = self.generic_visit(node)
+ stmt_live_out.update(self.current_analyzer.in_[s])
+ anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
+ return node
+
+ def _block_statement_live_in(self, node, entry_node):
+ cfg_node = self.current_analyzer.graph.index[entry_node]
+ stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
+ anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
return node
def visit_If(self, node):
- return self._aggregate_successors_live_in(node)
+ node = self.generic_visit(node)
+ node = self._block_statement_live_out(node)
+ return self._block_statement_live_in(node, node.test)
def visit_For(self, node):
- return self._aggregate_successors_live_in(node)
+ node = self.generic_visit(node)
+ node = self._block_statement_live_out(node)
+ return self._block_statement_live_in(node, node.iter)
def visit_While(self, node):
- return self._aggregate_successors_live_in(node)
+ node = self.generic_visit(node)
+ node = self._block_statement_live_out(node)
+ return self._block_statement_live_in(node, node.test)
def resolve(node, source_info, graphs):
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index 0d5f369e92..7b67f8f608 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -47,14 +47,23 @@ class LivenessTest(test.TestCase):
def assertHasLiveOut(self, node, expected):
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
- live_out_str = set(str(v) for v in live_out)
+ live_out_strs = set(str(v) for v in live_out)
if not expected:
expected = ()
if not isinstance(expected, tuple):
expected = (expected,)
- self.assertSetEqual(live_out_str, set(expected))
+ self.assertSetEqual(live_out_strs, set(expected))
- def test_stacked_if(self):
+ def assertHasLiveIn(self, node, expected):
+ live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
+ live_in_strs = set(str(v) for v in live_in)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(live_in_strs, set(expected))
+
+ def test_live_out_stacked_if(self):
def test_fn(x, a):
if a > 0:
@@ -69,7 +78,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], ('a', 'x'))
self.assertHasLiveOut(fn_body[1], 'x')
- def test_stacked_if_else(self):
+ def test_live_out_stacked_if_else(self):
def test_fn(x, a):
if a > 0:
@@ -86,7 +95,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'a')
self.assertHasLiveOut(fn_body[1], 'x')
- def test_for_basic(self):
+ def test_live_out_for_basic(self):
def test_fn(x, a):
for i in range(a):
@@ -98,7 +107,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'x')
- def test_attributes(self):
+ def test_live_out_attributes(self):
def test_fn(x, a):
if a > 0:
@@ -110,7 +119,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], ('x.y', 'x'))
- def test_nested_functions(self):
+ def test_live_out_nested_functions(self):
def test_fn(a, b):
if b:
@@ -126,7 +135,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'a')
- def test_nested_functions_isolation(self):
+ def test_live_out_nested_functions_isolation(self):
def test_fn(b):
if b:
@@ -144,6 +153,67 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'max')
+ def test_live_in_stacked_if(self):
+
+ def test_fn(x, a, b, c):
+ if a > 0:
+ x = b
+ if c > 1:
+ x = 0
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x'))
+ self.assertHasLiveIn(fn_body[1], ('c', 'x'))
+
+ def test_live_in_stacked_if_else(self):
+
+ def test_fn(x, a, b, c, d):
+ if a > 1:
+ x = b
+ else:
+ x = c
+ if d > 0:
+ x = 0
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd'))
+ self.assertHasLiveIn(fn_body[1], ('d', 'x'))
+
+ def test_live_in_for_basic(self):
+
+ def test_fn(x, y, a):
+ for i in a:
+ x = i
+ y += x
+ z = 0
+ return y, z
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
+
+ def test_live_in_for_nested(self):
+
+ def test_fn(x, y, a):
+ for i in a:
+ for j in i:
+ x = i
+ y += x
+ z = j
+ return y, z
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
+
if __name__ == '__main__':
test.main()