diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/liveness_test.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/liveness_test.py | 86 |
1 files changed, 78 insertions, 8 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index 0d5f369e92..7b67f8f608 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -47,14 +47,23 @@ class LivenessTest(test.TestCase): def assertHasLiveOut(self, node, expected): live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - live_out_str = set(str(v) for v in live_out) + live_out_strs = set(str(v) for v in live_out) if not expected: expected = () if not isinstance(expected, tuple): expected = (expected,) - self.assertSetEqual(live_out_str, set(expected)) + self.assertSetEqual(live_out_strs, set(expected)) - def test_stacked_if(self): + def assertHasLiveIn(self, node, expected): + live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) + live_in_strs = set(str(v) for v in live_in) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(live_in_strs, set(expected)) + + def test_live_out_stacked_if(self): def test_fn(x, a): if a > 0: @@ -69,7 +78,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], ('a', 'x')) self.assertHasLiveOut(fn_body[1], 'x') - def test_stacked_if_else(self): + def test_live_out_stacked_if_else(self): def test_fn(x, a): if a > 0: @@ -86,7 +95,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'a') self.assertHasLiveOut(fn_body[1], 'x') - def test_for_basic(self): + def test_live_out_for_basic(self): def test_fn(x, a): for i in range(a): @@ -98,7 +107,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'x') - def test_attributes(self): + def test_live_out_attributes(self): def test_fn(x, a): if a > 0: @@ -110,7 +119,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], ('x.y', 'x')) - def test_nested_functions(self): + def test_live_out_nested_functions(self): def test_fn(a, b): if b: @@ -126,7 +135,7 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'a') - def test_nested_functions_isolation(self): + def test_live_out_nested_functions_isolation(self): def test_fn(b): if b: @@ -144,6 +153,67 @@ class LivenessTest(test.TestCase): self.assertHasLiveOut(fn_body[0], 'max') + def test_live_in_stacked_if(self): + + def test_fn(x, a, b, c): + if a > 0: + x = b + if c > 1: + x = 0 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x')) + self.assertHasLiveIn(fn_body[1], ('c', 'x')) + + def test_live_in_stacked_if_else(self): + + def test_fn(x, a, b, c, d): + if a > 1: + x = b + else: + x = c + if d > 0: + x = 0 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd')) + self.assertHasLiveIn(fn_body[1], ('d', 'x')) + + def test_live_in_for_basic(self): + + def test_fn(x, y, a): + for i in a: + x = i + y += x + z = 0 + return y, z + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) + + def test_live_in_for_nested(self): + + def test_fn(x, y, a): + for i in a: + for j in i: + x = i + y += x + z = j + return y, z + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) + if __name__ == '__main__': test.main() |