aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 15:40:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 15:45:12 -0700
commitc96841dbd199d3c1a15a89e8c44c7c1d164968b9 (patch)
tree392511a31e3565f862c7257da6a35c31b6587968 /tensorflow/python/framework
parentb2b98a5ad1b647b77cb42761671cd9b3cf0e87b6 (diff)
This CL adds a new `tf.print` operator that more closely aligns with the standard python `print` method, and deprecates the old `tf.Print` operator (to be removed in in v2.0).
It follows the design doc specified in https://github.com/tensorflow/community/pull/14 and additionally incorporates the community feedback and design review decisions. This CL adds two new internal graph operators: a StringFormat operator that formats a template string with a list of input tensors to insert into the string and outputs a string scalar containing the result, and a PrintV2 operator that prints a string scalar to a specified output stream or logging level. The formatting op is exposed at `tf.strings.Format`. A new python method is exposed at `tf.print` that takes a list of inputs that may be nested structures and may contain tensors, formats them nicely using the formatting op, and returns a PrintV2 operator that prints them. In Eager mode and inside defuns this PrintV2 operator will automatically be executed, but in graph mode it will need to be either added to `sess.run`, or used as a control dependency for other operators being executed. As compared to the previous print function, the new print function: - Has an API that more closely aligns with the standard python3 print - Supports changing the print logging level/output stream - allows printing arbitrary (optionally nested) data structures as opposed to just flat lists of tensors - support printing sparse tensors - changes printed tensor format to show more meaningful summary (recursively print the first and last elements of each tensor dimension, instead of just the first few elements of the tensor irregardless of dimension). PiperOrigin-RevId: 213709924
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b7398238f5..c302072aa1 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
+import os
import math
import random
import re
@@ -868,6 +869,19 @@ def device(use_gpu):
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -934,6 +948,52 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.