aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-11 16:20:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 16:32:19 -0700
commit668c079f4e6020131978b7a812c3b92eea9c47b9 (patch)
tree269836fd98f37b3a099e6b4cceeb3256416705fa /tensorflow/python/autograph/pyct/static_analysis/activity_test.py
parentefd9e0d073a6632f7632f7fe43ae4364cc2c834b (diff)
Move AutoGraph to core. This CL moves the entirety of the code base, keeping the frontend autograph module in contrib for backward compatibility. Certain files, like notebooks and the readme file may be referenced from the outside, so a copy of those is kept as well. In addition, the notebooks subdirectory of examples is also kept in contrib because the extension the build file relies on is not available in the PIP package.
PiperOrigin-RevId: 212543067
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py508
1 files changed, 508 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
new file mode 100644
index 0000000000..d4a6ce8ac3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -0,0 +1,508 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for activity module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.platform import test
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = activity.Scope(None)
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_read(QN('foo'))
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_write(QN('foo'))
+ self.assertTrue(scope.has(QN('foo')))
+
+ scope.mark_read(QN('bar'))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_copy_from(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+
+ other = activity.Scope(None)
+ other.copy_from(scope)
+
+ self.assertTrue(QN('foo') in other.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.copy_from(other)
+
+ self.assertFalse(QN('bar') in scope.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.merge_from(other)
+
+ self.assertTrue(QN('bar') in scope.modified)
+ self.assertFalse(QN('bar') in other.modified)
+
+ def test_copy_of(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('foo'))
+
+ self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used)
+
+ child_scope = activity.Scope(scope)
+ child_scope.mark_read(QN('bar'))
+
+ self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used)
+
+ def test_nesting(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+ scope.mark_read(QN('bar'))
+
+ child = activity.Scope(scope)
+ self.assertTrue(child.has(QN('foo')))
+ self.assertTrue(scope.has(QN('foo')))
+
+ child.mark_write(QN('bar'))
+ self.assertTrue(child.has(QN('bar')))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_referenced(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('a'))
+
+ child = activity.Scope(scope)
+ child.mark_read(QN('b'))
+
+ child2 = activity.Scope(child, isolated=False)
+ child2.mark_read(QN('c'))
+
+ self.assertTrue(QN('c') in child2.referenced)
+ self.assertTrue(QN('b') in child2.referenced)
+ self.assertFalse(QN('a') in child2.referenced)
+
+ self.assertTrue(QN('c') in child.referenced)
+ self.assertTrue(QN('b') in child.referenced)
+ self.assertFalse(QN('a') in child.referenced)
+
+
+class ActivityAnalyzerTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ return node, entity_info
+
+ def test_local_markers(self):
+
+ def test_fn(a): # pylint:disable=unused-argument
+ b = c # pylint:disable=undefined-variable
+ while b > 0:
+ b -= 1
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ self.assertFalse(
+ anno.getanno(node.body[0].body[0].value,
+ NodeAnno.IS_LOCAL)) # c in b = c
+ self.assertTrue(
+ anno.getanno(node.body[0].body[1].test.left,
+ NodeAnno.IS_LOCAL)) # b in b > 0
+ self.assertTrue(
+ anno.getanno(node.body[0].body[2].value,
+ NodeAnno.IS_LOCAL)) # b in return b
+
+ def assertSymbolSetsAre(self, expected, actual, name):
+ expected = set(expected)
+ actual = set(str(s) for s in actual)
+ self.assertSetEqual(
+ expected, actual, 'for symbol set: %s\n'
+ ' Expected: %s\n'
+ ' Got: %s\n'
+ ' Missing: %s\n'
+ ' Extra: %s\n' % (name.upper(), expected, actual,
+ expected - actual, actual - expected))
+
+ def assertScopeIsRmc(self, scope, used, modified, created):
+ """Assert the scope contains specific used, modified & created variables."""
+ self.assertSymbolSetsAre(used, scope.used, 'read')
+ self.assertSymbolSetsAre(modified, scope.modified, 'modified')
+ # Created is deprecated, we're no longer verifying it.
+ # self.assertSymbolSetsAre(created, scope.created, 'created')
+
+ def test_print_statement(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ print(a, b)
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ print_node = node.body[0].body[2]
+ if isinstance(print_node, gast.Print):
+ # Python 2
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ else:
+ # Python 3
+ assert isinstance(print_node, gast.Expr)
+ # The call node should be the one being annotated.
+ print_node = print_node.value
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
+
+ def test_call_args(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ foo(a, b) # pylint:disable=undefined-variable
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+
+ def test_call_args_attributes(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ a.c = 0
+ foo(a.b, a.c)
+ return a.d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[1].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a.b', 'a.c'),
+ (),
+ (),
+ )
+
+ def test_call_args_subscripts(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ b = 1
+ c = 2
+ foo(a[0], a[b])
+ return a[c]
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a[0]', 'a[b]', 'b'),
+ (),
+ (),
+ )
+
+ def test_while(self):
+
+ def test_fn(a):
+ b = a
+ while b > 0:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ while_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
+ ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
+
+ def test_for(self):
+
+ def test_fn(a):
+ b = a
+ for _ in a:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ for_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+ def test_if(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x
+ y = 2 * x
+ z = -y
+ else:
+ x = 2 * x
+ y = -x
+ u = -y
+ return z, u
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
+ ('y', 'z'))
+ # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
+ ('x', 'y', 'u'), ('y', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+
+ def test_if_attributes(self):
+
+ def test_fn(a):
+ if a > 0:
+ a.b = -a.c
+ d = 2 * a
+ else:
+ a.b = a.c
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
+ ('a', 'a.c', 'd'),
+ ('a.b', 'd'),
+ ('a', 'd'),
+ )
+
+ def test_if_subscripts(self):
+
+ def test_fn(a, b, c, e):
+ if a > 0:
+ a[b] = -a[c]
+ d = 2 * a
+ else:
+ a[0] = e
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b', 'c', 'a[c]'),
+ ('a[b]', 'd'),
+ ('d',),
+ )
+ # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'e'),
+ ('a[0]', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'),
+ ('d', 'a[b]', 'a[0]'),
+ ('a', 'b', 'c', 'd', 'e'),
+ )
+
+ def test_nested_if(self):
+
+ def test_fn(b):
+ if b > 0:
+ if b < 5:
+ a = b
+ else:
+ a = b * b
+ return a
+
+ node, _ = self._parse_and_analyze(test_fn)
+ inner_if_node = node.body[0].body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',),
+ ('a',))
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
+ ('a',))
+
+ def test_nested_function(self):
+
+ def test_fn(a):
+
+ def f(x):
+ y = x * x
+ return y
+
+ b = a
+ for i in a:
+ c = b
+ b -= f(i)
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_def_node = node.body[0].body[0]
+
+ self.assertScopeIsRmc(
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
+ 'x',
+ 'y',
+ ))
+
+ def test_constructor_attributes(self):
+
+ class TestClass(object):
+
+ def __init__(self, a):
+ self.b = a
+ self.b.c = 1
+
+ node, _ = self._parse_and_analyze(TestClass)
+ init_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE),
+ ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'),
+ ('self', 'a', 'self.b'),
+ )
+
+ def test_aug_assign_subscripts(self):
+
+ def test_fn(a):
+ a[0] += 1
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a[0]'),
+ ('a[0]',),
+ ('a',),
+ )
+
+ def test_return_vars_are_read(self):
+
+ def test_fn(a, b, c): # pylint: disable=unused-argument
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('c',),
+ (),
+ (
+ 'a',
+ 'b',
+ 'c',
+ ),
+ )
+
+ def test_aug_assign(self):
+
+ def test_fn(a, b):
+ a += b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b'),
+ ('a'),
+ ('a', 'b'),
+ )
+
+ def test_aug_assign_rvalues(self):
+
+ a = dict(bar=3)
+
+ def foo():
+ return a
+
+ def test_fn(x):
+ foo()['bar'] += x
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('foo', 'x'),
+ (),
+ ('x',),
+ )
+
+ def test_params_created(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')),
+ (('a', 'b')))
+
+
+if __name__ == '__main__':
+ test.main()