diff options
Diffstat (limited to 'tensorflow/contrib/autograph/converters/continue_statements_test.py')
-rw-r--r-- | tensorflow/contrib/autograph/converters/continue_statements_test.py | 48 |
1 files changed, 19 insertions, 29 deletions
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index 2ce1837972..3a7c7d1486 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -25,7 +25,11 @@ from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): - def test_basic_continue(self): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, continue_statements, {}) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_basic(self): def test_fn(x): v = [] @@ -36,17 +40,12 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) - def test_basic_continue_for_loop(self): + def test_for_loop(self): def test_fn(a): v = [] @@ -57,16 +56,12 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) - with self.compiled(node) as result: - self.assertEqual(test_fn([]), result.test_fn([])) - self.assertEqual(test_fn([1]), result.test_fn([1])) - self.assertEqual(test_fn([2]), result.test_fn([2])) - self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3])) - - def test_continue_deeply_nested(self): + def test_nested(self): def test_fn(x): v = [] @@ -83,15 +78,10 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__': |