aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py508
1 files changed, 508 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
new file mode 100644
index 0000000000..d4a6ce8ac3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -0,0 +1,508 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for activity module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.platform import test
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = activity.Scope(None)
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_read(QN('foo'))
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_write(QN('foo'))
+ self.assertTrue(scope.has(QN('foo')))
+
+ scope.mark_read(QN('bar'))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_copy_from(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+
+ other = activity.Scope(None)
+ other.copy_from(scope)
+
+ self.assertTrue(QN('foo') in other.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.copy_from(other)
+
+ self.assertFalse(QN('bar') in scope.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.merge_from(other)
+
+ self.assertTrue(QN('bar') in scope.modified)
+ self.assertFalse(QN('bar') in other.modified)
+
+ def test_copy_of(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('foo'))
+
+ self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used)
+
+ child_scope = activity.Scope(scope)
+ child_scope.mark_read(QN('bar'))
+
+ self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used)
+
+ def test_nesting(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+ scope.mark_read(QN('bar'))
+
+ child = activity.Scope(scope)
+ self.assertTrue(child.has(QN('foo')))
+ self.assertTrue(scope.has(QN('foo')))
+
+ child.mark_write(QN('bar'))
+ self.assertTrue(child.has(QN('bar')))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_referenced(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('a'))
+
+ child = activity.Scope(scope)
+ child.mark_read(QN('b'))
+
+ child2 = activity.Scope(child, isolated=False)
+ child2.mark_read(QN('c'))
+
+ self.assertTrue(QN('c') in child2.referenced)
+ self.assertTrue(QN('b') in child2.referenced)
+ self.assertFalse(QN('a') in child2.referenced)
+
+ self.assertTrue(QN('c') in child.referenced)
+ self.assertTrue(QN('b') in child.referenced)
+ self.assertFalse(QN('a') in child.referenced)
+
+
+class ActivityAnalyzerTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ return node, entity_info
+
+ def test_local_markers(self):
+
+ def test_fn(a): # pylint:disable=unused-argument
+ b = c # pylint:disable=undefined-variable
+ while b > 0:
+ b -= 1
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ self.assertFalse(
+ anno.getanno(node.body[0].body[0].value,
+ NodeAnno.IS_LOCAL)) # c in b = c
+ self.assertTrue(
+ anno.getanno(node.body[0].body[1].test.left,
+ NodeAnno.IS_LOCAL)) # b in b > 0
+ self.assertTrue(
+ anno.getanno(node.body[0].body[2].value,
+ NodeAnno.IS_LOCAL)) # b in return b
+
+ def assertSymbolSetsAre(self, expected, actual, name):
+ expected = set(expected)
+ actual = set(str(s) for s in actual)
+ self.assertSetEqual(
+ expected, actual, 'for symbol set: %s\n'
+ ' Expected: %s\n'
+ ' Got: %s\n'
+ ' Missing: %s\n'
+ ' Extra: %s\n' % (name.upper(), expected, actual,
+ expected - actual, actual - expected))
+
+ def assertScopeIsRmc(self, scope, used, modified, created):
+ """Assert the scope contains specific used, modified & created variables."""
+ self.assertSymbolSetsAre(used, scope.used, 'read')
+ self.assertSymbolSetsAre(modified, scope.modified, 'modified')
+ # Created is deprecated, we're no longer verifying it.
+ # self.assertSymbolSetsAre(created, scope.created, 'created')
+
+ def test_print_statement(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ print(a, b)
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ print_node = node.body[0].body[2]
+ if isinstance(print_node, gast.Print):
+ # Python 2
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ else:
+ # Python 3
+ assert isinstance(print_node, gast.Expr)
+ # The call node should be the one being annotated.
+ print_node = print_node.value
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
+
+ def test_call_args(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ foo(a, b) # pylint:disable=undefined-variable
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+
+ def test_call_args_attributes(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ a.c = 0
+ foo(a.b, a.c)
+ return a.d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[1].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a.b', 'a.c'),
+ (),
+ (),
+ )
+
+ def test_call_args_subscripts(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ b = 1
+ c = 2
+ foo(a[0], a[b])
+ return a[c]
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a[0]', 'a[b]', 'b'),
+ (),
+ (),
+ )
+
+ def test_while(self):
+
+ def test_fn(a):
+ b = a
+ while b > 0:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ while_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
+ ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
+
+ def test_for(self):
+
+ def test_fn(a):
+ b = a
+ for _ in a:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ for_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+ def test_if(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x
+ y = 2 * x
+ z = -y
+ else:
+ x = 2 * x
+ y = -x
+ u = -y
+ return z, u
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
+ ('y', 'z'))
+ # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
+ ('x', 'y', 'u'), ('y', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+
+ def test_if_attributes(self):
+
+ def test_fn(a):
+ if a > 0:
+ a.b = -a.c
+ d = 2 * a
+ else:
+ a.b = a.c
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
+ ('a', 'a.c', 'd'),
+ ('a.b', 'd'),
+ ('a', 'd'),
+ )
+
+ def test_if_subscripts(self):
+
+ def test_fn(a, b, c, e):
+ if a > 0:
+ a[b] = -a[c]
+ d = 2 * a
+ else:
+ a[0] = e
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b', 'c', 'a[c]'),
+ ('a[b]', 'd'),
+ ('d',),
+ )
+ # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'e'),
+ ('a[0]', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'),
+ ('d', 'a[b]', 'a[0]'),
+ ('a', 'b', 'c', 'd', 'e'),
+ )
+
+ def test_nested_if(self):
+
+ def test_fn(b):
+ if b > 0:
+ if b < 5:
+ a = b
+ else:
+ a = b * b
+ return a
+
+ node, _ = self._parse_and_analyze(test_fn)
+ inner_if_node = node.body[0].body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',),
+ ('a',))
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
+ ('a',))
+
+ def test_nested_function(self):
+
+ def test_fn(a):
+
+ def f(x):
+ y = x * x
+ return y
+
+ b = a
+ for i in a:
+ c = b
+ b -= f(i)
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_def_node = node.body[0].body[0]
+
+ self.assertScopeIsRmc(
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
+ 'x',
+ 'y',
+ ))
+
+ def test_constructor_attributes(self):
+
+ class TestClass(object):
+
+ def __init__(self, a):
+ self.b = a
+ self.b.c = 1
+
+ node, _ = self._parse_and_analyze(TestClass)
+ init_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE),
+ ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'),
+ ('self', 'a', 'self.b'),
+ )
+
+ def test_aug_assign_subscripts(self):
+
+ def test_fn(a):
+ a[0] += 1
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a[0]'),
+ ('a[0]',),
+ ('a',),
+ )
+
+ def test_return_vars_are_read(self):
+
+ def test_fn(a, b, c): # pylint: disable=unused-argument
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('c',),
+ (),
+ (
+ 'a',
+ 'b',
+ 'c',
+ ),
+ )
+
+ def test_aug_assign(self):
+
+ def test_fn(a, b):
+ a += b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b'),
+ ('a'),
+ ('a', 'b'),
+ )
+
+ def test_aug_assign_rvalues(self):
+
+ a = dict(bar=3)
+
+ def foo():
+ return a
+
+ def test_fn(x):
+ foo()['bar'] += x
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('foo', 'x'),
+ (),
+ ('x',),
+ )
+
+ def test_params_created(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')),
+ (('a', 'b')))
+
+
+if __name__ == '__main__':
+ test.main()