diff options
author | Justine Tunney <jart@google.com> | 2017-11-16 20:50:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-16 20:54:12 -0800 |
commit | 10581c8afee392f2455acb700ece8217a3a19a4b (patch) | |
tree | c274dd0a0c90b378b387ac6c63e558a0959353ec /tensorflow/contrib/summary | |
parent | a764ec152ce8a4ebe6faf42c55a3177182389c9f (diff) |
Rename global_step -> step in contrib/summary API
Since it's more succinct and the API doesn't actually care if the provided step
is the one true global step.
PiperOrigin-RevId: 176063779
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 72 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 4 |
2 files changed, 42 insertions, 34 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index bf810744a1..3e65f83051 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -344,10 +344,9 @@ def summary_writer_function(name, tensor, function, family=None): return op -def generic(name, tensor, metadata=None, family=None, global_step=None): +def generic(name, tensor, metadata=None, family=None, step=None): """Writes a tensor summary if possible.""" - if global_step is None: - global_step = training_util.get_global_step() + def function(tag, scope): if metadata is None: serialized_metadata = constant_op.constant("") @@ -358,12 +357,15 @@ def generic(name, tensor, metadata=None, family=None, global_step=None): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_summary( context.context().summary_writer_resource, - global_step, array_ops.identity(tensor), - tag, serialized_metadata, name=scope) + _choose_step(step), + array_ops.identity(tensor), + tag, + serialized_metadata, + name=scope) return summary_writer_function(name, tensor, function, family=family) -def scalar(name, tensor, family=None, global_step=None): +def scalar(name, tensor, family=None, step=None): """Writes a scalar summary if possible. Unlike @{tf.contrib.summary.generic} this op may change the dtype @@ -375,68 +377,68 @@ def scalar(name, tensor, family=None, global_step=None): `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`, `uint16`, `half`, `uint32`, `uint64`. family: Optional, the summary's family. - global_step: The `int64` monotonic step variable, which defaults + step: The `int64` monotonic step variable, which defaults to @{tf.train.get_global_step}. Returns: The created @{tf.Operation} or a @{tf.no_op} if summary writing has not been enabled for this context. """ - if global_step is None: - global_step = training_util.get_global_step() - else: - global_step = ops.convert_to_tensor(global_step, dtypes.int64) + def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - global_step, tag, array_ops.identity(tensor), + _choose_step(step), + tag, + array_ops.identity(tensor), name=scope) + return summary_writer_function(name, tensor, function, family=family) -def histogram(name, tensor, family=None, global_step=None): +def histogram(name, tensor, family=None, step=None): """Writes a histogram summary if possible.""" - if global_step is None: - global_step = training_util.get_global_step() + def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - global_step, tag, array_ops.identity(tensor), + _choose_step(step), + tag, + array_ops.identity(tensor), name=scope) return summary_writer_function(name, tensor, function, family=family) -def image(name, tensor, bad_color=None, max_images=3, family=None, - global_step=None): +def image(name, tensor, bad_color=None, max_images=3, family=None, step=None): """Writes an image summary if possible.""" - if global_step is None: - global_step = training_util.get_global_step() + def function(tag, scope): bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) if bad_color is None else bad_color) # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - global_step, tag, array_ops.identity(tensor), + _choose_step(step), + tag, + array_ops.identity(tensor), bad_color_, - max_images, name=scope) + max_images, + name=scope) return summary_writer_function(name, tensor, function, family=family) -def audio(name, tensor, sample_rate, max_outputs, family=None, - global_step=None): +def audio(name, tensor, sample_rate, max_outputs, family=None, step=None): """Writes an audio summary if possible.""" - if global_step is None: - global_step = training_util.get_global_step() + def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, - global_step, + _choose_step(step), tag, array_ops.identity(tensor), sample_rate=sample_rate, @@ -483,15 +485,13 @@ def graph(param, step=None, name=None): if writer is None: return control_flow_ops.no_op() with ops.device("cpu:0"): - if step is None: - step = training_util.get_global_step() - else: - step = ops.convert_to_tensor(step, dtypes.int64) if isinstance(param, (ops.Graph, graph_pb2.GraphDef)): tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string) else: tensor = array_ops.identity(param) - return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name) + return gen_summary_ops.write_graph_summary( + writer, _choose_step(step), tensor, name=name) + _graph = graph # for functions with a graph parameter @@ -527,3 +527,11 @@ def _serialize_graph(arbitrary_graph): return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString() else: return arbitrary_graph.SerializeToString() + + +def _choose_step(step): + if step is None: + return training_util.get_global_step() + if not isinstance(step, ops.Tensor): + return ops.convert_to_tensor(step, dtypes.int64) + return step diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index c5ca054f77..ad89c0c36a 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -97,13 +97,13 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(events[1].summary.value[0].tag, 'scalar') def testSummaryGlobalStep(self): - global_step = training_util.get_or_create_global_step() + step = training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() with summary_ops.create_summary_file_writer( logdir, max_queue=0, name='t2').as_default(), summary_ops.always_record_summaries(): - summary_ops.scalar('scalar', 2.0, global_step=global_step) + summary_ops.scalar('scalar', 2.0, step=step) events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) |