aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/builtin_functions_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/builtin_functions_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py60
1 files changed, 21 insertions, 39 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index e9000e518c..d5c3e2c250 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys
-
import six
from tensorflow.contrib.autograph.converters import builtin_functions
@@ -36,55 +34,39 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
def test_fn(a):
return len(a)
- node = self.parse_and_analyze(test_fn, {'len': len})
- node = builtin_functions.transform(node, self.ctx)
-
- with self.compiled(node, array_ops.shape) as result:
+ with self.converted(test_fn, builtin_functions, {'len': len},
+ array_ops.shape) as result:
with self.test_session() as sess:
- self.assertEqual(3,
- sess.run(
- result.test_fn(constant_op.constant([0, 0, 0]))))
-
- self.assertEqual(3, result.test_fn([0, 0, 0]))
+ ops = result.test_fn(constant_op.constant([0, 0, 0]))
+ self.assertEqual(sess.run(ops), 3)
def test_print(self):
- def test_fn(a):
- print(a)
+ if six.PY2:
+ return
- node = self.parse_and_analyze(test_fn, {'print': print})
- node = builtin_functions.transform(node, self.ctx)
+ def test_fn(a):
+ return print(a)
- with self.compiled(node) as result:
+ with self.converted(test_fn, builtin_functions, {'print': print}) as result:
with self.test_session() as sess:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn(constant_op.constant('a'))
- sess.run(sess.graph.get_operations())
- self.assertEqual(out_capturer.getvalue(), 'a\n')
- finally:
- sys.stdout = sys.__stdout__
+ with self.assertPrints('a\n'):
+ sess.run(result.test_fn('a'))
- def test_print_with_op_multiple_values(self):
+ def test_print_multiple_values(self):
- def test_fn(a, b, c):
- print(a, b, c)
+ if six.PY2:
+ return
- node = self.parse_and_analyze(test_fn, {'print': print})
- node = builtin_functions.transform(node, self.ctx)
+ def test_fn(a, b, c):
+ return print(a, b, c)
- with self.compiled(node) as result:
+ with self.converted(test_fn, builtin_functions, {'print': print}) as result:
with self.test_session() as sess:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn(
- constant_op.constant('a'), constant_op.constant(1), [2, 3])
- sess.run(sess.graph.get_operations())
- self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
- finally:
- sys.stdout = sys.__stdout__
+ with self.assertPrints('a 1 [2, 3]\n'):
+ sess.run(
+ result.test_fn(
+ constant_op.constant('a'), constant_op.constant(1), [2, 3]))
if __name__ == '__main__':