aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/transformer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/transformer_test.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index baf04653ae..a37e922a1d 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -93,6 +93,83 @@ class TransformerTest(test.TestCase):
inner_function, lambda_node),
anno.getanno(lambda_expr, 'enclosing_entities'))
+ def assertSameAnno(self, first, second, key):
+ self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
+
+ def assertDifferentAnno(self, first, second, key):
+ self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
+
+ def test_state_tracking(self):
+
+ class LoopState(object):
+ pass
+
+ class CondState(object):
+ pass
+
+ class TestTransformer(transformer.Base):
+
+ def visit(self, node):
+ anno.setanno(node, 'loop_state', self.state[LoopState].value)
+ anno.setanno(node, 'cond_state', self.state[CondState].value)
+ return super(TestTransformer, self).visit(node)
+
+ def visit_While(self, node):
+ self.state[LoopState].enter()
+ node = self.generic_visit(node)
+ self.state[LoopState].exit()
+ return node
+
+ def visit_If(self, node):
+ self.state[CondState].enter()
+ node = self.generic_visit(node)
+ self.state[CondState].exit()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ a = 1
+ while a:
+ _ = 'a'
+ if a > 2:
+ _ = 'b'
+ while True:
+ raise '1'
+ if a > 3:
+ _ = 'c'
+ while True:
+ raise '1'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ fn_body = node.body[0].body
+ outer_while_body = fn_body[1].body
+ self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
+ self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
+
+ first_if_body = outer_while_body[1].body
+ self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
+ 'cond_state')
+ self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')
+
+ first_inner_while_body = first_if_body[1].body
+ self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
+ 'cond_state')
+ self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
+ 'loop_state')
+
+ second_if_body = outer_while_body[2].body
+ self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
+ self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')
+
+ second_inner_while_body = second_if_body[1].body
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'cond_state')
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'loop_state')
+
def test_local_scope_info_stack(self):
class TestTransformer(transformer.Base):
@@ -205,6 +282,88 @@ class TransformerTest(test.TestCase):
self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
+ def test_robust_error_on_list_visit(self):
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ # This is broken because visit expects a single node, not a list, and
+ # the body of an if is a list.
+ # Importantly, the default error handling in visit also expects a single
+ # node. Therefore, mistakes like this need to trigger a type error
+ # before the visit called here installs its error handler.
+ # That type error can then be caught by the enclosing call to visit,
+ # and correctly blame the If node.
+ self.visit(node.body)
+ return node
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
+ self.assertRegexpMatches(obtained_message, expected_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
+
+ def test_robust_error_on_ast_corruption(self):
+ # A child class should not be able to be so broken that it causes the error
+ # handling in `transformer.Base` to raise an exception. Why not? Because
+ # then the original error location is dropped, and an error handler higher
+ # up in the call stack gives misleading information.
+
+ # Here we test that the error handling in `visit` completes, and blames the
+ # correct original exception, even if the AST gets corrupted.
+
+ class NotANode(object):
+ pass
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ node.body = NotANode()
+ raise ValueError('I blew up')
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ # The message should reference the exception actually raised, not anything
+ # from the exception handler.
+ expected_substring = 'I blew up'
+ self.assertTrue(expected_substring in obtained_message, obtained_message)
+ # Expect the exception to have failed to parse the corrupted AST
+ self.assertTrue(
+ '<could not convert AST to source>' in obtained_message,
+ obtained_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
if __name__ == '__main__':
test.main()