diff options
Diffstat (limited to 'tensorflow/contrib/autograph/converters/builtin_functions_test.py')
-rw-r--r-- | tensorflow/contrib/autograph/converters/builtin_functions_test.py | 60 |
1 files changed, 21 insertions, 39 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index e9000e518c..d5c3e2c250 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - import six from tensorflow.contrib.autograph.converters import builtin_functions @@ -36,55 +34,39 @@ class BuiltinFunctionsTest(converter_testing.TestCase): def test_fn(a): return len(a) - node = self.parse_and_analyze(test_fn, {'len': len}) - node = builtin_functions.transform(node, self.ctx) - - with self.compiled(node, array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}, + array_ops.shape) as result: with self.test_session() as sess: - self.assertEqual(3, - sess.run( - result.test_fn(constant_op.constant([0, 0, 0])))) - - self.assertEqual(3, result.test_fn([0, 0, 0])) + ops = result.test_fn(constant_op.constant([0, 0, 0])) + self.assertEqual(sess.run(ops), 3) def test_print(self): - def test_fn(a): - print(a) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a): + return print(a) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn(constant_op.constant('a')) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a\n'): + sess.run(result.test_fn('a')) - def test_print_with_op_multiple_values(self): + def test_print_multiple_values(self): - def test_fn(a, b, c): - print(a, b, c) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a, b, c): + return print(a, b, c) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn( - constant_op.constant('a'), constant_op.constant(1), [2, 3]) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a 1 [2, 3]\n'): + sess.run( + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3])) if __name__ == '__main__': |