diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-23 09:59:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-23 10:03:23 -0700 |
commit | 4ec6f2b07c08ddab479541cad0c61f169c1f816f (patch) | |
tree | 2143df2738a7d5affa73a8ba7a124f839bab9e4e /tensorflow/contrib/summary | |
parent | 03b02ffc9e542a7f40d98debd711e537f7f3bb04 (diff) |
Switching contrib.summaries API to be context-manager-centric
PiperOrigin-RevId: 173129793
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 33 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 89 |
2 files changed, 79 insertions, 43 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ba3619bfc9..30a9398ee5 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import utils from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util +from tensorflow.python.util import tf_contextlib # Name for a collection which is expected to have at most a single boolean # Tensor. If this tensor is True the summary ops will record summaries. @@ -46,22 +47,50 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. +@tf_contextlib.contextmanager def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [training_util.get_global_step() % n == 0] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [True] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [False] + yield + collection_ref[:] = old + + +class SummaryWriter(object): + + def __init__(self, resource): + self._resource = resource + + def set_as_default(self): + context.context().summary_writer_resource = self._resource + + @tf_contextlib.contextmanager + def as_default(self): + old = context.context().summary_writer_resource + context.context().summary_writer_resource = self._resource + yield + context.context().summary_writer_resource = old def create_summary_file_writer(logdir, @@ -77,9 +106,11 @@ def create_summary_file_writer(logdir, if filename_suffix is None: filename_suffix = constant_op.constant("") resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos) ensure the initialization op runs when in graph mode; consider + # calling session.run here. gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, flush_secs, filename_suffix) - context.context().summary_writer_resource = resource + return SummaryWriter(resource) def _nothing(): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 2cd4fce5b3..405a92a726 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase): def testShouldRecordSummary(self): self.assertFalse(summary_ops.should_record_summaries()) - summary_ops.always_record_summaries() - self.assertTrue(summary_ops.should_record_summaries()) + with summary_ops.always_record_summaries(): + self.assertTrue(summary_ops.should_record_summaries()) def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') - summary_ops.always_record_summaries() - summary_ops.generic('tensor', 1, '') - summary_ops.scalar('scalar', 2.0) - summary_ops.histogram('histogram', [1.0]) - summary_ops.image('image', [[[[1.0]]]]) - summary_ops.audio('audio', [[1.0]], 1.0, 1) - # The working condition of the ops is tested in the C++ test so we just - # test here that we're calling them correctly. - self.assertTrue(gfile.Exists(logdir)) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') - summary_ops.always_record_summaries() - - @function.defun - def write(): - summary_ops.scalar('scalar', 2.0) - - write() - - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t1').as_default(), summary_ops.always_record_summaries(): + + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) + + write() + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 2.0) def testSummaryName(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2') - summary_ops.always_record_summaries() - - summary_ops.scalar('scalar', 2.0) - - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].tag, 'scalar') + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + summary_ops.scalar('scalar', 2.0) + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].tag, 'scalar') if __name__ == '__main__': |