diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-25 18:46:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-25 18:50:51 -0700 |
commit | ff7b9a6c496823c1bffdd0d74bf68aafacb8caca (patch) | |
tree | 266b5d9dff2b14abd959ab0efef2fe56bfb7c30c /tensorflow/contrib/summary | |
parent | 6149fecbdba96fea5460915cf2fad5ac163de091 (diff) |
Adding summaries to the resnet example.
Also utilities to use summaries in graph mode.
PiperOrigin-RevId: 173483424
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 25 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary.py | 39 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 82 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_test_util.py | 41 |
5 files changed, 175 insertions, 41 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index bcb2d74b4a..8cb5c3f381 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,6 +25,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", @@ -52,6 +53,16 @@ py_library( ], ) +py_library( + name = "summary", + srcs = ["summary.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":summary_ops", + ], +) + filegroup( name = "all_files", srcs = glob( @@ -63,3 +74,17 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# NOTE: target cannot be testonly because it needs to be in the pip +# package. Sigh. +py_library( + name = "summary_test_util", + srcs = ["summary_test_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py new file mode 100644 index 0000000000..89031caadc --- /dev/null +++ b/tensorflow/contrib/summary/summary.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contrib summary package. + +The operations in this package are safe to use with eager execution turned or on +off. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.summary.summary_ops import all_summary_ops +from tensorflow.contrib.summary.summary_ops import always_record_summaries +from tensorflow.contrib.summary.summary_ops import audio +from tensorflow.contrib.summary.summary_ops import create_summary_file_writer +from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import histogram +from tensorflow.contrib.summary.summary_ops import image +from tensorflow.contrib.summary.summary_ops import never_record_summaries +from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps +from tensorflow.contrib.summary.summary_ops import scalar +from tensorflow.contrib.summary.summary_ops import should_record_summaries +from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 30a9398ee5..b32b093675 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -25,6 +25,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util from tensorflow.python.util import tf_contextlib @@ -33,6 +35,9 @@ from tensorflow.python.util import tf_contextlib # Tensor. If this tensor is True the summary ops will record summaries. _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" +_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" +_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" + def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" @@ -78,10 +83,15 @@ def never_record_summaries(): class SummaryWriter(object): + """Encapsulates a summary writer.""" def __init__(self, resource): self._resource = resource + def __del__(self): + if context.in_eager_mode(): + resource_variable_ops.destroy_resource_op(self._resource) + def set_as_default(self): context.context().summary_writer_resource = self._resource @@ -90,6 +100,9 @@ class SummaryWriter(object): old = context.context().summary_writer_resource context.context().summary_writer_resource = self._resource yield + # Flushes the summary writer in eager mode or in graph functions, but not in + # legacy graph mode (you're on your own there). + gen_summary_ops.flush_summary_writer(self._resource) context.context().summary_writer_resource = old @@ -108,14 +121,33 @@ def create_summary_file_writer(logdir, resource = gen_summary_ops.summary_writer(shared_name=name) # TODO(apassos) ensure the initialization op runs when in graph mode; consider # calling session.run here. - gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, - flush_secs, filename_suffix) + ops.add_to_collection( + _SUMMARY_WRITER_INIT_COLLECTION_NAME, + gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, + flush_secs, filename_suffix)) return SummaryWriter(resource) def _nothing(): """Convenient else branch for when summaries do not record.""" - return False + return constant_op.constant(False) + + +def all_summary_ops(): + """Graph-mode only. Returns all summary ops.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return ops.get_collection(_SUMMARY_COLLECTION_NAME) + + +def summary_writer_initializer_op(): + """Graph-mode only. Returns the list of ops to create all summary writers.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.summary_writer_initializer_op is only " + "supported in graph mode.") + return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME) def summary_writer_function(name, tensor, function, family=None): @@ -133,20 +165,25 @@ def summary_writer_function(name, tensor, function, family=None): def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True + with ops.control_dependencies([function(tag, scope)]): + return constant_op.constant(True) - return utils.smart_cond( - should_record_summaries(), record, _nothing, name="") + with ops.device("cpu:0"): + op = utils.smart_cond( + should_record_summaries(), record, _nothing, name="") + ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op) + return op def generic(name, tensor, metadata, family=None): """Writes a tensor summary if possible.""" def function(tag, scope): - gen_summary_ops.write_summary(context.context().summary_writer_resource, - training_util.get_global_step(), tensor, - tag, metadata, name=scope) + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), array_ops.identity(tensor), + tag, metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -154,9 +191,11 @@ def scalar(name, tensor, family=None): """Writes a scalar summary if possible.""" def function(tag, scope): - gen_summary_ops.write_scalar_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -165,9 +204,11 @@ def histogram(name, tensor, family=None): """Writes a histogram summary if possible.""" def function(tag, scope): - gen_summary_ops.write_histogram_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -178,10 +219,12 @@ def image(name, tensor, bad_color=None, max_images=3, family=None): def function(tag, scope): if bad_color is None: bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - gen_summary_ops.write_image_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, bad_color_, max_images, - name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + bad_color_, + max_images, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -190,11 +233,12 @@ def audio(name, tensor, sample_rate, max_outputs, family=None): """Writes an audio summary if possible.""" def function(tag, scope): - gen_summary_ops.write_audio_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, training_util.get_global_step(), tag, - tensor, + array_ops.identity(tensor), sample_rate=sample_rate, max_outputs=max_outputs, name=scope) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 405a92a726..de7ae6ec27 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import tempfile from tensorflow.contrib.summary import summary_ops -from tensorflow.core.util import event_pb2 +from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import errors from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -71,16 +69,9 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) write() - - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 2.0) def testSummaryName(self): training_util.get_or_create_global_step() @@ -91,15 +82,9 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].tag, 'scalar') + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') if __name__ == '__main__': diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py new file mode 100644 index 0000000000..37b546d3ab --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities to test summaries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.util import event_pb2 +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import gfile + + +def events_from_file(logdir): + """Returns all events in the single eventfile in logdir.""" + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found more than one file in logdir: %s" % files + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + result = [] + for r in records: + event = event_pb2.Event() + event.ParseFromString(r) + result.append(event) + return result |