aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py268
1 files changed, 96 insertions, 172 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index d4a6ce8ac3..9a4f1bf09b 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -32,62 +32,63 @@ from tensorflow.python.platform import test
class ScopeTest(test.TestCase):
+ def assertMissing(self, qn, scope):
+ self.assertNotIn(qn, scope.used)
+ self.assertNotIn(qn, scope.modified)
+
+ def assertReadOnly(self, qn, scope):
+ self.assertIn(qn, scope.used)
+ self.assertNotIn(qn, scope.modified)
+
+ def assertWriteOnly(self, qn, scope):
+ self.assertNotIn(qn, scope.used)
+ self.assertIn(qn, scope.modified)
+
+ def assertReadWrite(self, qn, scope):
+ self.assertIn(qn, scope.used)
+ self.assertIn(qn, scope.modified)
+
def test_basic(self):
scope = activity.Scope(None)
- self.assertFalse(scope.has(QN('foo')))
+ self.assertMissing(QN('foo'), scope)
scope.mark_read(QN('foo'))
- self.assertFalse(scope.has(QN('foo')))
-
- scope.mark_write(QN('foo'))
- self.assertTrue(scope.has(QN('foo')))
+ self.assertReadOnly(QN('foo'), scope)
- scope.mark_read(QN('bar'))
- self.assertFalse(scope.has(QN('bar')))
+ scope.mark_modified(QN('foo'))
+ self.assertReadWrite(QN('foo'), scope)
def test_copy_from(self):
scope = activity.Scope(None)
- scope.mark_write(QN('foo'))
-
+ scope.mark_modified(QN('foo'))
other = activity.Scope(None)
other.copy_from(scope)
- self.assertTrue(QN('foo') in other.modified)
+ self.assertWriteOnly(QN('foo'), other)
- scope.mark_write(QN('bar'))
+ scope.mark_modified(QN('bar'))
scope.copy_from(other)
- self.assertFalse(QN('bar') in scope.modified)
+ self.assertMissing(QN('bar'), scope)
- scope.mark_write(QN('bar'))
+ scope.mark_modified(QN('bar'))
scope.merge_from(other)
- self.assertTrue(QN('bar') in scope.modified)
- self.assertFalse(QN('bar') in other.modified)
+ self.assertWriteOnly(QN('bar'), scope)
+ self.assertMissing(QN('bar'), other)
def test_copy_of(self):
scope = activity.Scope(None)
scope.mark_read(QN('foo'))
+ other = activity.Scope.copy_of(scope)
- self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used)
+ self.assertReadOnly(QN('foo'), other)
child_scope = activity.Scope(scope)
child_scope.mark_read(QN('bar'))
+ other = activity.Scope.copy_of(child_scope)
- self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used)
-
- def test_nesting(self):
- scope = activity.Scope(None)
- scope.mark_write(QN('foo'))
- scope.mark_read(QN('bar'))
-
- child = activity.Scope(scope)
- self.assertTrue(child.has(QN('foo')))
- self.assertTrue(scope.has(QN('foo')))
-
- child.mark_write(QN('bar'))
- self.assertTrue(child.has(QN('bar')))
- self.assertFalse(scope.has(QN('bar')))
+ self.assertReadOnly(QN('bar'), other)
def test_referenced(self):
scope = activity.Scope(None)
@@ -123,25 +124,6 @@ class ActivityAnalyzerTest(test.TestCase):
node = activity.resolve(node, entity_info)
return node, entity_info
- def test_local_markers(self):
-
- def test_fn(a): # pylint:disable=unused-argument
- b = c # pylint:disable=undefined-variable
- while b > 0:
- b -= 1
- return b
-
- node, _ = self._parse_and_analyze(test_fn)
- self.assertFalse(
- anno.getanno(node.body[0].body[0].value,
- NodeAnno.IS_LOCAL)) # c in b = c
- self.assertTrue(
- anno.getanno(node.body[0].body[1].test.left,
- NodeAnno.IS_LOCAL)) # b in b > 0
- self.assertTrue(
- anno.getanno(node.body[0].body[2].value,
- NodeAnno.IS_LOCAL)) # b in return b
-
def assertSymbolSetsAre(self, expected, actual, name):
expected = set(expected)
actual = set(str(s) for s in actual)
@@ -153,12 +135,10 @@ class ActivityAnalyzerTest(test.TestCase):
' Extra: %s\n' % (name.upper(), expected, actual,
expected - actual, actual - expected))
- def assertScopeIsRmc(self, scope, used, modified, created):
+ def assertScopeIs(self, scope, used, modified):
"""Assert the scope contains specific used, modified & created variables."""
self.assertSymbolSetsAre(used, scope.used, 'read')
self.assertSymbolSetsAre(modified, scope.modified, 'modified')
- # Created is deprecated, we're no longer verifying it.
- # self.assertSymbolSetsAre(created, scope.created, 'created')
def test_print_statement(self):
@@ -181,7 +161,7 @@ class ActivityAnalyzerTest(test.TestCase):
print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
+ self.assertScopeIs(print_args_scope, ('a', 'b'), ())
def test_call_args(self):
@@ -195,8 +175,8 @@ class ActivityAnalyzerTest(test.TestCase):
call_node = node.body[0].body[2].value
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertScopeIsRmc(
- anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), ())
def test_call_args_attributes(self):
@@ -210,12 +190,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[1].value
- self.assertScopeIsRmc(
- anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
- ('a', 'a.b', 'a.c'),
- (),
- (),
- )
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ())
def test_call_args_subscripts(self):
@@ -230,12 +206,9 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[2].value
- self.assertScopeIsRmc(
+ self.assertScopeIs(
anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
- ('a', 'a[0]', 'a[b]', 'b'),
- (),
- (),
- )
+ ('a', 'a[0]', 'a[b]', 'b'), ())
def test_while(self):
@@ -248,14 +221,13 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
while_node = node.body[0].body[1]
- self.assertScopeIsRmc(
- anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
- ('c',))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
+ self.assertScopeIs(
anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
- ('b', 'c'), ('a', 'b', 'c'))
- self.assertScopeIsRmc(
- anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
+ ('b', 'c'))
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), ())
def test_for(self):
@@ -268,11 +240,11 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
for_node = node.body[0].body[1]
- self.assertScopeIsRmc(
- anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
+ self.assertScopeIs(
anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
- ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+ ('b', 'c', '_'))
def test_if(self):
@@ -289,18 +261,16 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
- ('y', 'z'))
- # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
- ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'),
+ ('x', 'y', 'z', 'u'))
+ self.assertScopeIs(
anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
- ('x', 'y', 'u'), ('y', 'u'))
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_if_attributes(self):
@@ -316,24 +286,14 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE),
- ('a', 'a.c'),
- ('a.b', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
- ('a', 'a.c'),
- ('a.b', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
- ('a', 'a.c', 'd'),
- ('a.b', 'd'),
- ('a', 'd'),
- )
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c'),
+ ('a.b', 'd'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('a', 'a.c', 'd'),
+ ('a.b', 'd'))
def test_if_subscripts(self):
@@ -348,25 +308,15 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE),
- ('a', 'b', 'c', 'a[c]'),
- ('a[b]', 'd'),
- ('d',),
- )
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'),
+ ('a[b]', 'd'))
# TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
- ('a', 'e'),
- ('a[0]', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), ('a[0]', 'd'))
+ self.assertScopeIs(
anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
- ('a', 'b', 'c', 'd', 'e', 'a[c]'),
- ('d', 'a[b]', 'a[0]'),
- ('a', 'b', 'c', 'd', 'e'),
- )
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]'))
def test_nested_if(self):
@@ -380,12 +330,10 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
inner_if_node = node.body[0].body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',),
- ('a',))
- self.assertScopeIsRmc(
- anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
- ('a',))
+ self.assertScopeIs(
+ anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',))
+ self.assertScopeIs(
+ anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',))
def test_nested_function(self):
@@ -404,11 +352,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_def_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
- 'x',
- 'y',
- ))
+ self.assertScopeIs(
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',))
def test_constructor_attributes(self):
@@ -420,12 +365,9 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(TestClass)
init_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(init_node, NodeAnno.BODY_SCOPE),
- ('self', 'a', 'self.b'),
- ('self', 'self.b', 'self.b.c'),
- ('self', 'a', 'self.b'),
- )
+ self.assertScopeIs(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE), ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'))
def test_aug_assign_subscripts(self):
@@ -434,12 +376,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('a', 'a[0]'),
- ('a[0]',),
- ('a',),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',))
def test_return_vars_are_read(self):
@@ -448,16 +386,7 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('c',),
- (),
- (
- 'a',
- 'b',
- 'c',
- ),
- )
+ self.assertScopeIs(anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('c',), ())
def test_aug_assign(self):
@@ -466,12 +395,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('a', 'b'),
- ('a'),
- ('a', 'b'),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'b'), ('a'))
def test_aug_assign_rvalues(self):
@@ -485,23 +410,22 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('foo', 'x'),
- (),
- ('x',),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ())
- def test_params_created(self):
+ def test_params(self):
def test_fn(a, b): # pylint: disable=unused-argument
return b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')),
- (('a', 'b')))
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b',), ())
+ self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b'))
+
+ args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
+ self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
if __name__ == '__main__':