aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r--tensorflow/python/summary/README.md15
-rwxr-xr-xtensorflow/python/summary/__init__.py0
-rw-r--r--tensorflow/python/summary/event_accumulator.py433
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py422
-rw-r--r--tensorflow/python/summary/event_multiplexer.py346
-rw-r--r--tensorflow/python/summary/event_multiplexer_test.py244
-rwxr-xr-xtensorflow/python/summary/impl/__init__.py0
-rw-r--r--tensorflow/python/summary/impl/directory_watcher.py115
-rw-r--r--tensorflow/python/summary/impl/directory_watcher_test.py102
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py49
-rw-r--r--tensorflow/python/summary/impl/event_file_loader_test.py59
-rw-r--r--tensorflow/python/summary/impl/reservoir.py164
-rw-r--r--tensorflow/python/summary/impl/reservoir_test.py178
13 files changed, 2127 insertions, 0 deletions
diff --git a/tensorflow/python/summary/README.md b/tensorflow/python/summary/README.md
new file mode 100644
index 0000000000..8a5fea0d9a
--- /dev/null
+++ b/tensorflow/python/summary/README.md
@@ -0,0 +1,15 @@
+# TensorFlow Event Processing
+
+This folder contains classes useful for analyzing and visualizing TensorFlow
+events files. The code is primarily being developed to support TensorBoard,
+but it can be used by anyone who wishes to analyze or visualize TensorFlow
+events files.
+
+If you wish to load TensorFlow events, you should use an EventAccumulator
+(to load from a single events file) or an EventMultiplexer (to load from
+multiple events files).
+
+The API around these tools has not solidified, and we may make backwards-
+incompatible changes without warning.
+
+If you have questions or requests, please contact danmane@google.com
diff --git a/tensorflow/python/summary/__init__.py b/tensorflow/python/summary/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/summary/__init__.py
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
new file mode 100644
index 0000000000..ae067d94fe
--- /dev/null
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -0,0 +1,433 @@
+"""Takes a generator of values, and accumulates them for a frontend."""
+
+import collections
+import threading
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.summary.impl import directory_watcher
+from tensorflow.python.summary.impl import event_file_loader
+from tensorflow.python.summary.impl import reservoir
+
+namedtuple = collections.namedtuple
+ScalarEvent = namedtuple('ScalarEvent',
+ ['wall_time', 'step', 'value'])
+
+CompressedHistogramEvent = namedtuple('CompressedHistogramEvent',
+ ['wall_time', 'step',
+ 'compressed_histogram_values'])
+
+CompressedHistogramValue = namedtuple('CompressedHistogramValue',
+ ['basis_point', 'value'])
+
+HistogramEvent = namedtuple('HistogramEvent',
+ ['wall_time', 'step', 'histogram_value'])
+
+HistogramValue = namedtuple('HistogramValue',
+ ['min', 'max', 'num', 'sum', 'sum_squares',
+ 'bucket_limit', 'bucket'])
+
+ImageEvent = namedtuple('ImageEvent',
+ ['wall_time', 'step', 'encoded_image_string',
+ 'width', 'height'])
+
+## The tagTypes below are just arbitrary strings chosen to pass the type
+## information of the tag from the backend to the frontend
+COMPRESSED_HISTOGRAMS = 'compressedHistograms'
+HISTOGRAMS = 'histograms'
+IMAGES = 'images'
+SCALARS = 'scalars'
+GRAPH = 'graph'
+
+## normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf)
+## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev,
+## and then the long tail.
+NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000)
+
+DEFAULT_SIZE_GUIDANCE = {
+ COMPRESSED_HISTOGRAMS: 500,
+ IMAGES: 4,
+ SCALARS: 10000,
+ HISTOGRAMS: 1,
+}
+
+STORE_EVERYTHING_SIZE_GUIDANCE = {
+ COMPRESSED_HISTOGRAMS: 0,
+ IMAGES: 0,
+ SCALARS: 0,
+ HISTOGRAMS: 0,
+}
+
+
+def IsTensorFlowEventsFile(path):
+ """Check the path name to see if it is probably a TF Events file."""
+ return 'tfevents' in path
+
+
+class EventAccumulator(object):
+ """An `EventAccumulator` takes an event generator, and accumulates the values.
+
+ The `EventAccumulator` is intended to provide a convenient Python interface
+ for loading Event data written during a TensorFlow run. TensorFlow writes out
+ `Event` protobuf objects, which have a timestamp and step number, and often
+ contain a `Summary`. Summaries can have different kinds of data like an image,
+ a scalar value, or a histogram. The Summaries also have a tag, which we use to
+ organize logically related data. The `EventAccumulator` supports retrieving
+ the `Event` and `Summary` data by its tag.
+
+ Calling `Tags()` gets a map from `tagType` (e.g. `'images'`,
+ `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those
+ data types. Then, various functional endpoints (eg
+ `Accumulator.Scalars(tag)`) allow for the retrieval of all data
+ associated with that tag.
+
+ Before usage, the `EventAccumulator` must be activated via `Reload()` or
+ `AutoUpdate(interval)`.
+
+ If activated via `Reload()`, it loads synchronously, so calls to `Values` or
+ `Tags` will block until all outstanding events are processed. Afterwards,
+ `Reload()` may be called again to load any new data.
+
+ If activated via `AutoUpdate(interval)`, it loads asynchronously, so calls to
+ `Values` or `Tags` will immediately return a valid subset of the outstanding
+ event data. It reloads new data every `interval` seconds.
+
+ Histograms and images are very large, so storing all of them is not
+ recommended.
+
+ @@Reload
+ @@AutoUpdate
+ @@Tags
+ @@Scalars
+ @@Graph
+ @@Histograms
+ @@CompressedHistograms
+ @@Images
+ """
+
+ def __init__(self, path, size_guidance=DEFAULT_SIZE_GUIDANCE,
+ compression_bps=NORMAL_HISTOGRAM_BPS):
+ """Construct the `EventAccumulator`.
+
+ Args:
+ path: A file path to a directory containing tf events files, or a single
+ tf events file. The accumulator will load events from this path.
+ size_guidance: Information on how much data the EventAccumulator should
+ store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
+ so as to avoid OOMing the client. The size_guidance should be a map
+ from a `tagType` string to an integer representing the number of
+ items to keep per tag for items of that `tagType`. If the size is 0,
+ all events are stored.
+ compression_bps: Information on how the `EventAccumulator` should compress
+ histogram data for the `CompressedHistograms` tag (for details see
+ `ProcessCompressedHistogram`).
+ """
+ sizes = {}
+ for key in DEFAULT_SIZE_GUIDANCE:
+ if key in size_guidance:
+ sizes[key] = size_guidance[key]
+ else:
+ sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
+
+ self._scalars = reservoir.Reservoir(size=sizes[SCALARS])
+ self._graph = None
+ self._histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS])
+ self._compressed_histograms = reservoir.Reservoir(
+ size=sizes[COMPRESSED_HISTOGRAMS])
+ self._images = reservoir.Reservoir(size=sizes[IMAGES])
+ self._generator_mutex = threading.Lock()
+ self._generator = _GeneratorFromPath(path)
+ self._is_autoupdating = False
+ self._activated = False
+ self._compression_bps = compression_bps
+
+ def Reload(self):
+ """Loads all events added since the last call to `Reload`.
+
+ If `Reload` was never called, loads all events in the file.
+ Calling `Reload` activates the `EventAccumulator`.
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ self._activated = True
+ with self._generator_mutex:
+ for event in self._generator.Load():
+ if event.HasField('graph_def'):
+ if self._graph is not None:
+ logging.warn(('Found more than one graph event per run.'
+ 'Overwritting the graph with the newest event'))
+ self._graph = event.graph_def
+ elif event.HasField('summary'):
+ for value in event.summary.value:
+ if value.HasField('simple_value'):
+ self._ProcessScalar(value.tag, event.wall_time, event.step,
+ value.simple_value)
+ elif value.HasField('histo'):
+ self._ProcessHistogram(value.tag, event.wall_time, event.step,
+ value.histo)
+ self._ProcessCompressedHistogram(value.tag, event.wall_time,
+ event.step, value.histo)
+ elif value.HasField('image'):
+ self._ProcessImage(value.tag, event.wall_time, event.step,
+ value.image)
+ return self
+
+ def AutoUpdate(self, interval=60):
+ """Asynchronously load all events, and periodically reload.
+
+ Calling this function is not thread safe.
+ Calling this function activates the `EventAccumulator`.
+
+ Args:
+ interval: how many seconds after each successful reload to load new events
+ (default 60)
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ if self._is_autoupdating:
+ return
+ self._is_autoupdating = True
+ self._activated = True
+ def Update():
+ self.Reload()
+ logging.info('EventAccumulator update triggered')
+ t = threading.Timer(interval, Update)
+ t.daemon = True
+ t.start()
+ # Asynchronously start the update process, so that the accumulator can
+ # immediately serve data, even if there is a very large event file to parse
+ t = threading.Timer(0, Update)
+ t.daemon = True
+ t.start()
+ return self
+
+ def Tags(self):
+ """Return all tags found in the value stream.
+
+ Raises:
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ A `{tagType: ['list', 'of', 'tags']}` dictionary.
+ """
+ self._VerifyActivated()
+ return {IMAGES: self._images.Keys(),
+ HISTOGRAMS: self._histograms.Keys(),
+ SCALARS: self._scalars.Keys(),
+ COMPRESSED_HISTOGRAMS: self._compressed_histograms.Keys(),
+ GRAPH: self._graph is not None}
+
+ def Scalars(self, tag):
+ """Given a summary tag, return all associated `ScalarEvent`s.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `ScalarEvent`s.
+ """
+ self._VerifyActivated()
+ return self._scalars.Items(tag)
+
+ def Graph(self):
+ """Return the graph definition, if there is one.
+
+ Raises:
+ ValueError: If there is no graph for this run.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ The `graph_def` proto.
+ """
+ self._VerifyActivated()
+ if self._graph is None:
+ raise ValueError('There is no graph in this EventAccumulator')
+ return self._graph
+
+ def Histograms(self, tag):
+ """Given a summary tag, return all associated histograms.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `HistogramEvent`s.
+ """
+ self._VerifyActivated()
+ return self._histograms.Items(tag)
+
+ def CompressedHistograms(self, tag):
+ """Given a summary tag, return all associated compressed histograms.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `CompressedHistogramEvent`s.
+ """
+ self._VerifyActivated()
+ return self._compressed_histograms.Items(tag)
+
+ def Images(self, tag):
+ """Given a summary tag, return all associated images.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `ImageEvent`s.
+ """
+ self._VerifyActivated()
+ return self._images.Items(tag)
+
+ def _VerifyActivated(self):
+ if not self._activated:
+ raise RuntimeError('Accumulator must be activated before it may be used.')
+
+ def _ProcessScalar(self, tag, wall_time, step, scalar):
+ """Processes a simple value by adding it to accumulated state."""
+ sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
+ self._scalars.AddItem(tag, sv)
+
+ def _ProcessHistogram(self, tag, wall_time, step, histo):
+ """Processes a histogram by adding it to accumulated state."""
+ histogram_value = HistogramValue(
+ min=histo.min,
+ max=histo.max,
+ num=histo.num,
+ sum=histo.sum,
+ sum_squares=histo.sum_squares,
+ # convert from proto repeated to list
+ bucket_limit=list(histo.bucket_limit),
+ bucket=list(histo.bucket),
+ )
+ histogram_event = HistogramEvent(
+ wall_time=wall_time,
+ step=step,
+ histogram_value=histogram_value,
+ )
+ self._histograms.AddItem(tag, histogram_event)
+
+ def _Remap(self, x, x0, x1, y0, y1):
+ """Linearly map from [x0, x1] unto [y0, y1]."""
+ return y0 + (x - x0) * float(y1 - y0)/(x1 - x0)
+
+ def _Percentile(self, compression_bps, bucket_limit, cumsum_weights,
+ histo_min, histo_max, histo_num):
+ """Linearly interpolates a histogram weight for a particular basis point.
+
+ Uses clamping methods on `histo_min` and `histo_max` to produce tight
+ linear estimates of the histogram weight at a particular basis point.
+
+ Args:
+ compression_bps: The desired basis point at which to estimate the weight
+ bucket_limit: An array of the RHS histogram bucket limits
+ cumsum_weights: A cumulative sum of the fraction of weights in each
+ histogram bucket, represented in basis points.
+ histo_min: The minimum weight observed in the weight histogram
+ histo_max: The maximum weight observed in the weight histogram
+ histo_num: The number of items in the weight histogram
+
+ Returns:
+ A linearly interpolated value of the histogram weight estimate.
+ """
+ if histo_num == 0: return 0
+
+ for i, cumsum in enumerate(cumsum_weights):
+ if cumsum >= compression_bps:
+ cumsum_prev = cumsum_weights[i-1] if i > 0 else 0
+ # Prevent cumsum = 0, cumsum_prev = 0, lerp divide by zero.
+ if cumsum == cumsum_prev: continue
+
+ # Calculate the lower bound of interpolation
+ lhs = bucket_limit[i-1] if (i > 0 and cumsum_prev > 0) else histo_min
+ lhs = max(lhs, histo_min)
+
+ # Calculate the upper bound of interpolation
+ rhs = bucket_limit[i]
+ rhs = min(rhs, histo_max)
+
+ weight = self._Remap(compression_bps, cumsum_prev, cumsum, lhs, rhs)
+ return weight
+
+ ## We have not exceeded cumsum, so return the max observed.
+ return histo_max
+
+ def _ProcessCompressedHistogram(self, tag, wall_time, step, histo):
+ """Processes a histogram by adding a compression to accumulated state.
+
+ Adds a compressed histogram by linearly interpolating histogram buckets to
+ represent the histogram weight at multiple compression points. Uses
+ self._compression_bps (passed to EventAccumulator constructor) as the
+ compression points (represented in basis points, 1/100ths of a precent).
+
+ Args:
+ tag: A string name of the tag for which histograms are retrieved.
+ wall_time: Time in seconds since epoch
+ step: Number of steps that have passed
+ histo: proto2 histogram Object
+ """
+ def _CumulativeSum(arr):
+ return [sum(arr[:i+1]) for i in range(len(arr))]
+
+ # Convert from proto repeated field into a Python list.
+ bucket = list(histo.bucket)
+ bucket_limit = list(histo.bucket_limit)
+
+ bucket_total = sum(bucket)
+ fraction_weights = [float(10000*x)/bucket_total for x in bucket]
+ cumsum_weights = _CumulativeSum(fraction_weights)
+
+ percentiles = [
+ self._Percentile(bps, bucket_limit, cumsum_weights, histo.min,
+ histo.max, histo.num) for bps in self._compression_bps
+ ]
+
+ compressed_histogram_values = [CompressedHistogramValue(
+ basis_point=bps,
+ value=value) for bps, value in zip(self._compression_bps, percentiles)]
+ histogram_event = CompressedHistogramEvent(
+ wall_time=wall_time,
+ step=step,
+ compressed_histogram_values=compressed_histogram_values)
+
+ self._compressed_histograms.AddItem(tag, histogram_event)
+
+ def _ProcessImage(self, tag, wall_time, step, image):
+ """Processes an image by adding it to accumulated state."""
+ event = ImageEvent(
+ wall_time=wall_time,
+ step=step,
+ encoded_image_string=image.encoded_image_string,
+ width=image.width,
+ height=image.height
+ )
+ self._images.AddItem(tag, event)
+
+
+def _GeneratorFromPath(path):
+ """Create an event generator for file or directory at given path string."""
+ loader_factory = event_file_loader.EventFileLoader
+ if gfile.IsDirectory(path):
+ return directory_watcher.DirectoryWatcher(path, loader_factory,
+ IsTensorFlowEventsFile)
+ else:
+ return loader_factory(path)
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
new file mode 100644
index 0000000000..c8de80ccba
--- /dev/null
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -0,0 +1,422 @@
+import os
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.summary import event_accumulator as ea
+
+
+class _EventGenerator(object):
+
+ def __init__(self):
+ self.items = []
+
+ def Load(self):
+ while self.items:
+ yield self.items.pop(0)
+
+ def AddScalar(self, tag, wall_time=0, step=0, value=0):
+ event = tf.Event(
+ wall_time=wall_time, step=step,
+ summary=tf.Summary(
+ value=[tf.Summary.Value(tag=tag, simple_value=value)]
+ )
+ )
+ self.AddEvent(event)
+
+ def AddHistogram(self, tag, wall_time=0, step=0, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=None, hbucket=None):
+ histo = tf.HistogramProto(min=hmin, max=hmax, num=hnum, sum=hsum,
+ sum_squares=hsum_squares,
+ bucket_limit=hbucket_limit,
+ bucket=hbucket)
+ event = tf.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]))
+ self.AddEvent(event)
+
+ def AddImage(self, tag, wall_time=0, step=0, encoded_image_string='imgstr',
+ width=150, height=100):
+ image = tf.Summary.Image(encoded_image_string=encoded_image_string,
+ width=width, height=height)
+ event = tf.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=tf.Summary(
+ value=[tf.Summary.Value(tag=tag, image=image)]))
+ self.AddEvent(event)
+
+ def AddEvent(self, event):
+ self.items.append(event)
+
+
+class EventAccumulatorTest(tf.test.TestCase):
+
+ def assertTagsEqual(self, tags1, tags2):
+ # Make sure the two dictionaries have the same keys.
+ self.assertItemsEqual(tags1, tags2)
+ # Additionally, make sure each key in the dictionary maps to the same value.
+ for key in tags1:
+ if isinstance(tags1[key], list):
+ # We don't care about the order of the values in lists, thus asserting
+ # only if the items are equal.
+ self.assertItemsEqual(tags1[key], tags2[key])
+ else:
+ # Make sure the values are equal.
+ self.assertEqual(tags1[key], tags2[key])
+
+
+class MockingEventAccumulatorTest(EventAccumulatorTest):
+
+ def setUp(self):
+ super(MockingEventAccumulatorTest, self).setUp()
+ self.empty = {ea.IMAGES: [],
+ ea.SCALARS: [],
+ ea.HISTOGRAMS: [],
+ ea.COMPRESSED_HISTOGRAMS: [],
+ ea.GRAPH: False}
+ self._real_constructor = ea.EventAccumulator
+ self._real_generator = ea._GeneratorFromPath
+ def _FakeAccumulatorConstructor(generator, *args, **kwargs):
+ ea._GeneratorFromPath = lambda x: generator
+ return self._real_constructor(generator, *args, **kwargs)
+ ea.EventAccumulator = _FakeAccumulatorConstructor
+
+ def tearDown(self):
+ ea.EventAccumulator = self._real_constructor
+ ea._GeneratorFromPath = self._real_generator
+
+ def testEmptyAccumulator(self):
+ gen = _EventGenerator()
+ x = ea.EventAccumulator(gen)
+ x.Reload()
+ self.assertEqual(x.Tags(), self.empty)
+
+ def testTags(self):
+ gen = _EventGenerator()
+ gen.AddScalar('sv1')
+ gen.AddScalar('sv2')
+ gen.AddHistogram('hst1')
+ gen.AddHistogram('hst2')
+ gen.AddImage('im1')
+ gen.AddImage('im2')
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(), {
+ ea.IMAGES: ['im1', 'im2'],
+ ea.SCALARS: ['sv1', 'sv2'],
+ ea.HISTOGRAMS: ['hst1', 'hst2'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
+ ea.GRAPH: False})
+
+ def testReload(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertEqual(acc.Tags(), self.empty)
+ gen.AddScalar('sv1')
+ gen.AddScalar('sv2')
+ gen.AddHistogram('hst1')
+ gen.AddHistogram('hst2')
+ gen.AddImage('im1')
+ gen.AddImage('im2')
+ self.assertEqual(acc.Tags(), self.empty)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: ['im1', 'im2'],
+ ea.SCALARS: ['sv1', 'sv2'],
+ ea.HISTOGRAMS: ['hst1', 'hst2'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
+ ea.GRAPH: False})
+
+ def testScalars(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ sv1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
+ sv2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
+ gen.AddScalar('sv1', wall_time=1, step=10, value=32)
+ gen.AddScalar('sv2', wall_time=2, step=12, value=64)
+ acc.Reload()
+ self.assertEqual(acc.Scalars('sv1'), [sv1])
+ self.assertEqual(acc.Scalars('sv2'), [sv2])
+
+ def testHistograms(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+
+ val1 = ea.HistogramValue(min=1, max=2, num=3, sum=4, sum_squares=5,
+ bucket_limit=[1, 2, 3], bucket=[0, 3, 0])
+ val2 = ea.HistogramValue(min=-2, max=3, num=4, sum=5, sum_squares=6,
+ bucket_limit=[2, 3, 4], bucket=[1, 3, 0])
+
+ hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1)
+ hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2)
+ gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0])
+ gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4,
+ hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0])
+ acc.Reload()
+ self.assertEqual(acc.Histograms('hst1'), [hst1])
+ self.assertEqual(acc.Histograms('hst2'), [hst2])
+
+ def testCompressedHistograms(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))
+
+ gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0])
+ gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4,
+ hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0])
+ acc.Reload()
+
+ # Create the expected values after compressing hst1
+ expected_vals1 = [ea.CompressedHistogramValue(bp, val) for bp, val in [(
+ 0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75), (10000, 2.0)]]
+ expected_cmphst1 = ea.CompressedHistogramEvent(
+ wall_time=1,
+ step=10,
+ compressed_histogram_values=expected_vals1)
+ self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1])
+
+ # Create the expected values after compressing hst2
+ expected_vals2 = [
+ ea.CompressedHistogramValue(bp, val)
+ for bp, val in [(0, -2), (2500, 2), (5000, 2 + float(1) / 3), (
+ 7500, 2 + float(2) / 3), (10000, 3)]
+ ]
+ expected_cmphst2 = ea.CompressedHistogramEvent(
+ wall_time=2,
+ step=12,
+ compressed_histogram_values=expected_vals2)
+ self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2])
+
+ def testPercentile(self):
+
+ def AssertExpectedForBps(bps, expected):
+ output = acc._Percentile(
+ bps, bucket_limit, cumsum_weights, histo_min, histo_max, histo_num)
+ self.assertAlmostEqual(expected, output)
+
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+
+ bucket_limit = [1, 2, 3, 4]
+ histo_num = 100
+
+ ## All weights in the first bucket
+ cumsum_weights = [10000, 10000, 10000, 10000]
+ histo_min = -1
+ histo_max = .9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## All weights in second bucket
+ cumsum_weights = [0, 10000, 10000, 10000]
+ histo_min = 1.1
+ histo_max = 1.8
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## All weights in the last bucket
+ cumsum_weights = [0, 0, 0, 10000]
+ histo_min = 3.1
+ histo_max = 3.6
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Weights distributed between two buckets
+ cumsum_weights = [0, 4000, 10000, 10000]
+ histo_min = 1.1
+ histo_max = 2.9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 4000, histo_min,
+ bucket_limit[1]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 4000, 10000, bucket_limit[1],
+ histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 4000, 10000, bucket_limit[1],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Weights distributed between all buckets
+ cumsum_weights = [1000, 4000, 8000, 10000]
+ histo_min = -1
+ histo_max = 3.9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 1000, 4000, bucket_limit[0],
+ bucket_limit[1]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 4000, 8000, bucket_limit[1],
+ bucket_limit[2]))
+ AssertExpectedForBps(7500, acc._Remap(7500, 4000, 8000, bucket_limit[1],
+ bucket_limit[2]))
+ AssertExpectedForBps(9000, acc._Remap(9000, 8000, 10000, bucket_limit[2],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Most weight in first bucket
+ cumsum_weights = [9000, 10000, 10000, 10000]
+ histo_min = -1
+ histo_max = 1.1
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(9500, acc._Remap(9500, 9000, 10000, bucket_limit[0],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ def testImages(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ im1 = ea.ImageEvent(wall_time=1, step=10, encoded_image_string='big',
+ width=400, height=300)
+ im2 = ea.ImageEvent(wall_time=2, step=12, encoded_image_string='small',
+ width=40, height=30)
+ gen.AddImage('im1', wall_time=1, step=10, encoded_image_string='big',
+ width=400, height=300)
+ gen.AddImage('im2', wall_time=2, step=12, encoded_image_string='small',
+ width=40, height=30)
+ acc.Reload()
+ self.assertEqual(acc.Images('im1'), [im1])
+ self.assertEqual(acc.Images('im2'), [im2])
+
+ def testActivation(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ self.assertFalse(acc._activated)
+ with self.assertRaises(RuntimeError):
+ acc.Tags()
+ with self.assertRaises(RuntimeError):
+ acc.Scalars('sv1')
+ acc.Reload()
+ self.assertTrue(acc._activated)
+ acc._activated = False
+
+ def testKeyError(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ with self.assertRaises(KeyError):
+ acc.Scalars('sv1')
+ with self.assertRaises(KeyError):
+ acc.Scalars('hst1')
+ with self.assertRaises(KeyError):
+ acc.Scalars('im1')
+ with self.assertRaises(KeyError):
+ acc.Histograms('sv1')
+ with self.assertRaises(KeyError):
+ acc.Histograms('im1')
+ with self.assertRaises(KeyError):
+ acc.Images('sv1')
+ with self.assertRaises(KeyError):
+ acc.Images('hst1')
+
+ def testNonValueEvents(self):
+ """Tests that non-value events in the generator don't cause early exits."""
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalar('sv1', wall_time=1, step=10, value=20)
+ gen.AddEvent(tf.Event(
+ wall_time=2, step=20, file_version='notsv2'))
+ gen.AddScalar('sv3', wall_time=3, step=100, value=1)
+ gen.AddHistogram('hst1')
+ gen.AddImage('im1')
+
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: ['im1'],
+ ea.SCALARS: ['sv1', 'sv3'],
+ ea.HISTOGRAMS: ['hst1'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1'],
+ ea.GRAPH: False})
+
+
+class RealisticEventAccumulatorTest(EventAccumulatorTest):
+
+ def setUp(self):
+ super(RealisticEventAccumulatorTest, self).setUp()
+
+ def testScalarsRealistically(self):
+ """Test accumulator by writing values and then reading them."""
+ def FakeScalarSummary(tag, value):
+ value = tf.Summary.Value(tag=tag, simple_value=value)
+ summary = tf.Summary(value=[value])
+ return summary
+
+ directory = os.path.join(self.get_temp_dir(), 'values_dir')
+ if gfile.IsDirectory(directory):
+ gfile.DeleteRecursively(directory)
+ gfile.MkDir(directory)
+
+ writer = tf.train.SummaryWriter(directory, max_queue=100)
+ graph_def = tf.GraphDef(node=[tf.NodeDef(name='A', op='Mul')])
+ # Add a graph to the summary writer.
+ writer.add_graph(graph_def)
+
+ # Write a bunch of events using the writer
+ for i in xrange(30):
+ summ_id = FakeScalarSummary('id', i)
+ summ_sq = FakeScalarSummary('sq', i*i)
+ writer.add_summary(summ_id, i*5)
+ writer.add_summary(summ_sq, i*5)
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: [],
+ ea.SCALARS: ['id', 'sq'],
+ ea.HISTOGRAMS: [],
+ ea.COMPRESSED_HISTOGRAMS: [],
+ ea.GRAPH: True})
+ id_events = acc.Scalars('id')
+ sq_events = acc.Scalars('sq')
+ self.assertEqual(30, len(id_events))
+ self.assertEqual(30, len(sq_events))
+ for i in xrange(30):
+ self.assertEqual(i*5, id_events[i].step)
+ self.assertEqual(i*5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i*i, sq_events[i].value)
+
+ # Write a few more events to test incremental reloading
+ for i in xrange(30, 40):
+ summ_id = FakeScalarSummary('id', i)
+ summ_sq = FakeScalarSummary('sq', i*i)
+ writer.add_summary(summ_id, i*5)
+ writer.add_summary(summ_sq, i*5)
+ writer.flush()
+
+ # Verify we can now see all of the data
+ acc.Reload()
+ self.assertEqual(40, len(id_events))
+ self.assertEqual(40, len(sq_events))
+ for i in xrange(40):
+ self.assertEqual(i*5, id_events[i].step)
+ self.assertEqual(i*5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i*i, sq_events[i].value)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py
new file mode 100644
index 0000000000..9966d76b21
--- /dev/null
+++ b/tensorflow/python/summary/event_multiplexer.py
@@ -0,0 +1,346 @@
+"""Provides an interface for working with multiple event files."""
+
+import os
+import threading
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.summary import event_accumulator
+
+
+class EventMultiplexer(object):
+ """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
+
+ Each `EventAccumulator` is associated with a `run`, which is a self-contained
+ TensorFlow execution. The `EventMultiplexer` provides methods for extracting
+ information about events from multiple `run`s.
+
+ Example usage for loading specific runs from files:
+
+ ```python
+ x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
+ x.Reload()
+ ```
+
+ Example usage for loading a directory where each subdirectory is a run
+
+ ```python
+ (eg:) /parent/directory/path/
+ /parent/directory/path/run1/
+ /parent/directory/path/run1/events.out.tfevents.1001
+ /parent/directory/path/run1/events.out.tfevents.1002
+
+ /parent/directory/path/run2/
+ /parent/directory/path/run2/events.out.tfevents.9232
+
+ /parent/directory/path/run3/
+ /parent/directory/path/run3/events.out.tfevents.9232
+ x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
+ (which is equivalent to:)
+ x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
+ ```
+
+ If you would like to watch `/parent/directory/path`, wait for it to be created
+ (if necessary) and then periodically pick up new runs, use
+ `AutoloadingMultiplexer`
+
+ @@__init__
+ @@AddRun
+ @@AddRunsFromDirectory
+ @@Reload
+ @@AutoUpdate
+ @@Runs
+ @@Scalars
+ @@Graph
+ @@Histograms
+ @@CompressedHistograms
+ @@Images
+ """
+
+ def __init__(self, run_path_map=None,
+ size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE):
+ """Constructor for the `EventMultiplexer`.
+
+ Args:
+ run_path_map: Dict `{run: path}` which specifies the
+ name of a run, and the path to find the associated events. If it is
+ None, then the EventMultiplexer initializes without any runs.
+ size_guidance: A dictionary mapping from `tagType` to the number of items
+ to store for each tag of that type. See
+ `event_ccumulator.EventAccumulator` for details.
+ """
+ self._accumulators_mutex = threading.Lock()
+ self._accumulators = {}
+ self._paths = {}
+ self._reload_called = False
+ self._autoupdate_called = False
+ self._autoupdate_interval = None
+ self._size_guidance = size_guidance
+ if run_path_map is not None:
+ for (run, path) in run_path_map.iteritems():
+ self.AddRun(path, run)
+
+ def AddRun(self, path, name=None):
+ """Add a run to the multiplexer.
+
+ If the name is not specified, it is the same as the path.
+
+ If a run by that name exists, and we are already watching the right path,
+ do nothing. If we are watching a different path, replace the event
+ accumulator.
+
+ If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or
+ `Reload` the newly created accumulators. This maintains the invariant that
+ once the Multiplexer was activated, all of its accumulators are active.
+
+ Args:
+ path: Path to the event files (or event directory) for given run.
+ name: Name of the run to add. If not provided, is set to path.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ if name is None or name is '':
+ name = path
+ accumulator = None
+ with self._accumulators_mutex:
+ if name not in self._accumulators or self._paths[name] != path:
+ if name in self._paths and self._paths[name] != path:
+ # TODO(danmane) - Make it impossible to overwrite an old path with
+ # a new path (just give the new path a distinct name)
+ logging.warning('Conflict for name %s: old path %s, new path %s' %
+ (name, self._paths[name], path))
+ logging.info('Constructing EventAccumulator for %s', path)
+ accumulator = event_accumulator.EventAccumulator(path,
+ self._size_guidance)
+ self._accumulators[name] = accumulator
+ self._paths[name] = path
+ if accumulator:
+ if self._reload_called:
+ accumulator.Reload()
+ if self._autoupdate_called:
+ accumulator.AutoUpdate(self._autoupdate_interval)
+ return self
+
+ def AddRunsFromDirectory(self, path, name=None):
+ """Load runs from a directory, assuming each subdirectory is a run.
+
+ If path doesn't exist, no-op. This ensures that it is safe to call
+ `AddRunsFromDirectory` multiple times, even before the directory is made.
+
+ If the directory contains TensorFlow event files, it is itself treated as a
+ run.
+
+ If the `EventMultiplexer` is already loaded or autoupdating, this will cause
+ the newly created accumulators to also `Reload()` or `AutoUpdate()`.
+
+ Args:
+ path: A string path to a directory to load runs from.
+ name: Optionally, what name to apply to the runs. If name is provided
+ and the directory contains run subdirectories, the name of each subrun
+ is the concatenation of the parent name and the subdirectory name. If
+ name is provided and the directory contains event files, then a run
+ is added called "name" and with the events from the path.
+
+ Raises:
+ ValueError: If the path exists and isn't a directory.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ if not gfile.Exists(path):
+ return # Maybe it hasn't been created yet, fail silently to retry later
+ if not gfile.IsDirectory(path):
+ raise ValueError('Path exists and is not a directory, %s' % path)
+ paths = gfile.ListDirectory(path)
+ is_directory = lambda x: gfile.IsDirectory(os.path.join(path, x))
+ subdirectories = filter(is_directory, paths)
+ for s in subdirectories:
+ if name:
+ subname = '/'.join([name, s])
+ else:
+ subname = s
+ self.AddRun(os.path.join(path, s), subname)
+
+ if filter(event_accumulator.IsTensorFlowEventsFile, paths):
+ directory_name = os.path.split(path)[1]
+ logging.info('Directory %s has event files; loading' % directory_name)
+ if name:
+ dname = name
+ else:
+ dname = directory_name
+ self.AddRun(path, dname)
+ return self
+
+ def Reload(self):
+ """Call `Reload` on every `EventAccumulator`."""
+ self._reload_called = True
+ with self._accumulators_mutex:
+ loaders = self._accumulators.values()
+
+ for l in loaders:
+ l.Reload()
+ return self
+
+ def AutoUpdate(self, interval=60):
+ """Call `AutoUpdate(interval)` on every `EventAccumulator`."""
+ self._autoupdate_interval = interval
+ self._autoupdate_called = True
+ with self._accumulators_mutex:
+ loaders = self._accumulators.values()
+ for l in loaders:
+ l.AutoUpdate(interval)
+ return self
+
+ def Scalars(self, run, tag):
+ """Retrieve the scalar events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.ScalarEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Scalars(tag)
+
+ def Graph(self, run):
+ """Retrieve the graphs associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+ RuntimeError: If the run's EventAccumulator has not been activated.
+
+ Returns:
+ The `graph_def` protobuf data structure.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Graph()
+
+ def Histograms(self, run, tag):
+ """Retrieve the histogram events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.HistogramEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Histograms(tag)
+
+ def CompressedHistograms(self, run, tag):
+ """Retrieve the compressed histogram events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's EventAccumulator has not been activated.
+
+ Returns:
+ An array of `event_accumulator.CompressedHistogramEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.CompressedHistograms(tag)
+
+ def Images(self, run, tag):
+ """Retrieve the image events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.ImageEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Images(tag)
+
+ def Runs(self):
+ """Return all the run names in the `EventMultiplexer`.
+
+ Returns:
+ ```
+ {runName: { images: [tag1, tag2, tag3],
+ scalarValues: [tagA, tagB, tagC],
+ histograms: [tagX, tagY, tagZ],
+ compressedHistograms: [tagX, tagY, tagZ],
+ graph: true}}
+ ```
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(self._accumulators.iteritems())
+ return {
+ run_name: accumulator.Tags()
+ for run_name, accumulator in items
+ }
+
+ def _GetAccumulator(self, run):
+ with self._accumulators_mutex:
+ return self._accumulators[run]
+
+
+def AutoloadingMultiplexer(path_to_run, interval_secs=60,
+ size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE):
+ """Create an `EventMultiplexer` that automatically loads runs in directories.
+
+ Args:
+ path_to_run: Dict `{path: name}` which specifies the path to a directory,
+ and its name (or `None`). The path may contain tfevents files (in which
+ case they are loaded, with name as the name of the run) and subdirectories
+ containing tfevents files (in which case each subdirectory is added as a
+ run, named `'name/subdirectory'`).
+
+ interval_secs: How often to poll the directory for new runs.
+ size_guidance: How much data to store for each tag of various types - see
+ `event_accumulator.EventAccumulator`.
+
+ Returns:
+ The multiplexer which will automatically load from the directories.
+
+ Raises:
+ ValueError: if `path_to_run` is `None`
+ TypeError: if `path_to_run` is not a dict
+ """
+ multiplexer = EventMultiplexer(size_guidance=size_guidance)
+ if path_to_run is None:
+ raise ValueError('Cant construct an autoloading multiplexer without runs.')
+ if not isinstance(path_to_run, dict):
+ raise TypeError('path_to_run should be a dict, was %s', path_to_run)
+ def Load():
+ for (path, name) in path_to_run.iteritems():
+ logging.info('Checking for new runs in %s', path)
+ multiplexer.AddRunsFromDirectory(path, name)
+ t = threading.Timer(interval_secs, Load)
+ t.daemon = True
+ t.start()
+ t = threading.Timer(0, Load)
+ t.daemon = True
+ t.start()
+ return multiplexer
diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py
new file mode 100644
index 0000000000..35a8aed266
--- /dev/null
+++ b/tensorflow/python/summary/event_multiplexer_test.py
@@ -0,0 +1,244 @@
+import os
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary import event_accumulator
+from tensorflow.python.summary import event_multiplexer
+
+
+class _FakeAccumulator(object):
+
+ def __init__(self, path):
+ self._path = path
+ self.autoupdate_called = False
+ self.autoupdate_interval = None
+ self.reload_called = False
+
+ def Tags(self):
+ return {event_accumulator.IMAGES: ['im1', 'im2'],
+ event_accumulator.HISTOGRAMS: ['hst1', 'hst2'],
+ event_accumulator.COMPRESSED_HISTOGRAMS: ['cmphst1', 'cmphst2'],
+ event_accumulator.SCALARS: ['sv1', 'sv2']}
+
+ def Scalars(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.SCALARS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def Histograms(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.HISTOGRAMS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def CompressedHistograms(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.COMPRESSED_HISTOGRAMS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def Images(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.IMAGES]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def AutoUpdate(self, interval):
+ self.autoupdate_called = True
+ self.autoupdate_interval = interval
+
+ def Reload(self):
+ self.reload_called = True
+
+
+def _GetFakeAccumulator(path, size_guidance): # pylint: disable=unused-argument
+ return _FakeAccumulator(path)
+
+
+class EventMultiplexerTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ super(EventMultiplexerTest, self).setUp()
+ event_accumulator.EventAccumulator = _GetFakeAccumulator
+
+ def testEmptyLoader(self):
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+
+ def testRunNamesRespected(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual(x.Runs().keys(), ['run1', 'run2'])
+ self.assertEqual(x._GetAccumulator('run1')._path, 'path1')
+ self.assertEqual(x._GetAccumulator('run2')._path, 'path2')
+
+ def testReload(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertFalse(x._GetAccumulator('run1').reload_called)
+ self.assertFalse(x._GetAccumulator('run2').reload_called)
+ x.Reload()
+ self.assertTrue(x._GetAccumulator('run1').reload_called)
+ self.assertTrue(x._GetAccumulator('run2').reload_called)
+
+ def testAutoUpdate(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ x.AutoUpdate(5)
+ self.assertTrue(x._GetAccumulator('run1').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run1').autoupdate_interval, 5)
+ self.assertTrue(x._GetAccumulator('run2').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run2').autoupdate_interval, 5)
+
+ def testScalars(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+
+ run1_actual = x.Scalars('run1', 'sv1')
+ run1_expected = ['path1/sv1']
+
+ self.assertEqual(run1_expected, run1_actual)
+
+ def testExceptions(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ with self.assertRaises(KeyError):
+ x.Scalars('sv1', 'xxx')
+
+ def testInitialization(self):
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual(x.Runs(), ['run1', 'run2'])
+ self.assertEqual(x._GetAccumulator('run1')._path, 'path1')
+ self.assertEqual(x._GetAccumulator('run2')._path, 'path2')
+
+ def testAddRunsFromDirectory(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ fakedir = join(tmpdir, 'fake_accumulator_directory')
+ realdir = join(tmpdir, 'real_accumulator_directory')
+ self.assertEqual(x.Runs(), {})
+ x.AddRunsFromDirectory(fakedir)
+ self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect')
+
+ path1 = join(realdir, 'path1')
+ gfile.MkDir(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs().keys(), ['path1'], 'loaded run: path1')
+ loader1 = x._GetAccumulator('path1')
+ self.assertEqual(loader1._path, path1, 'has the correct path')
+
+ path2 = join(realdir, 'path2')
+ gfile.MkDir(path2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertEqual(x._GetAccumulator('path1'), loader1,
+ 'loader1 not regenerated')
+ loader2 = x._GetAccumulator('path2')
+
+ path2_2 = join(path2, 'path2')
+ gfile.MkDir(path2_2)
+ x.AddRunsFromDirectory(path2)
+ self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertNotEqual(loader2, x._GetAccumulator('path2'),
+ 'loader2 regenerated')
+ self.assertEqual(x._GetAccumulator('path2')._path, path2_2,
+ 'loader2 path correct')
+
+ def testAddRunsFromDirectoryThatContainsEvents(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, 'event_containing_directory')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
+ pass
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ['event_containing_directory'])
+
+ subdir = join(realdir, 'subdir')
+ gfile.MkDir(subdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ['event_containing_directory', 'subdir'])
+
+ def testAddRunsFromDirectoryWithRunNames(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, 'event_containing_directory')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
+ pass
+ x.AddRunsFromDirectory(realdir, 'foo')
+ self.assertItemsEqual(x.Runs(), ['foo'])
+
+ subdir = join(realdir, 'subdir')
+ gfile.MkDir(subdir)
+ x.AddRunsFromDirectory(realdir, 'foo')
+ self.assertItemsEqual(x.Runs(), ['foo', 'foo/subdir'])
+
+ def testAddRunsFromDirectoryThrowsException(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+
+ filepath = os.path.join(tmpdir, 'bad_file')
+ with gfile.GFile(filepath, 'w'):
+ pass
+
+ with self.assertRaises(ValueError):
+ x.AddRunsFromDirectory(filepath)
+
+ def testAddRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AddRun('run1_path', 'run1')
+ run1 = x._GetAccumulator('run1')
+ self.assertEqual(x.Runs().keys(), ['run1'])
+ self.assertEqual(run1._path, 'run1_path')
+
+ x.AddRun('run1_path', 'run1')
+ self.assertEqual(run1, x._GetAccumulator('run1'), 'loader not recreated')
+
+ x.AddRun('run2_path', 'run1')
+ new_run1 = x._GetAccumulator('run1')
+ self.assertEqual(new_run1._path, 'run2_path')
+ self.assertNotEqual(run1, new_run1)
+
+ x.AddRun('runName3')
+ self.assertItemsEqual(x.Runs().keys(), ['run1', 'runName3'])
+ self.assertEqual(x._GetAccumulator('runName3')._path, 'runName3')
+
+ def testAddRunMaintainsLoading(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.Reload()
+ x.AddRun('run1')
+ x.AddRun('run2')
+ self.assertTrue(x._GetAccumulator('run1').reload_called)
+ self.assertTrue(x._GetAccumulator('run2').reload_called)
+
+ def testAddRunMaintainsAutoUpdate(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AutoUpdate(5)
+ x.AddRun('run1')
+ x.AddRun('run2')
+ self.assertTrue(x._GetAccumulator('run1').autoupdate_called)
+ self.assertTrue(x._GetAccumulator('run2').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run1').autoupdate_interval, 5)
+ self.assertEqual(x._GetAccumulator('run2').autoupdate_interval, 5)
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/__init__.py b/tensorflow/python/summary/impl/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/summary/impl/__init__.py
diff --git a/tensorflow/python/summary/impl/directory_watcher.py b/tensorflow/python/summary/impl/directory_watcher.py
new file mode 100644
index 0000000000..830e538cb6
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher.py
@@ -0,0 +1,115 @@
+"""Contains the implementation for the DirectoryWatcher class."""
+import os
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+
+
+class DirectoryWatcher(object):
+ """A DirectoryWatcher wraps a loader to load from a directory.
+
+ A loader reads a file on disk and produces some kind of values as an
+ iterator. A DirectoryWatcher takes a directory with one file at a time being
+ written to and a factory for loaders and watches all the files at once.
+
+ This class is *only* valid under the assumption that files are never removed
+ and the only file ever changed is whichever one is lexicographically last.
+ """
+
+ def __init__(self, directory, loader_factory, path_filter=lambda x: True):
+ """Constructs a new DirectoryWatcher.
+
+ Args:
+ directory: The directory to watch. The directory doesn't have to exist.
+ loader_factory: A factory for creating loaders. The factory should take a
+ file path and return an object that has a Load method returning an
+ iterator that will yield all events that have not been yielded yet.
+ path_filter: Only files whose full path matches this predicate will be
+ loaded. If not specified, all files are loaded.
+
+ Raises:
+ ValueError: If directory or loader_factory is None.
+ """
+ if directory is None:
+ raise ValueError('A directory is required')
+ if loader_factory is None:
+ raise ValueError('A loader factory is required')
+ self._directory = directory
+ self._loader_factory = loader_factory
+ self._loader = None
+ self._path = None
+ self._path_filter = path_filter
+
+ def Load(self):
+ """Loads new values from disk.
+
+ The watcher will load from one file at a time; as soon as that file stops
+ yielding events, it will move on to the next file. We assume that old files
+ are never modified after a newer file has been written. As a result, Load()
+ can be called multiple times in a row without losing events that have not
+ been yielded yet. In other words, we guarantee that every event will be
+ yielded exactly once.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+
+ # If the loader exists, check it for a value.
+ if not self._loader:
+ self._InitializeLoader()
+
+ while True:
+ # Yield all the new events in the file we're currently loading from.
+ for event in self._loader.Load():
+ yield event
+
+ next_path = self._GetNextPath()
+ if not next_path:
+ logging.info('No more files in %s', self._directory)
+ # Current file is empty and there are no new files, so we're done.
+ return
+
+ # There's a new file, so check to make sure there weren't any events
+ # written between when we finished reading the current file and when we
+ # checked for the new one. The sequence of events might look something
+ # like this:
+ #
+ # 1. Event #1 written to file #1.
+ # 2. We check for events and yield event #1 from file #1
+ # 3. We check for events and see that there are no more events in file #1.
+ # 4. Event #2 is written to file #1.
+ # 5. Event #3 is written to file #2.
+ # 6. We check for a new file and see that file #2 exists.
+ #
+ # Without this loop, we would miss event #2. We're also guaranteed by the
+ # loader contract that no more events will be written to file #1 after
+ # events start being written to file #2, so we don't have to worry about
+ # that.
+ for event in self._loader.Load():
+ yield event
+
+ logging.info('Directory watcher for %s advancing to file %s',
+ self._directory, next_path)
+
+ # Advance to the next file and start over.
+ self._SetPath(next_path)
+
+ def _InitializeLoader(self):
+ path = self._GetNextPath()
+ if path:
+ self._SetPath(path)
+ else:
+ raise StopIteration
+
+ def _SetPath(self, path):
+ self._path = path
+ self._loader = self._loader_factory(path)
+
+ def _GetNextPath(self):
+ """Returns the path of the next file to use or None if no file exists."""
+ sorted_paths = [os.path.join(self._directory, path)
+ for path in sorted(gfile.ListDirectory(self._directory))]
+ # We filter here so the filter gets the full directory name.
+ filtered_paths = (path for path in sorted_paths
+ if self._path_filter(path) and path > self._path)
+ return next(filtered_paths, None)
diff --git a/tensorflow/python/summary/impl/directory_watcher_test.py b/tensorflow/python/summary/impl/directory_watcher_test.py
new file mode 100644
index 0000000000..a22e3f2922
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher_test.py
@@ -0,0 +1,102 @@
+"""Tests for directory_watcher."""
+
+import os
+import shutil
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import directory_watcher
+
+
+class _ByteLoader(object):
+ """A loader that loads individual bytes from a file."""
+
+ def __init__(self, path):
+ self._f = open(path)
+
+ def Load(self):
+ while True:
+ byte = self._f.read(1)
+ if byte:
+ yield byte
+ else:
+ return
+
+
+class DirectoryWatcherTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ # Put everything in a directory so it's easier to delete.
+ self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir')
+ os.mkdir(self._directory)
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader)
+
+ def tearDown(self):
+ shutil.rmtree(self._directory)
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self._directory, filename)
+ with open(path, 'a') as f:
+ f.write(data)
+
+ def assertWatcherYields(self, values):
+ self.assertEqual(list(self._watcher.Load()), values)
+
+ def testRaisesWithBadArguments(self):
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher(None, lambda x: [])
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher('asdf', None)
+
+ def testEmptyDirectory(self):
+ self.assertWatcherYields([])
+
+ def testSingleWrite(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+
+ def testMultipleWrites(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+ self._WriteToFile('a', 'xyz')
+ self.assertWatcherYields(['x', 'y', 'z'])
+
+ def testMultipleLoads(self):
+ self._WriteToFile('a', 'a')
+ self._watcher.Load()
+ self._watcher.Load()
+ self.assertWatcherYields(['a'])
+
+ def testMultipleFilesAtOnce(self):
+ self._WriteToFile('b', 'b')
+ self._WriteToFile('a', 'a')
+ self.assertWatcherYields(['a', 'b'])
+
+ def testFinishesLoadingFileWhenSwitchingToNewFile(self):
+ self._WriteToFile('a', 'a')
+ # Empty the iterator.
+ self.assertEquals(['a'], list(self._watcher.Load()))
+ self._WriteToFile('a', 'b')
+ self._WriteToFile('b', 'c')
+ # The watcher should finish its current file before starting a new one.
+ self.assertWatcherYields(['b', 'c'])
+
+ def testIntermediateEmptyFiles(self):
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('b', '')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+ def testFileFilter(self):
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader,
+ path_filter=lambda path: 'do_not_watch_me' not in path)
+
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('do_not_watch_me', 'b')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
new file mode 100644
index 0000000000..0571bc84cb
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -0,0 +1,49 @@
+"""Functionality for loading events from a record file."""
+
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.platform import app
+from tensorflow.python.platform import logging
+
+
+class EventFileLoader(object):
+ """An EventLoader is an iterator that yields Event protos."""
+
+ def __init__(self, file_path):
+ if file_path is None:
+ raise ValueError('A file path is required')
+ logging.debug('Opening a record reader pointing at %s', file_path)
+ self._reader = pywrap_tensorflow.PyRecordReader_New(file_path, 0)
+ # Store it for logging purposes.
+ self._file_path = file_path
+ if not self._reader:
+ raise IOError('Failed to open a record reader pointing to %s' % file_path)
+
+ def Load(self):
+ """Loads all new values from disk.
+
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+ while self._reader.GetNext():
+ logging.debug('Got an event from %s', self._file_path)
+ event = event_pb2.Event()
+ event.ParseFromString(self._reader.record())
+ yield event
+ logging.debug('No more events in %s', self._file_path)
+
+
+def main(argv):
+ if len(argv) != 2:
+ print 'Usage: event_file_loader <path-to-the-recordio-file>'
+ return 1
+ loader = EventFileLoader(argv[1])
+ for event in loader.Load():
+ print event
+
+
+if __name__ == '__main__':
+ app.run()
diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py
new file mode 100644
index 0000000000..1dc29d85d5
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader_test.py
@@ -0,0 +1,59 @@
+"""Tests for event_file_loader."""
+
+import os
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import event_file_loader
+
+
+class EventFileLoaderTest(test_util.TensorFlowTestCase):
+ # A record containing a simple event.
+ RECORD = ('\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu'
+ '\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d')
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self.get_temp_dir(), filename)
+ with open(path, 'ab') as f:
+ f.write(data)
+
+ def _LoaderForTestFile(self, filename):
+ return event_file_loader.EventFileLoader(
+ os.path.join(self.get_temp_dir(), filename))
+
+ def testEmptyEventFile(self):
+ self._WriteToFile('empty_event_file', '')
+ loader = self._LoaderForTestFile('empty_event_file')
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testSingleWrite(self):
+ self._WriteToFile('single_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('single_event_file')
+ events = list(loader.Load())
+ self.assertEquals(len(events), 1)
+ self.assertEquals(events[0].wall_time, 1440183447.0)
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testMultipleWrites(self):
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 1)
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleLoads(self):
+ self._WriteToFile('multiple_loads_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('multiple_loads_event_file')
+ loader.Load()
+ loader.Load()
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleWritesAtOnce(self):
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 2)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py
new file mode 100644
index 0000000000..2c9b294841
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir.py
@@ -0,0 +1,164 @@
+"""A key-value[] store that implements reservoir sampling on the values."""
+
+import collections
+import random
+import threading
+
+
+class Reservoir(object):
+ """A map-to-arrays container, with deterministic Reservoir Sampling.
+
+ Items are added with an associated key. Items may be retrieved by key, and
+ a list of keys can also be retrieved. If size is not zero, then it dictates
+ the maximum number of items that will be stored with each key. Once there are
+ more items for a given key, they are replaced via reservoir sampling, such
+ that each item has an equal probability of being included in the sample.
+
+ Deterministic means that for any given seed and bucket size, the sequence of
+ values that are kept for any given tag will always be the same, and that this
+ is independent of any insertions on other tags. That is:
+
+ >>> separate_reservoir = reservoir.Reservoir(10)
+ >>> interleaved_reservoir = reservoir.Reservoir(10)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key1', i)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key2', i)
+ >>> for i in xrange(100):
+ >>> interleaved_reservoir.AddItem('key1', i)
+ >>> interleaved_reservoir.AddItem('key2', i)
+
+ separate_reservoir and interleaved_reservoir will be in identical states.
+
+ See: https://en.wikipedia.org/wiki/Reservoir_sampling
+
+ Adding items has amortized O(1) runtime.
+
+ """
+
+ def __init__(self, size, seed=0):
+ """Creates a new reservoir.
+
+ Args:
+ size: The number of values to keep in the reservoir for each tag. If 0,
+ all values will be kept.
+ seed: The seed of the random number generator to use when sampling.
+ Different values for |seed| will produce different samples from the same
+ input items.
+
+ Raises:
+ ValueError: If size is negative or not an integer.
+ """
+ if size < 0 or size != round(size):
+ raise ValueError('size must be nonegative integer, was %s' % size)
+ self._buckets = collections.defaultdict(
+ lambda: _ReservoirBucket(size, random.Random(seed)))
+ # _mutex guards the keys - creating new keys, retreiving by key, etc
+ # the internal items are guarded by the ReservoirBuckets' internal mutexes
+ self._mutex = threading.Lock()
+
+ def Keys(self):
+ """Return all the keys in the reservoir.
+
+ Returns:
+ ['list', 'of', 'keys'] in the Reservoir.
+ """
+ with self._mutex:
+ return self._buckets.keys()
+
+ def Items(self, key):
+ """Return items associated with given key.
+
+ Args:
+ key: The key for which we are finding associated items.
+
+ Raises:
+ KeyError: If the key is not ofund in the reservoir.
+
+ Returns:
+ [list, of, items] associated with that key.
+ """
+ with self._mutex:
+ if key not in self._buckets:
+ raise KeyError('Key %s was not found in Reservoir' % key)
+ bucket = self._buckets[key]
+ return bucket.Items()
+
+ def AddItem(self, key, item):
+ """Add a new item to the Reservoir with the given tag.
+
+ The new item is guaranteed to be kept in the Reservoir. One other item might
+ be replaced.
+
+ Args:
+ key: The key to store the item under.
+ item: The item to add to the reservoir.
+ """
+ with self._mutex:
+ bucket = self._buckets[key]
+ bucket.AddItem(item)
+
+
+class _ReservoirBucket(object):
+ """A container for items from a stream, that implements reservoir sampling.
+
+ It always stores the most recent item as its final item.
+ """
+
+ def __init__(self, _max_size, _random=None):
+ """Create the _ReservoirBucket.
+
+ Args:
+ _max_size: The maximum size the reservoir bucket may grow to. If size is
+ zero, the bucket has unbounded size.
+ _random: The random number generator to use. If not specified, defaults to
+ random.Random(0).
+
+ Raises:
+ ValueError: if the size is not a nonnegative integer.
+ """
+ if _max_size < 0 or _max_size != round(_max_size):
+ raise ValueError('_max_size must be nonegative int, was %s' % _max_size)
+ self.items = []
+ # This mutex protects the internal items, ensuring that calls to Items and
+ # AddItem are thread-safe
+ self._mutex = threading.Lock()
+ self._max_size = _max_size
+ self._count = 0
+ if _random is not None:
+ self._random = _random
+ else:
+ self._random = random.Random(0)
+
+ def AddItem(self, item):
+ """Add an item to the ReservoirBucket, replacing an old item if necessary.
+
+ The new item is guaranteed to be added to the bucket, and to be the last
+ element in the bucket. If the bucket has reached capacity, then an old item
+ will be replaced. With probability (_max_size/_count) a random item in the
+ bucket will be popped out and the new item will be appended to the end. With
+ probability (1 - _max_size/_count) the last item in the bucket will be
+ replaced.
+
+ Since the O(n) replacements occur with O(1/_count) liklihood, the amortized
+ runtime is O(1).
+
+ Args:
+ item: The item to add to the bucket.
+ """
+ with self._mutex:
+ if len(self.items) < self._max_size or self._max_size == 0:
+ self.items.append(item)
+ else:
+ r = self._random.randint(0, self._count)
+ if r < self._max_size:
+ self.items.pop(r)
+ self.items.append(item)
+ else:
+ self.items[-1] = item
+ self._count += 1
+
+ def Items(self):
+ """Get all the items in the bucket."""
+ with self._mutex:
+ return self.items
diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py
new file mode 100644
index 0000000000..46cbde5940
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir_test.py
@@ -0,0 +1,178 @@
+import tensorflow.python.platform
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import reservoir
+
+
+class ReservoirTest(googletest.TestCase):
+
+ def testEmptyReservoir(self):
+ r = reservoir.Reservoir(1)
+ self.assertFalse(r.Keys())
+
+ def testRespectsSize(self):
+ r = reservoir.Reservoir(42)
+ self.assertEqual(r._buckets['meaning of life']._max_size, 42)
+
+ def testItemsAndKeys(self):
+ r = reservoir.Reservoir(42)
+ r.AddItem('foo', 4)
+ r.AddItem('bar', 9)
+ r.AddItem('foo', 19)
+ self.assertItemsEqual(r.Keys(), ['foo', 'bar'])
+ self.assertEqual(r.Items('foo'), [4, 19])
+ self.assertEqual(r.Items('bar'), [9])
+
+ def testExceptions(self):
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(-1)
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(13.3)
+
+ r = reservoir.Reservoir(12)
+ with self.assertRaises(KeyError):
+ r.Items('missing key')
+
+ def testDeterminism(self):
+ """Tests that the reservoir is deterministic."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10)
+ r2 = reservoir.Reservoir(10)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+
+ self.assertEqual(r1.Items(key), r2.Items(key))
+
+ def testBucketDeterminism(self):
+ """Tests that reservoirs are deterministic at a bucket level.
+
+ This means that only the order elements are added within a bucket matters.
+ """
+ separate_reservoir = reservoir.Reservoir(10)
+ interleaved_reservoir = reservoir.Reservoir(10)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key1', i)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key2', i)
+ for i in xrange(100):
+ interleaved_reservoir.AddItem('key1', i)
+ interleaved_reservoir.AddItem('key2', i)
+
+ for key in ['key1', 'key2']:
+ self.assertEqual(separate_reservoir.Items(key),
+ interleaved_reservoir.Items(key))
+
+ def testUsesSeed(self):
+ """Tests that reservoirs with different seeds keep different samples."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10, seed=0)
+ r2 = reservoir.Reservoir(10, seed=1)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+ self.assertNotEqual(r1.Items(key), r2.Items(key))
+
+
+class ReservoirBucketTest(googletest.TestCase):
+
+ def testEmptyBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ self.assertFalse(b.Items())
+
+ def testFillToSize(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(100):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(100))
+
+ def testDoesntOverfill(self):
+ b = reservoir._ReservoirBucket(10)
+ for i in xrange(1000):
+ b.AddItem(i)
+ self.assertEqual(len(b.Items()), 10)
+
+ def testMaintainsOrder(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10000):
+ b.AddItem(i)
+ items = b.Items()
+ prev = None
+ for item in items:
+ self.assertTrue(item > prev)
+ prev = item
+
+ def testKeepsLatestItem(self):
+ b = reservoir._ReservoirBucket(5)
+ for i in xrange(100):
+ b.AddItem(i)
+ last = b.Items()[-1]
+ self.assertEqual(last, i)
+
+ def testSizeOneBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), [i])
+
+ def testSizeZeroBucket(self):
+ b = reservoir._ReservoirBucket(0)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(i+1))
+
+ def testSizeRequirement(self):
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(-1)
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(10.3)
+
+
+class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
+
+ def setUp(self):
+ self.total = 1000000
+ self.samples = 10000
+ self.n_buckets = 100
+ self.total_per_bucket = self.total / self.n_buckets
+ self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly '
+ 'divisible by the number of buckets')
+ self.assertTrue(self.total > self.samples, 'need to have more items '
+ 'than samples')
+
+ def AssertBinomialQuantity(self, measured):
+ p = 1.0 * self.n_buckets / self.samples
+ mean = p * self.samples
+ variance = p * (1 - p) * self.samples
+ error = measured - mean
+ # Given that the buckets were actually binomially distributed, this
+ # fails with probability ~2E-9
+ passed = error * error <= 36.0 * variance
+ self.assertTrue(passed, 'found a bucket with measured %d '
+ 'too far from expected %d' % (measured, mean))
+
+ def testBucketReservoirSamplingViaStatisticalProperties(self):
+ # Not related to a 'ReservoirBucket', but instead number of buckets we put
+ # samples into for testing the shape of the distribution
+ b = reservoir._ReservoirBucket(_max_size=self.samples)
+ # add one extra item because we always keep the most recent item, which
+ # would skew the distribution; we can just slice it off the end instead.
+ for i in xrange(self.total + 1):
+ b.AddItem(i)
+
+ divbins = [0] * self.n_buckets
+ modbins = [0] * self.n_buckets
+ # Slice off the last item when we iterate.
+ for item in b.Items()[0:-1]:
+ divbins[item / self.total_per_bucket] += 1
+ modbins[item % self.n_buckets] += 1
+
+ for bucket_index in xrange(self.n_buckets):
+ divbin = divbins[bucket_index]
+ modbin = modbins[bucket_index]
+ self.AssertBinomialQuantity(divbin)
+ self.AssertBinomialQuantity(modbin)
+
+
+if __name__ == '__main__':
+ googletest.main()