diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity_test.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/activity_test.py | 268 |
1 files changed, 96 insertions, 172 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index d4a6ce8ac3..9a4f1bf09b 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -32,62 +32,63 @@ from tensorflow.python.platform import test class ScopeTest(test.TestCase): + def assertMissing(self, qn, scope): + self.assertNotIn(qn, scope.used) + self.assertNotIn(qn, scope.modified) + + def assertReadOnly(self, qn, scope): + self.assertIn(qn, scope.used) + self.assertNotIn(qn, scope.modified) + + def assertWriteOnly(self, qn, scope): + self.assertNotIn(qn, scope.used) + self.assertIn(qn, scope.modified) + + def assertReadWrite(self, qn, scope): + self.assertIn(qn, scope.used) + self.assertIn(qn, scope.modified) + def test_basic(self): scope = activity.Scope(None) - self.assertFalse(scope.has(QN('foo'))) + self.assertMissing(QN('foo'), scope) scope.mark_read(QN('foo')) - self.assertFalse(scope.has(QN('foo'))) - - scope.mark_write(QN('foo')) - self.assertTrue(scope.has(QN('foo'))) + self.assertReadOnly(QN('foo'), scope) - scope.mark_read(QN('bar')) - self.assertFalse(scope.has(QN('bar'))) + scope.mark_modified(QN('foo')) + self.assertReadWrite(QN('foo'), scope) def test_copy_from(self): scope = activity.Scope(None) - scope.mark_write(QN('foo')) - + scope.mark_modified(QN('foo')) other = activity.Scope(None) other.copy_from(scope) - self.assertTrue(QN('foo') in other.modified) + self.assertWriteOnly(QN('foo'), other) - scope.mark_write(QN('bar')) + scope.mark_modified(QN('bar')) scope.copy_from(other) - self.assertFalse(QN('bar') in scope.modified) + self.assertMissing(QN('bar'), scope) - scope.mark_write(QN('bar')) + scope.mark_modified(QN('bar')) scope.merge_from(other) - self.assertTrue(QN('bar') in scope.modified) - self.assertFalse(QN('bar') in other.modified) + self.assertWriteOnly(QN('bar'), scope) + self.assertMissing(QN('bar'), other) def test_copy_of(self): scope = activity.Scope(None) scope.mark_read(QN('foo')) + other = activity.Scope.copy_of(scope) - self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used) + self.assertReadOnly(QN('foo'), other) child_scope = activity.Scope(scope) child_scope.mark_read(QN('bar')) + other = activity.Scope.copy_of(child_scope) - self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used) - - def test_nesting(self): - scope = activity.Scope(None) - scope.mark_write(QN('foo')) - scope.mark_read(QN('bar')) - - child = activity.Scope(scope) - self.assertTrue(child.has(QN('foo'))) - self.assertTrue(scope.has(QN('foo'))) - - child.mark_write(QN('bar')) - self.assertTrue(child.has(QN('bar'))) - self.assertFalse(scope.has(QN('bar'))) + self.assertReadOnly(QN('bar'), other) def test_referenced(self): scope = activity.Scope(None) @@ -123,25 +124,6 @@ class ActivityAnalyzerTest(test.TestCase): node = activity.resolve(node, entity_info) return node, entity_info - def test_local_markers(self): - - def test_fn(a): # pylint:disable=unused-argument - b = c # pylint:disable=undefined-variable - while b > 0: - b -= 1 - return b - - node, _ = self._parse_and_analyze(test_fn) - self.assertFalse( - anno.getanno(node.body[0].body[0].value, - NodeAnno.IS_LOCAL)) # c in b = c - self.assertTrue( - anno.getanno(node.body[0].body[1].test.left, - NodeAnno.IS_LOCAL)) # b in b > 0 - self.assertTrue( - anno.getanno(node.body[0].body[2].value, - NodeAnno.IS_LOCAL)) # b in return b - def assertSymbolSetsAre(self, expected, actual, name): expected = set(expected) actual = set(str(s) for s in actual) @@ -153,12 +135,10 @@ class ActivityAnalyzerTest(test.TestCase): ' Extra: %s\n' % (name.upper(), expected, actual, expected - actual, actual - expected)) - def assertScopeIsRmc(self, scope, used, modified, created): + def assertScopeIs(self, scope, used, modified): """Assert the scope contains specific used, modified & created variables.""" self.assertSymbolSetsAre(used, scope.used, 'read') self.assertSymbolSetsAre(modified, scope.modified, 'modified') - # Created is deprecated, we're no longer verifying it. - # self.assertSymbolSetsAre(created, scope.created, 'created') def test_print_statement(self): @@ -181,7 +161,7 @@ class ActivityAnalyzerTest(test.TestCase): print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ()) + self.assertScopeIs(print_args_scope, ('a', 'b'), ()) def test_call_args(self): @@ -195,8 +175,8 @@ class ActivityAnalyzerTest(test.TestCase): call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), ()) def test_call_args_attributes(self): @@ -210,12 +190,8 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[1].value - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE), - ('a', 'a.b', 'a.c'), - (), - (), - ) + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ()) def test_call_args_subscripts(self): @@ -230,12 +206,9 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value - self.assertScopeIsRmc( + self.assertScopeIs( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), - ('a', 'a[0]', 'a[b]', 'b'), - (), - (), - ) + ('a', 'a[0]', 'a[b]', 'b'), ()) def test_while(self): @@ -248,14 +221,13 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] - self.assertScopeIsRmc( - anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), - ('c',)) - self.assertScopeIsRmc( + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c')) + self.assertScopeIs( anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), - ('b', 'c'), ('a', 'b', 'c')) - self.assertScopeIsRmc( - anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ()) + ('b', 'c')) + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), ()) def test_for(self): @@ -268,11 +240,11 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] - self.assertScopeIsRmc( - anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIsRmc( + self.assertScopeIs( + anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c')) + self.assertScopeIs( anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), - ('b', 'c', '_'), ('a', 'b', 'c', '_')) + ('b', 'c', '_')) def test_if(self): @@ -289,18 +261,16 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), - ('y', 'z')) - # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), - ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - self.assertScopeIsRmc( + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'), + ('x', 'y', 'z', 'u')) + self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), - ('x', 'y', 'u'), ('y', 'u')) - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'u')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) def test_if_attributes(self): @@ -316,24 +286,14 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE), - ('a', 'a.c'), - ('a.b', 'd'), - ('d',), - ) - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), - ('a', 'a.c'), - ('a.b', 'd'), - ('d',), - ) - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, - ('a', 'a.c', 'd'), - ('a.b', 'd'), - ('a', 'd'), - ) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c'), + ('a.b', 'd')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('a', 'a.c', 'd'), + ('a.b', 'd')) def test_if_subscripts(self): @@ -348,25 +308,15 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE), - ('a', 'b', 'c', 'a[c]'), - ('a[b]', 'd'), - ('d',), - ) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), + ('a[b]', 'd')) # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? - self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), - ('a', 'e'), - ('a[0]', 'd'), - ('d',), - ) - self.assertScopeIsRmc( + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), ('a[0]', 'd')) + self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, - ('a', 'b', 'c', 'd', 'e', 'a[c]'), - ('d', 'a[b]', 'a[0]'), - ('a', 'b', 'c', 'd', 'e'), - ) + ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]')) def test_nested_if(self): @@ -380,12 +330,10 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), - ('a',)) - self.assertScopeIsRmc( - anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',), - ('a',)) + self.assertScopeIs( + anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',)) + self.assertScopeIs( + anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',)) def test_nested_function(self): @@ -404,11 +352,8 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( - 'x', - 'y', - )) + self.assertScopeIs( + anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',)) def test_constructor_attributes(self): @@ -420,12 +365,9 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(TestClass) init_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(init_node, NodeAnno.BODY_SCOPE), - ('self', 'a', 'self.b'), - ('self', 'self.b', 'self.b.c'), - ('self', 'a', 'self.b'), - ) + self.assertScopeIs( + anno.getanno(init_node, NodeAnno.BODY_SCOPE), ('self', 'a', 'self.b'), + ('self', 'self.b', 'self.b.c')) def test_aug_assign_subscripts(self): @@ -434,12 +376,8 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] - self.assertScopeIsRmc( - anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('a', 'a[0]'), - ('a[0]',), - ('a',), - ) + self.assertScopeIs( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',)) def test_return_vars_are_read(self): @@ -448,16 +386,7 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] - self.assertScopeIsRmc( - anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('c',), - (), - ( - 'a', - 'b', - 'c', - ), - ) + self.assertScopeIs(anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('c',), ()) def test_aug_assign(self): @@ -466,12 +395,8 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] - self.assertScopeIsRmc( - anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('a', 'b'), - ('a'), - ('a', 'b'), - ) + self.assertScopeIs( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'b'), ('a')) def test_aug_assign_rvalues(self): @@ -485,23 +410,22 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] - self.assertScopeIsRmc( - anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('foo', 'x'), - (), - ('x',), - ) + self.assertScopeIs( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ()) - def test_params_created(self): + def test_params(self): def test_fn(a, b): # pylint: disable=unused-argument return b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] - self.assertScopeIsRmc( - anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), - (('a', 'b'))) + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + self.assertScopeIs(body_scope, ('b',), ()) + self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b')) + + args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE) + self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params') if __name__ == '__main__': |