aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/continue_statements_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/continue_statements_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements_test.py48
1 files changed, 19 insertions, 29 deletions
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py
index 2ce1837972..3a7c7d1486 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py
@@ -25,7 +25,11 @@ from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase):
- def test_basic_continue(self):
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ with self.converted(test_fn, continue_statements, {}) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_basic(self):
def test_fn(x):
v = []
@@ -36,17 +40,12 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {})
- node = continue_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
- def test_basic_continue_for_loop(self):
+ def test_for_loop(self):
def test_fn(a):
v = []
@@ -57,16 +56,12 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {})
- node = continue_statements.transform(node, self.ctx)
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1])
+ self.assertTransformedEquivalent(test_fn, [2])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
- with self.compiled(node) as result:
- self.assertEqual(test_fn([]), result.test_fn([]))
- self.assertEqual(test_fn([1]), result.test_fn([1]))
- self.assertEqual(test_fn([2]), result.test_fn([2]))
- self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3]))
-
- def test_continue_deeply_nested(self):
+ def test_nested(self):
def test_fn(x):
v = []
@@ -83,15 +78,10 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {})
- node = continue_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
if __name__ == '__main__':