aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/transformer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/transformer_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/transformer_test.py369
1 files changed, 369 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py
new file mode 100644
index 0000000000..23bf9a8e16
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/transformer_test.py
@@ -0,0 +1,369 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.platform import test
+
+
+class TransformerTest(test.TestCase):
+
+ def _simple_source_info(self):
+ return transformer.EntityInfo(
+ source_code=None,
+ source_file=None,
+ namespace=None,
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+
+ def test_entity_scope_tracking(self):
+
+ class TestTransformer(transformer.Base):
+
+ # The choice of note to assign to is arbitrary. Using Assign because it's
+ # easy to find in the tree.
+ def visit_Assign(self, node):
+ anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
+ return self.generic_visit(node)
+
+ # This will show up in the lambda function.
+ def visit_BinOp(self, node):
+ anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
+ return self.generic_visit(node)
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function():
+ a = 0
+
+ class TestClass(object):
+
+ def test_method(self):
+ b = 0
+ def inner_function(x):
+ c = 0
+ d = lambda y: (x + y)
+ return c, d
+ return b, inner_function
+ return a, TestClass
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ test_function_node = node.body[0]
+ test_class = test_function_node.body[1]
+ test_method = test_class.body[0]
+ inner_function = test_method.body[1]
+ lambda_node = inner_function.body[1].value
+
+ a = test_function_node.body[0]
+ b = test_method.body[0]
+ c = inner_function.body[0]
+ lambda_expr = lambda_node.body
+
+ self.assertEqual(
+ (test_function_node,), anno.getanno(a, 'enclosing_entities'))
+ self.assertEqual((test_function_node, test_class, test_method),
+ anno.getanno(b, 'enclosing_entities'))
+ self.assertEqual(
+ (test_function_node, test_class, test_method, inner_function),
+ anno.getanno(c, 'enclosing_entities'))
+ self.assertEqual((test_function_node, test_class, test_method,
+ inner_function, lambda_node),
+ anno.getanno(lambda_expr, 'enclosing_entities'))
+
+ def assertSameAnno(self, first, second, key):
+ self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
+
+ def assertDifferentAnno(self, first, second, key):
+ self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
+
+ def test_state_tracking(self):
+
+ class LoopState(object):
+ pass
+
+ class CondState(object):
+ pass
+
+ class TestTransformer(transformer.Base):
+
+ def visit(self, node):
+ anno.setanno(node, 'loop_state', self.state[LoopState].value)
+ anno.setanno(node, 'cond_state', self.state[CondState].value)
+ return super(TestTransformer, self).visit(node)
+
+ def visit_While(self, node):
+ self.state[LoopState].enter()
+ node = self.generic_visit(node)
+ self.state[LoopState].exit()
+ return node
+
+ def visit_If(self, node):
+ self.state[CondState].enter()
+ node = self.generic_visit(node)
+ self.state[CondState].exit()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ a = 1
+ while a:
+ _ = 'a'
+ if a > 2:
+ _ = 'b'
+ while True:
+ raise '1'
+ if a > 3:
+ _ = 'c'
+ while True:
+ raise '1'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ fn_body = node.body[0].body
+ outer_while_body = fn_body[1].body
+ self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
+ self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
+
+ first_if_body = outer_while_body[1].body
+ self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
+ 'cond_state')
+ self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')
+
+ first_inner_while_body = first_if_body[1].body
+ self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
+ 'cond_state')
+ self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
+ 'loop_state')
+
+ second_if_body = outer_while_body[2].body
+ self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
+ self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')
+
+ second_inner_while_body = second_if_body[1].body
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'cond_state')
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'loop_state')
+
+ def test_local_scope_info_stack(self):
+
+ class TestTransformer(transformer.Base):
+
+ # Extract all string constants from the block.
+ def visit_Str(self, node):
+ self.set_local('string', self.get_local('string', default='') + node.s)
+ return self.generic_visit(node)
+
+ def _annotate_result(self, node):
+ self.enter_local_scope()
+ node = self.generic_visit(node)
+ anno.setanno(node, 'test', self.get_local('string'))
+ self.exit_local_scope()
+ return node
+
+ def visit_While(self, node):
+ return self._annotate_result(node)
+
+ def visit_For(self, node):
+ return self._annotate_result(node)
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ """Docstring."""
+ assert a == 'This should not be counted'
+ for i in range(3):
+ _ = 'a'
+ if i > 2:
+ return 'b'
+ else:
+ _ = 'c'
+ while True:
+ raise '1'
+ return 'nor this'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ for_node = node.body[0].body[2]
+ while_node = for_node.body[1].orelse[1]
+
+ self.assertFalse(anno.hasanno(for_node, 'string'))
+ self.assertEqual('abc', anno.getanno(for_node, 'test'))
+ self.assertFalse(anno.hasanno(while_node, 'string'))
+ self.assertEqual('1', anno.getanno(while_node, 'test'))
+
+ def test_local_scope_info_stack_checks_integrity(self):
+
+ class TestTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ self.enter_local_scope()
+ return self.generic_visit(node)
+
+ def visit_For(self, node):
+ node = self.generic_visit(node)
+ self.exit_local_scope()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def no_exit(a):
+ if a > 0:
+ print(a)
+ return None
+
+ node, _ = parser.parse_entity(no_exit)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
+ def no_entry(a):
+ for _ in a:
+ print(a)
+
+ node, _ = parser.parse_entity(no_entry)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
+ def test_visit_block_postprocessing(self):
+
+ class TestTransformer(transformer.Base):
+
+ def _process_body_item(self, node):
+ if isinstance(node, gast.Assign) and (node.value.id == 'y'):
+ if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
+ return if_node, if_node.body
+ return node, None
+
+ def visit_FunctionDef(self, node):
+ node.body = self.visit_block(
+ node.body, after_visit=self._process_body_item)
+ return node
+
+ def test_function(x, y):
+ z = x
+ z = y
+ return z
+
+ tr = TestTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+ node = node.body[0]
+
+ self.assertEqual(len(node.body), 2)
+ self.assertTrue(isinstance(node.body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1], gast.If))
+ self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
+
+ def test_robust_error_on_list_visit(self):
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ # This is broken because visit expects a single node, not a list, and
+ # the body of an if is a list.
+ # Importantly, the default error handling in visit also expects a single
+ # node. Therefore, mistakes like this need to trigger a type error
+ # before the visit called here installs its error handler.
+ # That type error can then be caught by the enclosing call to visit,
+ # and correctly blame the If node.
+ self.visit(node.body)
+ return node
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
+ self.assertRegexpMatches(obtained_message, expected_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
+
+ def test_robust_error_on_ast_corruption(self):
+ # A child class should not be able to be so broken that it causes the error
+ # handling in `transformer.Base` to raise an exception. Why not? Because
+ # then the original error location is dropped, and an error handler higher
+ # up in the call stack gives misleading information.
+
+ # Here we test that the error handling in `visit` completes, and blames the
+ # correct original exception, even if the AST gets corrupted.
+
+ class NotANode(object):
+ pass
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ node.body = NotANode()
+ raise ValueError('I blew up')
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ # The message should reference the exception actually raised, not anything
+ # from the exception handler.
+ expected_substring = 'I blew up'
+ self.assertTrue(expected_substring in obtained_message, obtained_message)
+ # Expect the exception to have failed to parse the corrupted AST
+ self.assertTrue(
+ '<could not convert AST to source>' in obtained_message,
+ obtained_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
+
+if __name__ == '__main__':
+ test.main()