aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary
diff options
context:
space:
mode:
authorGravatar Dandelion Mané <dandelion@google.com>2017-02-22 13:31:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-22 13:46:21 -0800
commit4b5be2239be05cc5def888f6ce3e484cf68be5b3 (patch)
tree63a09cea318c061c0df3a8a619de47f866e28a01 /tensorflow/python/summary
parentf2843b0831653020ca6e1ceb7011f916e2e8f706 (diff)
Add support for Tensors to the EventAccumulator.
After this CL, it is possible to access TensorSummary data via the EventAccumulator and EventMultiplexer classes. This is a first step towards adding full TensorSummary support to TensorBoard. Change: 148261489
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r--tensorflow/python/summary/event_accumulator.py132
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py120
-rw-r--r--tensorflow/python/summary/event_multiplexer.py18
-rw-r--r--tensorflow/python/summary/event_multiplexer_test.py26
4 files changed, 149 insertions, 147 deletions
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
index 309f2b4e15..23408705bd 100644
--- a/tensorflow/python/summary/event_accumulator.py
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -28,9 +28,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf.config_pb2 import RunMetadata
from tensorflow.core.util.event_pb2 import SessionLog
-from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.summary import summary
from tensorflow.python.summary.impl import directory_watcher
from tensorflow.python.summary.impl import event_file_loader
from tensorflow.python.summary.impl import reservoir
@@ -65,11 +63,16 @@ AudioEvent = namedtuple('AudioEvent', ['wall_time', 'step',
'encoded_audio_string', 'content_type',
'sample_rate', 'length_frames'])
+TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto'])
+
## Different types of summary events handled by the event_accumulator
-SUMMARY_TYPES = {'simple_value': '_ProcessScalar',
- 'histo': '_ProcessHistogram',
- 'image': '_ProcessImage',
- 'audio': '_ProcessAudio'}
+SUMMARY_TYPES = {
+ 'simple_value': '_ProcessScalar',
+ 'histo': '_ProcessHistogram',
+ 'image': '_ProcessImage',
+ 'audio': '_ProcessAudio',
+ 'tensor': '_ProcessTensor',
+}
## The tagTypes below are just arbitrary strings chosen to pass the type
## information of the tag from the backend to the frontend
@@ -78,6 +81,7 @@ HISTOGRAMS = 'histograms'
IMAGES = 'images'
AUDIO = 'audio'
SCALARS = 'scalars'
+TENSORS = 'tensors'
HEALTH_PILLS = 'health_pills'
GRAPH = 'graph'
META_GRAPH = 'meta_graph'
@@ -96,6 +100,7 @@ DEFAULT_SIZE_GUIDANCE = {
# We store this many health pills per op.
HEALTH_PILLS: 100,
HISTOGRAMS: 1,
+ TENSORS: 10,
}
STORE_EVERYTHING_SIZE_GUIDANCE = {
@@ -105,6 +110,7 @@ STORE_EVERYTHING_SIZE_GUIDANCE = {
SCALARS: 0,
HEALTH_PILLS: 0,
HISTOGRAMS: 0,
+ TENSORS: 0,
}
# The tag that values containing health pills have. Health pill data is stored
@@ -151,6 +157,7 @@ class EventAccumulator(object):
Histograms, audio, and images are very large, so storing all of them is not
recommended.
+ @@Tensors
"""
def __init__(self,
@@ -199,6 +206,7 @@ class EventAccumulator(object):
size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False)
self._images = reservoir.Reservoir(size=sizes[IMAGES])
self._audio = reservoir.Reservoir(size=sizes[AUDIO])
+ self._tensors = reservoir.Reservoir(size=sizes[TENSORS])
self._generator_mutex = threading.Lock()
self._generator = _GeneratorFromPath(path)
@@ -285,7 +293,6 @@ class EventAccumulator(object):
'newest event.'))
self._graph = event.graph_def
self._graph_from_metagraph = False
- self._UpdateTensorSummaries()
elif event.HasField('meta_graph_def'):
if self._meta_graph is not None:
logging.warn(('Found more than one metagraph event per run. '
@@ -303,7 +310,6 @@ class EventAccumulator(object):
'graph with the newest metagraph version.'))
self._graph_from_metagraph = True
self._graph = meta_graph.graph_def.SerializeToString()
- self._UpdateTensorSummaries()
elif event.HasField('tagged_run_metadata'):
tag = event.tagged_run_metadata.tag
if tag in self._tagged_metadata:
@@ -312,66 +318,15 @@ class EventAccumulator(object):
self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
elif event.HasField('summary'):
for value in event.summary.value:
- if value.HasField('tensor'):
- if value.tag == _HEALTH_PILL_EVENT_TAG:
- self._ProcessHealthPillSummary(value, event)
- else:
- self._ProcessTensorSummary(value, event)
+ if value.HasField('tensor') and value.tag == _HEALTH_PILL_EVENT_TAG:
+ self._ProcessHealthPillSummary(value, event)
else:
for summary_type, summary_func in SUMMARY_TYPES.items():
if value.HasField(summary_type):
datum = getattr(value, summary_type)
- getattr(self, summary_func)(value.tag, event.wall_time,
- event.step, datum)
-
- def _ProcessTensorSummary(self, value, event):
- """Process summaries generated by the TensorSummary op.
-
- These summaries are distinguished by the fact that they have a Tensor field,
- rather than one of the old idiosyncratic per-summary data fields.
-
- Processing Tensor summaries is complicated by the fact that Tensor summaries
- are not self-descriptive; you need to read the NodeDef of the corresponding
- TensorSummary op to know the summary_type, the tag, etc.
-
- This method emits ERROR-level messages to the logs if it encounters Tensor
- summaries that it cannot process.
-
- Args:
- value: A summary_pb2.Summary.Value with a Tensor field.
- event: The event_pb2.Event containing that value.
- """
-
- def LogErrorOnce(msg):
- logging.log_first_n(logging.ERROR, msg, 1)
-
- name = value.node_name
- if self._graph is None:
- LogErrorOnce('Attempting to process TensorSummary output, but '
- 'no graph is present, so processing is impossible. '
- 'All TensorSummary output will be ignored.')
- return
-
- if name not in self._tensor_summaries:
- LogErrorOnce('No node_def for TensorSummary {}; skipping this sequence.'.
- format(name))
- return
-
- summary_description = self._tensor_summaries[name]
- type_hint = summary_description.type_hint
-
- if not type_hint:
- LogErrorOnce('No type_hint for TensorSummary {}; skipping this sequence.'.
- format(name))
- return
-
- if type_hint == 'scalar':
- scalar = float(tensor_util.MakeNdarray(value.tensor))
- self._ProcessScalar(name, event.wall_time, event.step, scalar)
- else:
- LogErrorOnce(
- 'Unsupported type {} for TensorSummary {}; skipping this sequence.'.
- format(type_hint, name))
+ tag = value.node_name if summary_type == 'tensor' else value.tag
+ getattr(self, summary_func)(tag, event.wall_time, event.step,
+ datum)
def _ProcessHealthPillSummary(self, value, event):
"""Process summaries containing health pills.
@@ -405,30 +360,25 @@ class EventAccumulator(object):
self._ProcessHealthPill(
event.wall_time, event.step, node_name, output_slot, elements)
- def _UpdateTensorSummaries(self):
- g = self.Graph()
- for node in g.node:
- if node.op == 'TensorSummary':
- d = summary.get_summary_description(node)
-
- self._tensor_summaries[node.name] = d
-
def Tags(self):
"""Return all tags found in the value stream.
Returns:
A `{tagType: ['list', 'of', 'tags']}` dictionary.
"""
- return {IMAGES: self._images.Keys(),
- AUDIO: self._audio.Keys(),
- HISTOGRAMS: self._histograms.Keys(),
- SCALARS: self._scalars.Keys(),
- COMPRESSED_HISTOGRAMS: self._compressed_histograms.Keys(),
- # Use a heuristic: if the metagraph is available, but
- # graph is not, then we assume the metagraph contains the graph.
- GRAPH: self._graph is not None,
- META_GRAPH: self._meta_graph is not None,
- RUN_METADATA: list(self._tagged_metadata.keys())}
+ return {
+ IMAGES: self._images.Keys(),
+ AUDIO: self._audio.Keys(),
+ HISTOGRAMS: self._histograms.Keys(),
+ SCALARS: self._scalars.Keys(),
+ COMPRESSED_HISTOGRAMS: self._compressed_histograms.Keys(),
+ TENSORS: self._tensors.Keys(),
+ # Use a heuristic: if the metagraph is available, but
+ # graph is not, then we assume the metagraph contains the graph.
+ GRAPH: self._graph is not None,
+ META_GRAPH: self._meta_graph is not None,
+ RUN_METADATA: list(self._tagged_metadata.keys())
+ }
def Scalars(self, tag):
"""Given a summary tag, return all associated `ScalarEvent`s.
@@ -566,6 +516,20 @@ class EventAccumulator(object):
"""
return self._audio.Items(tag)
+ def Tensors(self, tag):
+ """Given a summary tag, return all associated tensors.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `TensorEvent`s.
+ """
+ return self._tensors.Items(tag)
+
def _MaybePurgeOrphanedData(self, event):
"""Maybe purge orphaned data due to a TensorFlow crash.
@@ -668,6 +632,10 @@ class EventAccumulator(object):
sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
self._scalars.AddItem(tag, sv)
+ def _ProcessTensor(self, tag, wall_time, step, tensor):
+ tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
+ self._tensors.AddItem(tag, tv)
+
def _ProcessHealthPill(self, wall_time, step, node_name, output_slot,
elements):
"""Processes a health pill value by adding it to accumulated state.
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index 2c4ee558ec..49b1115624 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
@@ -29,6 +30,7 @@ from tensorflow.core.util import event_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
@@ -156,26 +158,19 @@ class _EventGenerator(object):
class EventAccumulatorTest(test.TestCase):
- def assertTagsEqual(self, tags1, tags2):
- # Make sure the two dictionaries have the same keys.
- self.assertItemsEqual(tags1, tags2)
- # Additionally, make sure each key in the dictionary maps to the same value.
- for key in tags1:
- if isinstance(tags1[key], list):
- # We don't care about the order of the values in lists, thus asserting
- # only if the items are equal.
- self.assertItemsEqual(tags1[key], tags2[key])
- else:
- # Make sure the values are equal.
- self.assertEqual(tags1[key], tags2[key])
+ def assertTagsEqual(self, actual, expected):
+ """Utility method for checking the return value of the Tags() call.
+ It fills out the `expected` arg with the default (empty) values for every
+ tag type, so that the author needs only specify the non-empty values they
+ are interested in testing.
-class MockingEventAccumulatorTest(EventAccumulatorTest):
+ Args:
+ actual: The actual Accumulator tags response.
+ expected: The expected tags response (empty fields may be omitted)
+ """
- def setUp(self):
- super(MockingEventAccumulatorTest, self).setUp()
- self.stubs = googletest.StubOutForTesting()
- self.empty = {
+ empty_tags = {
ea.IMAGES: [],
ea.AUDIO: [],
ea.SCALARS: [],
@@ -183,8 +178,28 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ea.COMPRESSED_HISTOGRAMS: [],
ea.GRAPH: False,
ea.META_GRAPH: False,
- ea.RUN_METADATA: []
+ ea.RUN_METADATA: [],
+ ea.TENSORS: [],
}
+
+ # Verifies that there are no unexpected keys in the actual response.
+ # If this line fails, likely you added a new tag type, and need to update
+ # the empty_tags dictionary above.
+ self.assertItemsEqual(actual.keys(), empty_tags.keys())
+
+ for key in actual:
+ expected_value = expected.get(key, empty_tags[key])
+ if isinstance(expected_value, list):
+ self.assertItemsEqual(actual[key], expected_value)
+ else:
+ self.assertEqual(actual[key], expected_value)
+
+
+class MockingEventAccumulatorTest(EventAccumulatorTest):
+
+ def setUp(self):
+ super(MockingEventAccumulatorTest, self).setUp()
+ self.stubs = googletest.StubOutForTesting()
self._real_constructor = ea.EventAccumulator
self._real_generator = ea._GeneratorFromPath
@@ -203,7 +218,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
gen = _EventGenerator()
x = ea.EventAccumulator(gen)
x.Reload()
- self.assertEqual(x.Tags(), self.empty)
+ self.assertTagsEqual(x.Tags(), {})
def testTags(self):
gen = _EventGenerator()
@@ -223,16 +238,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ea.SCALARS: ['s1', 's2'],
ea.HISTOGRAMS: ['hst1', 'hst2'],
ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
- ea.GRAPH: False,
- ea.META_GRAPH: False,
- ea.RUN_METADATA: []
})
def testReload(self):
gen = _EventGenerator()
acc = ea.EventAccumulator(gen)
acc.Reload()
- self.assertEqual(acc.Tags(), self.empty)
+ self.assertTagsEqual(acc.Tags(), {})
gen.AddScalar('s1')
gen.AddScalar('s2')
gen.AddHistogram('hst1')
@@ -248,9 +260,6 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ea.SCALARS: ['s1', 's2'],
ea.HISTOGRAMS: ['hst1', 'hst2'],
ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
- ea.GRAPH: False,
- ea.META_GRAPH: False,
- ea.RUN_METADATA: []
})
def testScalars(self):
@@ -572,9 +581,6 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ea.SCALARS: ['s1', 's3'],
ea.HISTOGRAMS: ['hst1'],
ea.COMPRESSED_HISTOGRAMS: ['hst1'],
- ea.GRAPH: False,
- ea.META_GRAPH: False,
- ea.RUN_METADATA: []
})
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
@@ -751,7 +757,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.file_version, 2.0)
def testTFSummaryScalar(self):
- """Verify processing of tf.summary.scalar, which uses TensorSummary op."""
+ """Verify processing of tf.summary.scalar."""
event_sink = _EventGenerator(zero_out_timestamps=True)
writer = SummaryToEventTransformer(event_sink)
with self.test_session() as sess:
@@ -774,14 +780,9 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
]
self.assertTagsEqual(accumulator.Tags(), {
- ea.IMAGES: [],
- ea.AUDIO: [],
ea.SCALARS: ['scalar1', 'scalar2'],
- ea.HISTOGRAMS: [],
- ea.COMPRESSED_HISTOGRAMS: [],
ea.GRAPH: True,
ea.META_GRAPH: False,
- ea.RUN_METADATA: []
})
self.assertEqual(accumulator.Scalars('scalar1'), seq1)
@@ -821,15 +822,42 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertTagsEqual(accumulator.Tags(), {
ea.IMAGES: tags,
- ea.AUDIO: [],
- ea.SCALARS: [],
- ea.HISTOGRAMS: [],
- ea.COMPRESSED_HISTOGRAMS: [],
ea.GRAPH: True,
ea.META_GRAPH: False,
- ea.RUN_METADATA: []
})
+ def testTFSummaryTensor(self):
+ """Verify processing of tf.summary.tensor."""
+ event_sink = _EventGenerator(zero_out_timestamps=True)
+ writer = SummaryToEventTransformer(event_sink)
+ with self.test_session() as sess:
+ summary_lib.tensor_summary('scalar', constant_op.constant(1.0))
+ summary_lib.tensor_summary('vector', constant_op.constant(
+ [1.0, 2.0, 3.0]))
+ summary_lib.tensor_summary('string',
+ constant_op.constant(six.b('foobar')))
+ merged = summary_lib.merge_all()
+ summ = sess.run(merged)
+ writer.add_summary(summ, 0)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ self.assertTagsEqual(accumulator.Tags(), {
+ ea.TENSORS: ['scalar', 'vector', 'string'],
+ })
+
+ scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto
+ scalar = tensor_util.MakeNdarray(scalar_proto)
+ vector_proto = accumulator.Tensors('vector')[0].tensor_proto
+ vector = tensor_util.MakeNdarray(vector_proto)
+ string_proto = accumulator.Tensors('string')[0].tensor_proto
+ string = tensor_util.MakeNdarray(string_proto)
+
+ self.assertTrue(np.array_equal(scalar, 1.0))
+ self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))
+ self.assertTrue(np.array_equal(string, six.b('foobar')))
+
class RealisticEventAccumulatorTest(EventAccumulatorTest):
@@ -876,14 +904,10 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest):
acc = ea.EventAccumulator(directory)
acc.Reload()
self.assertTagsEqual(acc.Tags(), {
- ea.IMAGES: [],
- ea.AUDIO: [],
ea.SCALARS: ['id', 'sq'],
- ea.HISTOGRAMS: [],
- ea.COMPRESSED_HISTOGRAMS: [],
ea.GRAPH: True,
ea.META_GRAPH: True,
- ea.RUN_METADATA: ['test run']
+ ea.RUN_METADATA: ['test run'],
})
id_events = acc.Scalars('id')
sq_events = acc.Scalars('sq')
@@ -940,14 +964,8 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest):
acc = ea.EventAccumulator(directory)
acc.Reload()
self.assertTagsEqual(acc.Tags(), {
- ea.IMAGES: [],
- ea.AUDIO: [],
- ea.SCALARS: [],
- ea.HISTOGRAMS: [],
- ea.COMPRESSED_HISTOGRAMS: [],
ea.GRAPH: True,
ea.META_GRAPH: True,
- ea.RUN_METADATA: []
})
self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph())
self.assertProtoEquals(meta_graph_def, acc.MetaGraph())
diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py
index e41cfe8c3e..18176e10fe 100644
--- a/tensorflow/python/summary/event_multiplexer.py
+++ b/tensorflow/python/summary/event_multiplexer.py
@@ -65,6 +65,7 @@ class EventMultiplexer(object):
If you would like to watch `/parent/directory/path`, wait for it to be created
(if necessary) and then periodically pick up new runs, use
`AutoloadingMultiplexer`
+ @@Tensors
"""
def __init__(self,
@@ -370,6 +371,23 @@ class EventMultiplexer(object):
accumulator = self._GetAccumulator(run)
return accumulator.Audio(tag)
+ def Tensors(self, run, tag):
+ """Retrieve the tensor events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.TensorEvent`s.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Tensors(tag)
+
def Runs(self):
"""Return all the run names in the `EventMultiplexer`.
diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py
index fa4290cccd..8f78c6c547 100644
--- a/tensorflow/python/summary/event_multiplexer_test.py
+++ b/tensorflow/python/summary/event_multiplexer_test.py
@@ -60,11 +60,14 @@ class _FakeAccumulator(object):
def FirstEventTimestamp(self):
return 0
- def Scalars(self, tag_name):
- if tag_name not in self.Tags()[event_accumulator.SCALARS]:
+ def _TagHelper(self, tag_name, enum):
+ if tag_name not in self.Tags()[enum]:
raise KeyError
return ['%s/%s' % (self._path, tag_name)]
+ def Scalars(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.SCALARS)
+
def HealthPills(self, node_name):
if node_name not in self._node_names_to_health_pills:
raise KeyError
@@ -72,24 +75,19 @@ class _FakeAccumulator(object):
return [self._path + '/' + health_pill for health_pill in health_pills]
def Histograms(self, tag_name):
- if tag_name not in self.Tags()[event_accumulator.HISTOGRAMS]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS)
def CompressedHistograms(self, tag_name):
- if tag_name not in self.Tags()[event_accumulator.COMPRESSED_HISTOGRAMS]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ return self._TagHelper(tag_name, event_accumulator.COMPRESSED_HISTOGRAMS)
def Images(self, tag_name):
- if tag_name not in self.Tags()[event_accumulator.IMAGES]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ return self._TagHelper(tag_name, event_accumulator.IMAGES)
def Audio(self, tag_name):
- if tag_name not in self.Tags()[event_accumulator.AUDIO]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ return self._TagHelper(tag_name, event_accumulator.AUDIO)
+
+ def Tensors(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.TENSORS)
def Reload(self):
self.reload_called = True