aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Sergio Guadarrama <sguada@google.com>2017-11-07 16:30:33 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:34 -0800
commit091ec068a932944a0cd0792a01aa127ed36647e6 (patch)
tree6ff487527ae37b6ddb223b673d50847901a342a3 /tensorflow/contrib/summary
parent23bf184564e7842432efb8a66d6d22db4b79205e (diff)
Allow passing other global_steps to summaries.
PiperOrigin-RevId: 174931874
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/summary_ops.py43
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py12
2 files changed, 38 insertions, 17 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 56e3198593..9238671c4a 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -57,12 +57,14 @@ def should_record_summaries():
# TODO(apassos) consider how to handle local step here.
@tf_contextlib.contextmanager
-def record_summaries_every_n_global_steps(n):
+def record_summaries_every_n_global_steps(n, global_step=None):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
+ if global_step is None:
+ global_step = training_util.get_global_step()
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
with ops.device("cpu:0"):
- collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
+ collection_ref[:] = [math_ops.equal(global_step % n, 0)]
yield
collection_ref[:] = old
@@ -204,68 +206,75 @@ def summary_writer_function(name, tensor, function, family=None):
return op
-def generic(name, tensor, metadata, family=None):
+def generic(name, tensor, metadata, family=None, global_step=None):
"""Writes a tensor summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), array_ops.identity(tensor),
+ global_step, array_ops.identity(tensor),
tag, metadata, name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def scalar(name, tensor, family=None):
+def scalar(name, tensor, family=None, global_step=None):
"""Writes a scalar summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def histogram(name, tensor, family=None):
+def histogram(name, tensor, family=None, global_step=None):
"""Writes a histogram summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def image(name, tensor, bad_color=None, max_images=3, family=None):
+def image(name, tensor, bad_color=None, max_images=3, family=None,
+ global_step=None):
"""Writes an image summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
if bad_color is None else bad_color)
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
bad_color_,
max_images, name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def audio(name, tensor, sample_rate, max_outputs, family=None):
+def audio(name, tensor, sample_rate, max_outputs, family=None,
+ global_step=None):
"""Writes an audio summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(),
+ global_step,
tag,
array_ops.identity(tensor),
sample_rate=sample_rate,
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index de7ae6ec27..466e194096 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -86,6 +86,18 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
+ def testSummaryGlobalStep(self):
+ global_step = training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t2').as_default(), summary_ops.always_record_summaries():
+
+ summary_ops.scalar('scalar', 2.0, global_step=global_step)
+
+ events = summary_test_util.events_from_file(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].tag, 'scalar')
if __name__ == '__main__':
test.main()