aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-16 12:09:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-16 12:12:11 -0700
commit345ccea1ea751e426a2d2d8e8d44455c43336d8c (patch)
treef481c739b46d4b136e9ba01adc488c4c70862dae
parent0fdad03d31854ad37ad8e8a2cf5df9283a2ee050 (diff)
Remove obsolete tests. Patch the unexpected print output in Python 3.
PiperOrigin-RevId: 193078330
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py38
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py10
2 files changed, 14 insertions, 34 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index ac7e756c47..30272409df 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -26,8 +26,6 @@ from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -49,7 +47,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
self.assertEqual(3, result.test_fn([0, 0, 0]))
- def test_print_with_op(self):
+ def test_print(self):
def test_fn(a):
print(a)
@@ -57,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- # Note: it's relevant not to include script_ops.py_func here, to verify
- # that tf.Print is used.
- with self.compiled(node, logging_ops.Print) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- result.test_fn('a')
+ result.test_fn(constant_op.constant('a'))
sess.run(sess.graph.get_operations())
self.assertEqual(out_capturer.getvalue(), 'a\n')
finally:
@@ -72,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
def test_print_with_op_multiple_values(self):
- def test_fn(a, b):
- print(a, b)
-
- node = self.parse_and_analyze(test_fn, {'print': print})
- node = builtin_functions.transform(node, self.ctx)
-
- # Note: it's relevant not to include script_ops.py_func here, to verify
- # that tf.Print is used.
- with self.compiled(node, logging_ops.Print) as result:
- with self.test_session() as sess:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn('a', 1)
- sess.run(sess.graph.get_operations())
- self.assertEqual(out_capturer.getvalue(), 'a 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_print_with_py_func(self):
-
def test_fn(a, b, c):
print(a, b, c)
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- # Note: it's relevant not to include logging_ops.Print here, to verify
- # that py_func is used.
- with self.compiled(node, script_ops.py_func) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- result.test_fn('a', 1, [2, 3])
+ result.test_fn(
+ constant_op.constant('a'), constant_op.constant(1), [2, 3])
sess.run(sess.graph.get_operations())
self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
finally:
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index 7fbb7c09d8..349b7b6f2a 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -98,9 +98,15 @@ def dynamic_print(*values):
if all(map(is_tf_print_compatible, values)):
return logging_ops.Print(1, values)
- def flushed_print(*vals):
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
print(*vals)
+ # The flush helps avoid garbled output in IPython.
sys.stdout.flush()
return py_func.wrap_py_func(
- flushed_print, None, values, use_dummy_return=True)
+ print_wrapper, None, values, use_dummy_return=True)