aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-09-01 16:33:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-01 16:40:33 -0700
commite83d8ab0959b6a0e50d14acbfd39c5ab79e449d5 (patch)
treef3343d92548ebdd77a2c84f5e9fd0393828d235a /tensorflow/contrib/summary
parentb541b4812289b19cd4d4ba68916ab386878f8f17 (diff)
Contrib ops and kernels for summary ops which write without touching python.
PiperOrigin-RevId: 167340103
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD59
-rw-r--r--tensorflow/contrib/summary/summary_ops.py159
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py52
3 files changed, 270 insertions, 0 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
new file mode 100644
index 0000000000..bc30502264
--- /dev/null
+++ b/tensorflow/contrib/summary/BUILD
@@ -0,0 +1,59 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files([
+ "LICENSE",
+])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "py_test",
+ "tf_gen_op_wrapper_py",
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_summary_ops",
+ out = "gen_summary_ops.py",
+ deps = ["//tensorflow/core:summary_ops_op_lib"],
+)
+
+py_test(
+ name = "summary_ops_test",
+ srcs = ["summary_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":summary_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_library(
+ name = "summary_ops",
+ srcs = ["summary_ops.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":gen_summary_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:summary_op_util",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
new file mode 100644
index 0000000000..05e627adf1
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -0,0 +1,159 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Operations to emit summaries."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.summary import gen_summary_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_op_util
+from tensorflow.python.training import training_util
+
+
+# Name for a collection which is expected to have at most a single boolean
+# Tensor. If this tensor is True the summary ops will record summaries.
+_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
+
+
+def should_record_summaries():
+ """Returns boolean Tensor which is true if summaries should be recorded."""
+ should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
+ if not should_record_collection:
+ return constant_op.constant(False)
+ if len(should_record_collection) != 1:
+ raise ValueError(
+ "More than one tensor specified for whether summaries "
+ "should be recorded: %s" % should_record_collection)
+ return should_record_collection[0]
+
+
+# TODO(apassos) consider how to handle local step here.
+def record_summaries_every_n_global_steps(n):
+ """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
+ collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ collection_ref[:] = [training_util.get_global_step() % n == 0]
+
+
+def always_record_summaries():
+ """Sets the should_record_summaries Tensor to always true."""
+ collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ collection_ref[:] = [constant_op.constant(True)]
+
+
+def never_record_summaries():
+ """Sets the should_record_summaries Tensor to always false."""
+ collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+ collection_ref[:] = [constant_op.constant(False)]
+
+
+def create_summary_file_writer(logdir,
+ max_queue=None,
+ flush_secs=None,
+ filename_suffix=None):
+ """Creates a summary file writer in the current context."""
+ if max_queue is None:
+ max_queue = constant_op.constant(10)
+ if flush_secs is None:
+ flush_secs = constant_op.constant(120)
+ if filename_suffix is None:
+ filename_suffix = constant_op.constant("")
+ resource = gen_summary_ops.summary_writer()
+ gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
+ flush_secs, filename_suffix)
+ context.context().summary_writer_resource = resource
+
+
+def _nothing():
+ """Convenient else branch for when summaries do not record."""
+ return
+
+
+def generic(name, tensor, metadata, family=None):
+ """Writes a tensor summary if possible."""
+
+ def record():
+ with summary_op_util.summary_scope(
+ name, family, values=[tensor]) as (tag, scope):
+ gen_summary_ops.write_summary(context.context().summary_writer_resource,
+ training_util.get_global_step(), tensor,
+ tag, metadata, name=scope)
+ return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+
+
+def scalar(name, tensor, family=None):
+ """Writes a scalar summary if possible."""
+
+ def record():
+ with summary_op_util.summary_scope(
+ name, family, values=[tensor]) as (tag, scope):
+ gen_summary_ops.write_scalar_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, name=scope)
+
+ return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+
+
+def histogram(name, tensor, family=None):
+ """Writes a histogram summary if possible."""
+
+ def record():
+ with summary_op_util.summary_scope(
+ name, family, values=[tensor]) as (tag, scope):
+ gen_summary_ops.write_histogram_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, name=scope)
+
+ return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+
+
+def image(name, tensor, bad_color=None, max_images=3, family=None):
+ """Writes an image summary if possible."""
+
+ def record():
+ if bad_color is None:
+ bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
+ with summary_op_util.summary_scope(
+ name, family, values=[tensor]) as (tag, scope):
+ gen_summary_ops.write_image_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, bad_color_, max_images,
+ name=scope)
+
+ return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+
+
+def audio(name, tensor, sample_rate, max_outputs, family=None):
+ """Writes an audio summary if possible."""
+
+ def record():
+ with summary_op_util.summary_scope(
+ name, family, values=[tensor]) as (tag, scope):
+ gen_summary_ops.write_audio_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(),
+ tag,
+ tensor,
+ sample_rate=sample_rate,
+ max_outputs=max_outputs,
+ name=scope)
+
+ return control_flow_ops.cond(should_record_summaries(), record, _nothing)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
new file mode 100644
index 0000000000..56c1a16f7f
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+from tensorflow.python.training import training_util
+
+
+class TargetTest(test_util.TensorFlowTestCase):
+
+ def testShouldRecordSummary(self):
+ self.assertFalse(summary_ops.should_record_summaries().numpy())
+ summary_ops.always_record_summaries()
+ self.assertTrue(summary_ops.should_record_summaries().numpy())
+
+ def testSummaryOps(self):
+ training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ summary_ops.create_summary_file_writer(logdir, max_queue=0)
+ summary_ops.always_record_summaries()
+ summary_ops.generic('tensor', 1, '')
+ summary_ops.scalar('scalar', 2.0)
+ summary_ops.histogram('histogram', [1.0])
+ summary_ops.image('image', [[[[1.0]]]])
+ summary_ops.audio('audio', [[1.0]], 1.0, 1)
+ # The working condition of the ops is tested in the C++ test so we just
+ # test here that we're calling them correctly.
+ self.assertTrue(gfile.Exists(logdir))
+
+
+if __name__ == '__main__':
+ test.main()