diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-16 12:09:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-16 12:12:11 -0700 |
commit | 345ccea1ea751e426a2d2d8e8d44455c43336d8c (patch) | |
tree | f481c739b46d4b136e9ba01adc488c4c70862dae | |
parent | 0fdad03d31854ad37ad8e8a2cf5df9283a2ee050 (diff) |
Remove obsolete tests. Patch the unexpected print output in Python 3.
PiperOrigin-RevId: 193078330
-rw-r--r-- | tensorflow/contrib/autograph/converters/builtin_functions_test.py | 38 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/builtins.py | 10 |
2 files changed, 14 insertions, 34 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index ac7e756c47..30272409df 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -26,8 +26,6 @@ from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -49,7 +47,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): self.assertEqual(3, result.test_fn([0, 0, 0])) - def test_print_with_op(self): + def test_print(self): def test_fn(a): print(a) @@ -57,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a') + result.test_fn(constant_op.constant('a')) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a\n') finally: @@ -72,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): def test_print_with_op_multiple_values(self): - def test_fn(a, b): - print(a, b) - - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) - - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: - with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a', 1) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_print_with_py_func(self): - def test_fn(a, b, c): print(a, b, c) node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include logging_ops.Print here, to verify - # that py_func is used. - with self.compiled(node, script_ops.py_func) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a', 1, [2, 3]) + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3]) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') finally: diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 7fbb7c09d8..349b7b6f2a 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -98,9 +98,15 @@ def dynamic_print(*values): if all(map(is_tf_print_compatible, values)): return logging_ops.Print(1, values) - def flushed_print(*vals): + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) print(*vals) + # The flush helps avoid garbled output in IPython. sys.stdout.flush() return py_func.wrap_py_func( - flushed_print, None, values, use_dummy_return=True) + print_wrapper, None, values, use_dummy_return=True) |