aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-23 09:59:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 10:03:23 -0700
commit4ec6f2b07c08ddab479541cad0c61f169c1f816f (patch)
tree2143df2738a7d5affa73a8ba7a124f839bab9e4e /tensorflow/contrib/summary
parent03b02ffc9e542a7f40d98debd711e537f7f3bb04 (diff)
Switching contrib.summaries API to be context-manager-centric
PiperOrigin-RevId: 173129793
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/summary_ops.py33
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py89
2 files changed, 79 insertions, 43 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index ba3619bfc9..30a9398ee5 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.layers import utils
from tensorflow.python.ops import summary_op_util
from tensorflow.python.training import training_util
+from tensorflow.python.util import tf_contextlib
# Name for a collection which is expected to have at most a single boolean
# Tensor. If this tensor is True the summary ops will record summaries.
@@ -46,22 +47,50 @@ def should_record_summaries():
# TODO(apassos) consider how to handle local step here.
+@tf_contextlib.contextmanager
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ old = collection_ref[:]
collection_ref[:] = [training_util.get_global_step() % n == 0]
+ yield
+ collection_ref[:] = old
+@tf_contextlib.contextmanager
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ old = collection_ref[:]
collection_ref[:] = [True]
+ yield
+ collection_ref[:] = old
+@tf_contextlib.contextmanager
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ old = collection_ref[:]
collection_ref[:] = [False]
+ yield
+ collection_ref[:] = old
+
+
+class SummaryWriter(object):
+
+ def __init__(self, resource):
+ self._resource = resource
+
+ def set_as_default(self):
+ context.context().summary_writer_resource = self._resource
+
+ @tf_contextlib.contextmanager
+ def as_default(self):
+ old = context.context().summary_writer_resource
+ context.context().summary_writer_resource = self._resource
+ yield
+ context.context().summary_writer_resource = old
def create_summary_file_writer(logdir,
@@ -77,9 +106,11 @@ def create_summary_file_writer(logdir,
if filename_suffix is None:
filename_suffix = constant_op.constant("")
resource = gen_summary_ops.summary_writer(shared_name=name)
+ # TODO(apassos) ensure the initialization op runs when in graph mode; consider
+ # calling session.run here.
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix)
- context.context().summary_writer_resource = resource
+ return SummaryWriter(resource)
def _nothing():
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 2cd4fce5b3..405a92a726 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase):
def testShouldRecordSummary(self):
self.assertFalse(summary_ops.should_record_summaries())
- summary_ops.always_record_summaries()
- self.assertTrue(summary_ops.should_record_summaries())
+ with summary_ops.always_record_summaries():
+ self.assertTrue(summary_ops.should_record_summaries())
def testSummaryOps(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
- summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
- summary_ops.always_record_summaries()
- summary_ops.generic('tensor', 1, '')
- summary_ops.scalar('scalar', 2.0)
- summary_ops.histogram('histogram', [1.0])
- summary_ops.image('image', [[[[1.0]]]])
- summary_ops.audio('audio', [[1.0]], 1.0, 1)
- # The working condition of the ops is tested in the C++ test so we just
- # test here that we're calling them correctly.
- self.assertTrue(gfile.Exists(logdir))
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t0').as_default(), summary_ops.always_record_summaries():
+ summary_ops.generic('tensor', 1, '')
+ summary_ops.scalar('scalar', 2.0)
+ summary_ops.histogram('histogram', [1.0])
+ summary_ops.image('image', [[[[1.0]]]])
+ summary_ops.audio('audio', [[1.0]], 1.0, 1)
+ # The working condition of the ops is tested in the C++ test so we just
+ # test here that we're calling them correctly.
+ self.assertTrue(gfile.Exists(logdir))
def testDefunSummarys(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
- summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1')
- summary_ops.always_record_summaries()
-
- @function.defun
- def write():
- summary_ops.scalar('scalar', 2.0)
-
- write()
-
- self.assertTrue(gfile.Exists(logdir))
- files = gfile.ListDirectory(logdir)
- self.assertEqual(len(files), 1)
- records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
- self.assertEqual(len(records), 2)
- event = event_pb2.Event()
- event.ParseFromString(records[1])
- self.assertEqual(event.summary.value[0].simple_value, 2.0)
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t1').as_default(), summary_ops.always_record_summaries():
+
+ @function.defun
+ def write():
+ summary_ops.scalar('scalar', 2.0)
+
+ write()
+
+ self.assertTrue(gfile.Exists(logdir))
+ files = gfile.ListDirectory(logdir)
+ self.assertEqual(len(files), 1)
+ records = list(
+ tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+ self.assertEqual(len(records), 2)
+ event = event_pb2.Event()
+ event.ParseFromString(records[1])
+ self.assertEqual(event.summary.value[0].simple_value, 2.0)
def testSummaryName(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
- summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
- summary_ops.always_record_summaries()
-
- summary_ops.scalar('scalar', 2.0)
-
- self.assertTrue(gfile.Exists(logdir))
- files = gfile.ListDirectory(logdir)
- self.assertEqual(len(files), 1)
- records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
- self.assertEqual(len(records), 2)
- event = event_pb2.Event()
- event.ParseFromString(records[1])
- self.assertEqual(event.summary.value[0].tag, 'scalar')
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t2').as_default(), summary_ops.always_record_summaries():
+
+ summary_ops.scalar('scalar', 2.0)
+
+ self.assertTrue(gfile.Exists(logdir))
+ files = gfile.ListDirectory(logdir)
+ self.assertEqual(len(files), 1)
+ records = list(
+ tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+ self.assertEqual(len(records), 2)
+ event = event_pb2.Event()
+ event.ParseFromString(records[1])
+ self.assertEqual(event.summary.value[0].tag, 'scalar')
if __name__ == '__main__':