aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/liveness_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness_test.py86
1 files changed, 78 insertions, 8 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index 0d5f369e92..7b67f8f608 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -47,14 +47,23 @@ class LivenessTest(test.TestCase):
def assertHasLiveOut(self, node, expected):
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
- live_out_str = set(str(v) for v in live_out)
+ live_out_strs = set(str(v) for v in live_out)
if not expected:
expected = ()
if not isinstance(expected, tuple):
expected = (expected,)
- self.assertSetEqual(live_out_str, set(expected))
+ self.assertSetEqual(live_out_strs, set(expected))
- def test_stacked_if(self):
+ def assertHasLiveIn(self, node, expected):
+ live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
+ live_in_strs = set(str(v) for v in live_in)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(live_in_strs, set(expected))
+
+ def test_live_out_stacked_if(self):
def test_fn(x, a):
if a > 0:
@@ -69,7 +78,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], ('a', 'x'))
self.assertHasLiveOut(fn_body[1], 'x')
- def test_stacked_if_else(self):
+ def test_live_out_stacked_if_else(self):
def test_fn(x, a):
if a > 0:
@@ -86,7 +95,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'a')
self.assertHasLiveOut(fn_body[1], 'x')
- def test_for_basic(self):
+ def test_live_out_for_basic(self):
def test_fn(x, a):
for i in range(a):
@@ -98,7 +107,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'x')
- def test_attributes(self):
+ def test_live_out_attributes(self):
def test_fn(x, a):
if a > 0:
@@ -110,7 +119,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], ('x.y', 'x'))
- def test_nested_functions(self):
+ def test_live_out_nested_functions(self):
def test_fn(a, b):
if b:
@@ -126,7 +135,7 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'a')
- def test_nested_functions_isolation(self):
+ def test_live_out_nested_functions_isolation(self):
def test_fn(b):
if b:
@@ -144,6 +153,67 @@ class LivenessTest(test.TestCase):
self.assertHasLiveOut(fn_body[0], 'max')
+ def test_live_in_stacked_if(self):
+
+ def test_fn(x, a, b, c):
+ if a > 0:
+ x = b
+ if c > 1:
+ x = 0
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x'))
+ self.assertHasLiveIn(fn_body[1], ('c', 'x'))
+
+ def test_live_in_stacked_if_else(self):
+
+ def test_fn(x, a, b, c, d):
+ if a > 1:
+ x = b
+ else:
+ x = c
+ if d > 0:
+ x = 0
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd'))
+ self.assertHasLiveIn(fn_body[1], ('d', 'x'))
+
+ def test_live_in_for_basic(self):
+
+ def test_fn(x, y, a):
+ for i in a:
+ x = i
+ y += x
+ z = 0
+ return y, z
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
+
+ def test_live_in_for_nested(self):
+
+ def test_fn(x, y, a):
+ for i in a:
+ for j in i:
+ x = i
+ y += x
+ z = j
+ return y, z
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
+
if __name__ == '__main__':
test.main()