aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/break_statements_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/break_statements_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements_test.py62
1 files changed, 22 insertions, 40 deletions
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py
index dcff1c54c2..c26ca2946c 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/break_statements_test.py
@@ -25,7 +25,11 @@ from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase):
- def test_basic_while(self):
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ with self.converted(test_fn, break_statements, {}) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_while_loop(self):
def test_fn(x):
v = []
@@ -36,15 +40,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {})
- node = break_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual([], result.test_fn(0))
- self.assertEqual([], result.test_fn(1))
- self.assertEqual([3], result.test_fn(4))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 4)
- def test_basic_for(self):
+ def test_for_loop(self):
def test_fn(a):
v = []
@@ -55,18 +55,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {})
- node = break_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ with self.converted(test_fn, break_statements, {}) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
- self.assertEqual([], result.test_fn([]))
- self.assertEqual([3, 3], result.test_fn([4, 4]))
- self.assertEqual([3], result.test_fn([4, 5]))
self.assertEqual([3], result.test_fn([5, 4]))
- def test_deeply_nested(self):
+ def test_nested(self):
def test_fn(x):
v = []
@@ -83,13 +77,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {})
- node = break_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual(([], [], []), result.test_fn(0))
- self.assertEqual(([2, 1], [2], [0]), result.test_fn(3))
- self.assertEqual(([10, 9, 8, 7], [10, 8], [6]), result.test_fn(11))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 11)
def test_nested_loops(self):
@@ -109,16 +99,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- node = self.parse_and_analyze(test_fn, {})
- node = break_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual(([], []), result.test_fn(0))
- self.assertEqual(([1], []), result.test_fn(2))
- self.assertEqual(([2, 1], [1]), result.test_fn(3))
- self.assertEqual(([4, 3, 2, 1], [3, 1]), result.test_fn(5))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 5)
- def test_loop_else(self):
+ def test_loop_orelse(self):
def test_fn(x):
v = []
@@ -134,13 +120,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- node = self.parse_and_analyze(test_fn, {})
- node = break_statements.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- self.assertEqual(([], []), result.test_fn(0))
- self.assertEqual(([], [1]), result.test_fn(2))
- self.assertEqual(([2], [1]), result.test_fn(3))
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
if __name__ == '__main__':