aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-16 20:50:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 20:54:12 -0800
commit10581c8afee392f2455acb700ece8217a3a19a4b (patch)
treec274dd0a0c90b378b387ac6c63e558a0959353ec
parenta764ec152ce8a4ebe6faf42c55a3177182389c9f (diff)
Rename global_step -> step in contrib/summary API
Since it's more succinct and the API doesn't actually care if the provided step is the one true global step. PiperOrigin-RevId: 176063779
-rw-r--r--tensorflow/contrib/summary/summary_ops.py72
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py4
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc40
-rw-r--r--tensorflow/core/ops/summary_ops.cc24
4 files changed, 73 insertions, 67 deletions
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index bf810744a1..3e65f83051 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -344,10 +344,9 @@ def summary_writer_function(name, tensor, function, family=None):
return op
-def generic(name, tensor, metadata=None, family=None, global_step=None):
+def generic(name, tensor, metadata=None, family=None, step=None):
"""Writes a tensor summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
if metadata is None:
serialized_metadata = constant_op.constant("")
@@ -358,12 +357,15 @@ def generic(name, tensor, metadata=None, family=None, global_step=None):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
- global_step, array_ops.identity(tensor),
- tag, serialized_metadata, name=scope)
+ _choose_step(step),
+ array_ops.identity(tensor),
+ tag,
+ serialized_metadata,
+ name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def scalar(name, tensor, family=None, global_step=None):
+def scalar(name, tensor, family=None, step=None):
"""Writes a scalar summary if possible.
Unlike @{tf.contrib.summary.generic} this op may change the dtype
@@ -375,68 +377,68 @@ def scalar(name, tensor, family=None, global_step=None):
`float32`, `float64`, `int32`, `int64`, `uint8`, `int16`,
`int8`, `uint16`, `half`, `uint32`, `uint64`.
family: Optional, the summary's family.
- global_step: The `int64` monotonic step variable, which defaults
+ step: The `int64` monotonic step variable, which defaults
to @{tf.train.get_global_step}.
Returns:
The created @{tf.Operation} or a @{tf.no_op} if summary writing has
not been enabled for this context.
"""
- if global_step is None:
- global_step = training_util.get_global_step()
- else:
- global_step = ops.convert_to_tensor(global_step, dtypes.int64)
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
name=scope)
+
return summary_writer_function(name, tensor, function, family=family)
-def histogram(name, tensor, family=None, global_step=None):
+def histogram(name, tensor, family=None, step=None):
"""Writes a histogram summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def image(name, tensor, bad_color=None, max_images=3, family=None,
- global_step=None):
+def image(name, tensor, bad_color=None, max_images=3, family=None, step=None):
"""Writes an image summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
if bad_color is None else bad_color)
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
- global_step, tag, array_ops.identity(tensor),
+ _choose_step(step),
+ tag,
+ array_ops.identity(tensor),
bad_color_,
- max_images, name=scope)
+ max_images,
+ name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def audio(name, tensor, sample_rate, max_outputs, family=None,
- global_step=None):
+def audio(name, tensor, sample_rate, max_outputs, family=None, step=None):
"""Writes an audio summary if possible."""
- if global_step is None:
- global_step = training_util.get_global_step()
+
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
- global_step,
+ _choose_step(step),
tag,
array_ops.identity(tensor),
sample_rate=sample_rate,
@@ -483,15 +485,13 @@ def graph(param, step=None, name=None):
if writer is None:
return control_flow_ops.no_op()
with ops.device("cpu:0"):
- if step is None:
- step = training_util.get_global_step()
- else:
- step = ops.convert_to_tensor(step, dtypes.int64)
if isinstance(param, (ops.Graph, graph_pb2.GraphDef)):
tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string)
else:
tensor = array_ops.identity(param)
- return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name)
+ return gen_summary_ops.write_graph_summary(
+ writer, _choose_step(step), tensor, name=name)
+
_graph = graph # for functions with a graph parameter
@@ -527,3 +527,11 @@ def _serialize_graph(arbitrary_graph):
return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString()
else:
return arbitrary_graph.SerializeToString()
+
+
+def _choose_step(step):
+ if step is None:
+ return training_util.get_global_step()
+ if not isinstance(step, ops.Tensor):
+ return ops.convert_to_tensor(step, dtypes.int64)
+ return step
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index c5ca054f77..ad89c0c36a 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -97,13 +97,13 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
def testSummaryGlobalStep(self):
- global_step = training_util.get_or_create_global_step()
+ step = training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_summary_file_writer(
logdir, max_queue=0,
name='t2').as_default(), summary_ops.always_record_summaries():
- summary_ops.scalar('scalar', 2.0, global_step=global_step)
+ summary_ops.scalar('scalar', 2.0, step=step)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index 3706f51cf4..7487e70acc 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -111,8 +111,8 @@ class WriteSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp));
- const int64 global_step = tmp->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
+ const int64 step = tmp->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
const string& tag = tmp->scalar<string>()();
OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp));
@@ -121,8 +121,7 @@ class WriteSummaryOp : public OpKernel {
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
- OP_REQUIRES_OK(ctx,
- s->WriteTensor(global_step, *t, tag, serialized_metadata));
+ OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata));
}
};
REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
@@ -158,15 +157,15 @@ class WriteScalarSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp));
- const int64 global_step = tmp->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
+ const int64 step = tmp->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
const string& tag = tmp->scalar<string>()();
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("value", &t));
- OP_REQUIRES_OK(ctx, s->WriteScalar(global_step, *t, tag));
+ OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag));
}
};
REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU),
@@ -181,15 +180,15 @@ class WriteHistogramSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp));
- const int64 global_step = tmp->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
+ const int64 step = tmp->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
const string& tag = tmp->scalar<string>()();
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("values", &t));
- OP_REQUIRES_OK(ctx, s->WriteHistogram(global_step, *t, tag));
+ OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag));
}
};
REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU),
@@ -210,8 +209,8 @@ class WriteImageSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp));
- const int64 global_step = tmp->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
+ const int64 step = tmp->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
const string& tag = tmp->scalar<string>()();
const Tensor* bad_color;
@@ -224,8 +223,7 @@ class WriteImageSummaryOp : public OpKernel {
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
- OP_REQUIRES_OK(
- ctx, s->WriteImage(global_step, *t, tag, max_images_, *bad_color));
+ OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color));
}
private:
@@ -247,8 +245,8 @@ class WriteAudioSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp));
- const int64 global_step = tmp->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
+ const int64 step = tmp->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
const string& tag = tmp->scalar<string>()();
OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp));
@@ -257,8 +255,8 @@ class WriteAudioSummaryOp : public OpKernel {
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
- OP_REQUIRES_OK(
- ctx, s->WriteAudio(global_step, *t, tag, max_outputs_, sample_rate));
+ OP_REQUIRES_OK(ctx,
+ s->WriteAudio(step, *t, tag, max_outputs_, sample_rate));
}
private:
@@ -278,8 +276,8 @@ class WriteGraphSummaryOp : public OpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* t;
- OP_REQUIRES_OK(ctx, ctx->input("global_step", &t));
- const int64 global_step = t->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("step", &t));
+ const int64 step = t->scalar<int64>()();
OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
std::unique_ptr<GraphDef> graph{new GraphDef};
if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) {
@@ -287,7 +285,7 @@ class WriteGraphSummaryOp : public OpKernel {
errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
return;
}
- OP_REQUIRES_OK(ctx, s->WriteGraph(global_step, std::move(graph)));
+ OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph)));
}
};
REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU),
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
index 7f6d8b06cd..029ff09906 100644
--- a/tensorflow/core/ops/summary_ops.cc
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -99,7 +99,7 @@ writer: A handle to the summary writer resource.
REGISTER_OP("WriteSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tensor: T")
.Input("tag: string")
.Input("summary_metadata: string")
@@ -109,7 +109,7 @@ REGISTER_OP("WriteSummary")
Outputs a `Summary` protocol buffer with a tensor.
writer: A handle to a summary writer.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tensor: A tensor to serialize.
tag: The summary's tag.
summary_metadata: Serialized SummaryMetadata protocol buffer containing
@@ -132,7 +132,7 @@ event: A string containing a binary-encoded tf.Event proto.
REGISTER_OP("WriteScalarSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tag: string")
.Input("value: T")
.Attr("T: realnumbertype")
@@ -143,14 +143,14 @@ Writes a `Summary` protocol buffer with scalar values.
The input `tag` and `value` must have the scalars.
writer: A handle to a summary writer.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tag: Tag for the summary.
value: Value for the summary.
)doc");
REGISTER_OP("WriteHistogramSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tag: string")
.Input("values: T")
.Attr("T: realnumbertype = DT_FLOAT")
@@ -165,14 +165,14 @@ has one summary value containing a histogram for `values`.
This op reports an `InvalidArgument` error if any value is not finite.
writer: A handle to a summary writer.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tag: Scalar. Tag to use for the `Summary.Value`.
values: Any shape. Values to use to build the histogram.
)doc");
REGISTER_OP("WriteImageSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tag: string")
.Input("tensor: T")
.Input("bad_color: uint8")
@@ -217,7 +217,7 @@ replaced by this tensor in the output image. The default value is the color
red.
writer: A handle to a summary writer.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tag: Scalar. Used to build the `tag` attribute of the summary values.
tensor: 4-D of shape `[batch_size, height, width, channels]` where
`channels` is 1, 3, or 4.
@@ -227,7 +227,7 @@ bad_color: Color to use for pixels with non-finite values.
REGISTER_OP("WriteAudioSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tag: string")
.Input("tensor: float")
.Input("sample_rate: float")
@@ -249,7 +249,7 @@ build the `tag` of the summary values:
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
writer: A handle to a summary writer.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tag: Scalar. Used to build the `tag` attribute of the summary values.
tensor: 2-D of shape `[batch_size, frames]`.
sample_rate: The sample rate of the signal in hertz.
@@ -258,14 +258,14 @@ max_outputs: Max number of batch elements to generate audio for.
REGISTER_OP("WriteGraphSummary")
.Input("writer: resource")
- .Input("global_step: int64")
+ .Input("step: int64")
.Input("tensor: string")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Writes a `GraphDef` protocol buffer to a `SummaryWriter`.
writer: Handle of `SummaryWriter`.
-global_step: The step to write the summary for.
+step: The step to write the summary for.
tensor: A scalar string of the serialized tf.GraphDef proto.
)doc");