diff options
author | Sergio Guadarrama <sguada@google.com> | 2017-11-07 16:30:33 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:34 -0800 |
commit | 091ec068a932944a0cd0792a01aa127ed36647e6 (patch) | |
tree | 6ff487527ae37b6ddb223b673d50847901a342a3 /tensorflow/contrib/summary | |
parent | 23bf184564e7842432efb8a66d6d22db4b79205e (diff) |
Allow passing other global_steps to summaries.
PiperOrigin-RevId: 174931874
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 43 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 12 |
2 files changed, 38 insertions, 17 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 56e3198593..9238671c4a 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -57,12 +57,14 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. @tf_contextlib.contextmanager -def record_summaries_every_n_global_steps(n): +def record_summaries_every_n_global_steps(n, global_step=None): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" + if global_step is None: + global_step = training_util.get_global_step() collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] with ops.device("cpu:0"): - collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] + collection_ref[:] = [math_ops.equal(global_step % n, 0)] yield collection_ref[:] = old @@ -204,68 +206,75 @@ def summary_writer_function(name, tensor, function, family=None): return op -def generic(name, tensor, metadata, family=None): +def generic(name, tensor, metadata, family=None, global_step=None): """Writes a tensor summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_summary( context.context().summary_writer_resource, - training_util.get_global_step(), array_ops.identity(tensor), + global_step, array_ops.identity(tensor), tag, metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) -def scalar(name, tensor, family=None): +def scalar(name, tensor, family=None, global_step=None): """Writes a scalar summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), name=scope) return summary_writer_function(name, tensor, function, family=family) -def histogram(name, tensor, family=None): +def histogram(name, tensor, family=None, global_step=None): """Writes a histogram summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), name=scope) return summary_writer_function(name, tensor, function, family=family) -def image(name, tensor, bad_color=None, max_images=3, family=None): +def image(name, tensor, bad_color=None, max_images=3, family=None, + global_step=None): """Writes an image summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) if bad_color is None else bad_color) # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), bad_color_, max_images, name=scope) return summary_writer_function(name, tensor, function, family=family) -def audio(name, tensor, sample_rate, max_outputs, family=None): +def audio(name, tensor, sample_rate, max_outputs, family=None, + global_step=None): """Writes an audio summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, - training_util.get_global_step(), + global_step, tag, array_ops.identity(tensor), sample_rate=sample_rate, diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index de7ae6ec27..466e194096 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -86,6 +86,18 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testSummaryGlobalStep(self): + global_step = training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + summary_ops.scalar('scalar', 2.0, global_step=global_step) + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') if __name__ == '__main__': test.main() |