aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-16 20:50:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 20:54:12 -0800
commit10581c8afee392f2455acb700ece8217a3a19a4b (patch)
treec274dd0a0c90b378b387ac6c63e558a0959353ec /tensorflow/contrib/summary
parenta764ec152ce8a4ebe6faf42c55a3177182389c9f (diff)
Rename global_step -> step in contrib/summary API
Since it's more succinct and the API doesn't actually care if the provided step is the one true global step. PiperOrigin-RevId: 176063779
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/summary_ops.py72
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py4
2 files changed, 42 insertions, 34 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index bf810744a1..3e65f83051 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -344,10 +344,9 @@ def summary_writer_function(name, tensor, function, family=None):
return op
-def generic(name, tensor, metadata=None, family=None, global_step=None):
+def generic(name, tensor, metadata=None, family=None, step=None):
"""Writes a tensor summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
if metadata is None:
serialized_metadata = constant_op.constant("")
@@ -358,12 +357,15 @@ def generic(name, tensor, metadata=None, family=None, global_step=None):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
- global_step, array_ops.identity(tensor),
- tag, serialized_metadata, name=scope)
+ _choose_step(step),
+ array_ops.identity(tensor),
+ tag,
+ serialized_metadata,
+ name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def scalar(name, tensor, family=None, global_step=None):
+def scalar(name, tensor, family=None, step=None):
"""Writes a scalar summary if possible.
Unlike @{tf.contrib.summary.generic} this op may change the dtype
@@ -375,68 +377,68 @@ def scalar(name, tensor, family=None, global_step=None):
`float32`, `float64`, `int32`, `int64`, `uint8`, `int16`,
`int8`, `uint16`, `half`, `uint32`, `uint64`.
family: Optional, the summary's family.
- global_step: The `int64` monotonic step variable, which defaults
+ step: The `int64` monotonic step variable, which defaults
to @{tf.train.get_global_step}.
Returns:
The created @{tf.Operation} or a @{tf.no_op} if summary writing has
not been enabled for this context.
"""
- if global_step is None:
- global_step = training_util.get_global_step()
- else:
- global_step = ops.convert_to_tensor(global_step, dtypes.int64)
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
name=scope)
+
return summary_writer_function(name, tensor, function, family=family)
-def histogram(name, tensor, family=None, global_step=None):
+def histogram(name, tensor, family=None, step=None):
"""Writes a histogram summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def image(name, tensor, bad_color=None, max_images=3, family=None,
- global_step=None):
+def image(name, tensor, bad_color=None, max_images=3, family=None, step=None):
"""Writes an image summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
if bad_color is None else bad_color)
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
bad_color_,
- max_images, name=scope)
+ max_images,
+ name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def audio(name, tensor, sample_rate, max_outputs, family=None,
- global_step=None):
+def audio(name, tensor, sample_rate, max_outputs, family=None, step=None):
"""Writes an audio summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
- global_step,
+ _choose_step(step),
tag,
array_ops.identity(tensor),
sample_rate=sample_rate,
@@ -483,15 +485,13 @@ def graph(param, step=None, name=None):
if writer is None:
return control_flow_ops.no_op()
with ops.device("cpu:0"):
- if step is None:
- step = training_util.get_global_step()
- else:
- step = ops.convert_to_tensor(step, dtypes.int64)
if isinstance(param, (ops.Graph, graph_pb2.GraphDef)):
tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string)
else:
tensor = array_ops.identity(param)
- return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name)
+ return gen_summary_ops.write_graph_summary(
+ writer, _choose_step(step), tensor, name=name)
+
_graph = graph # for functions with a graph parameter
@@ -527,3 +527,11 @@ def _serialize_graph(arbitrary_graph):
return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString()
else:
return arbitrary_graph.SerializeToString()
+
+
+def _choose_step(step):
+ if step is None:
+ return training_util.get_global_step()
+ if not isinstance(step, ops.Tensor):
+ return ops.convert_to_tensor(step, dtypes.int64)
+ return step
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index c5ca054f77..ad89c0c36a 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -97,13 +97,13 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
def testSummaryGlobalStep(self):
- global_step = training_util.get_or_create_global_step()
+ step = training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_summary_file_writer(
logdir, max_queue=0,
name='t2').as_default(), summary_ops.always_record_summaries():
- summary_ops.scalar('scalar', 2.0, global_step=global_step)
+ summary_ops.scalar('scalar', 2.0, step=step)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)