diff options
author | 2017-09-01 16:33:24 -0700 | |
---|---|---|
committer | 2017-09-01 16:40:33 -0700 | |
commit | e83d8ab0959b6a0e50d14acbfd39c5ab79e449d5 (patch) | |
tree | f3343d92548ebdd77a2c84f5e9fd0393828d235a /tensorflow/contrib/summary | |
parent | b541b4812289b19cd4d4ba68916ab386878f8f17 (diff) |
Contrib ops and kernels for summary ops which write without touching python.
PiperOrigin-RevId: 167340103
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 59 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 159 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 52 |
3 files changed, 270 insertions, 0 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD new file mode 100644 index 0000000000..bc30502264 --- /dev/null +++ b/tensorflow/contrib/summary/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", +]) + +load( + "//tensorflow:tensorflow.bzl", + "py_test", + "tf_gen_op_wrapper_py", +) + +tf_gen_op_wrapper_py( + name = "gen_summary_ops", + out = "gen_summary_ops.py", + deps = ["//tensorflow/core:summary_ops_op_lib"], +) + +py_test( + name = "summary_ops_test", + srcs = ["summary_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "summary_ops", + srcs = ["summary_ops.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":gen_summary_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_op_util", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py new file mode 100644 index 0000000000..05e627adf1 --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops.py @@ -0,0 +1,159 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Operations to emit summaries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_op_util +from tensorflow.python.training import training_util + + +# Name for a collection which is expected to have at most a single boolean +# Tensor. If this tensor is True the summary ops will record summaries. +_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" + + +def should_record_summaries(): + """Returns boolean Tensor which is true if summaries should be recorded.""" + should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) + if not should_record_collection: + return constant_op.constant(False) + if len(should_record_collection) != 1: + raise ValueError( + "More than one tensor specified for whether summaries " + "should be recorded: %s" % should_record_collection) + return should_record_collection[0] + + +# TODO(apassos) consider how to handle local step here. +def record_summaries_every_n_global_steps(n): + """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [training_util.get_global_step() % n == 0] + + +def always_record_summaries(): + """Sets the should_record_summaries Tensor to always true.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [constant_op.constant(True)] + + +def never_record_summaries(): + """Sets the should_record_summaries Tensor to always false.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [constant_op.constant(False)] + + +def create_summary_file_writer(logdir, + max_queue=None, + flush_secs=None, + filename_suffix=None): + """Creates a summary file writer in the current context.""" + if max_queue is None: + max_queue = constant_op.constant(10) + if flush_secs is None: + flush_secs = constant_op.constant(120) + if filename_suffix is None: + filename_suffix = constant_op.constant("") + resource = gen_summary_ops.summary_writer() + gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, + flush_secs, filename_suffix) + context.context().summary_writer_resource = resource + + +def _nothing(): + """Convenient else branch for when summaries do not record.""" + return + + +def generic(name, tensor, metadata, family=None): + """Writes a tensor summary if possible.""" + + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_summary(context.context().summary_writer_resource, + training_util.get_global_step(), tensor, + tag, metadata, name=scope) + return control_flow_ops.cond(should_record_summaries(), record, _nothing) + + +def scalar(name, tensor, family=None): + """Writes a scalar summary if possible.""" + + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_scalar_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) + + +def histogram(name, tensor, family=None): + """Writes a histogram summary if possible.""" + + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_histogram_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) + + +def image(name, tensor, bad_color=None, max_images=3, family=None): + """Writes an image summary if possible.""" + + def record(): + if bad_color is None: + bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_image_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, bad_color_, max_images, + name=scope) + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) + + +def audio(name, tensor, sample_rate, max_outputs, family=None): + """Writes an audio summary if possible.""" + + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_audio_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), + tag, + tensor, + sample_rate=sample_rate, + max_outputs=max_outputs, + name=scope) + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py new file mode 100644 index 0000000000..56c1a16f7f --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile +from tensorflow.python.training import training_util + + +class TargetTest(test_util.TensorFlowTestCase): + + def testShouldRecordSummary(self): + self.assertFalse(summary_ops.should_record_summaries().numpy()) + summary_ops.always_record_summaries() + self.assertTrue(summary_ops.should_record_summaries().numpy()) + + def testSummaryOps(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0) + summary_ops.always_record_summaries() + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) + + +if __name__ == '__main__': + test.main() |