aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary/impl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/summary/impl')
-rwxr-xr-xtensorflow/python/summary/impl/__init__.py0
-rw-r--r--tensorflow/python/summary/impl/directory_watcher.py115
-rw-r--r--tensorflow/python/summary/impl/directory_watcher_test.py102
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py49
-rw-r--r--tensorflow/python/summary/impl/event_file_loader_test.py59
-rw-r--r--tensorflow/python/summary/impl/reservoir.py164
-rw-r--r--tensorflow/python/summary/impl/reservoir_test.py178
7 files changed, 667 insertions, 0 deletions
diff --git a/tensorflow/python/summary/impl/__init__.py b/tensorflow/python/summary/impl/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/summary/impl/__init__.py
diff --git a/tensorflow/python/summary/impl/directory_watcher.py b/tensorflow/python/summary/impl/directory_watcher.py
new file mode 100644
index 0000000000..830e538cb6
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher.py
@@ -0,0 +1,115 @@
+"""Contains the implementation for the DirectoryWatcher class."""
+import os
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+
+
+class DirectoryWatcher(object):
+ """A DirectoryWatcher wraps a loader to load from a directory.
+
+ A loader reads a file on disk and produces some kind of values as an
+ iterator. A DirectoryWatcher takes a directory with one file at a time being
+ written to and a factory for loaders and watches all the files at once.
+
+ This class is *only* valid under the assumption that files are never removed
+ and the only file ever changed is whichever one is lexicographically last.
+ """
+
+ def __init__(self, directory, loader_factory, path_filter=lambda x: True):
+ """Constructs a new DirectoryWatcher.
+
+ Args:
+ directory: The directory to watch. The directory doesn't have to exist.
+ loader_factory: A factory for creating loaders. The factory should take a
+ file path and return an object that has a Load method returning an
+ iterator that will yield all events that have not been yielded yet.
+ path_filter: Only files whose full path matches this predicate will be
+ loaded. If not specified, all files are loaded.
+
+ Raises:
+ ValueError: If directory or loader_factory is None.
+ """
+ if directory is None:
+ raise ValueError('A directory is required')
+ if loader_factory is None:
+ raise ValueError('A loader factory is required')
+ self._directory = directory
+ self._loader_factory = loader_factory
+ self._loader = None
+ self._path = None
+ self._path_filter = path_filter
+
+ def Load(self):
+ """Loads new values from disk.
+
+ The watcher will load from one file at a time; as soon as that file stops
+ yielding events, it will move on to the next file. We assume that old files
+ are never modified after a newer file has been written. As a result, Load()
+ can be called multiple times in a row without losing events that have not
+ been yielded yet. In other words, we guarantee that every event will be
+ yielded exactly once.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+
+ # If the loader exists, check it for a value.
+ if not self._loader:
+ self._InitializeLoader()
+
+ while True:
+ # Yield all the new events in the file we're currently loading from.
+ for event in self._loader.Load():
+ yield event
+
+ next_path = self._GetNextPath()
+ if not next_path:
+ logging.info('No more files in %s', self._directory)
+ # Current file is empty and there are no new files, so we're done.
+ return
+
+ # There's a new file, so check to make sure there weren't any events
+ # written between when we finished reading the current file and when we
+ # checked for the new one. The sequence of events might look something
+ # like this:
+ #
+ # 1. Event #1 written to file #1.
+ # 2. We check for events and yield event #1 from file #1
+ # 3. We check for events and see that there are no more events in file #1.
+ # 4. Event #2 is written to file #1.
+ # 5. Event #3 is written to file #2.
+ # 6. We check for a new file and see that file #2 exists.
+ #
+ # Without this loop, we would miss event #2. We're also guaranteed by the
+ # loader contract that no more events will be written to file #1 after
+ # events start being written to file #2, so we don't have to worry about
+ # that.
+ for event in self._loader.Load():
+ yield event
+
+ logging.info('Directory watcher for %s advancing to file %s',
+ self._directory, next_path)
+
+ # Advance to the next file and start over.
+ self._SetPath(next_path)
+
+ def _InitializeLoader(self):
+ path = self._GetNextPath()
+ if path:
+ self._SetPath(path)
+ else:
+ raise StopIteration
+
+ def _SetPath(self, path):
+ self._path = path
+ self._loader = self._loader_factory(path)
+
+ def _GetNextPath(self):
+ """Returns the path of the next file to use or None if no file exists."""
+ sorted_paths = [os.path.join(self._directory, path)
+ for path in sorted(gfile.ListDirectory(self._directory))]
+ # We filter here so the filter gets the full directory name.
+ filtered_paths = (path for path in sorted_paths
+ if self._path_filter(path) and path > self._path)
+ return next(filtered_paths, None)
diff --git a/tensorflow/python/summary/impl/directory_watcher_test.py b/tensorflow/python/summary/impl/directory_watcher_test.py
new file mode 100644
index 0000000000..a22e3f2922
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher_test.py
@@ -0,0 +1,102 @@
+"""Tests for directory_watcher."""
+
+import os
+import shutil
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import directory_watcher
+
+
+class _ByteLoader(object):
+ """A loader that loads individual bytes from a file."""
+
+ def __init__(self, path):
+ self._f = open(path)
+
+ def Load(self):
+ while True:
+ byte = self._f.read(1)
+ if byte:
+ yield byte
+ else:
+ return
+
+
+class DirectoryWatcherTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ # Put everything in a directory so it's easier to delete.
+ self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir')
+ os.mkdir(self._directory)
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader)
+
+ def tearDown(self):
+ shutil.rmtree(self._directory)
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self._directory, filename)
+ with open(path, 'a') as f:
+ f.write(data)
+
+ def assertWatcherYields(self, values):
+ self.assertEqual(list(self._watcher.Load()), values)
+
+ def testRaisesWithBadArguments(self):
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher(None, lambda x: [])
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher('asdf', None)
+
+ def testEmptyDirectory(self):
+ self.assertWatcherYields([])
+
+ def testSingleWrite(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+
+ def testMultipleWrites(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+ self._WriteToFile('a', 'xyz')
+ self.assertWatcherYields(['x', 'y', 'z'])
+
+ def testMultipleLoads(self):
+ self._WriteToFile('a', 'a')
+ self._watcher.Load()
+ self._watcher.Load()
+ self.assertWatcherYields(['a'])
+
+ def testMultipleFilesAtOnce(self):
+ self._WriteToFile('b', 'b')
+ self._WriteToFile('a', 'a')
+ self.assertWatcherYields(['a', 'b'])
+
+ def testFinishesLoadingFileWhenSwitchingToNewFile(self):
+ self._WriteToFile('a', 'a')
+ # Empty the iterator.
+ self.assertEquals(['a'], list(self._watcher.Load()))
+ self._WriteToFile('a', 'b')
+ self._WriteToFile('b', 'c')
+ # The watcher should finish its current file before starting a new one.
+ self.assertWatcherYields(['b', 'c'])
+
+ def testIntermediateEmptyFiles(self):
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('b', '')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+ def testFileFilter(self):
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader,
+ path_filter=lambda path: 'do_not_watch_me' not in path)
+
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('do_not_watch_me', 'b')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
new file mode 100644
index 0000000000..0571bc84cb
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -0,0 +1,49 @@
+"""Functionality for loading events from a record file."""
+
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.platform import app
+from tensorflow.python.platform import logging
+
+
+class EventFileLoader(object):
+ """An EventLoader is an iterator that yields Event protos."""
+
+ def __init__(self, file_path):
+ if file_path is None:
+ raise ValueError('A file path is required')
+ logging.debug('Opening a record reader pointing at %s', file_path)
+ self._reader = pywrap_tensorflow.PyRecordReader_New(file_path, 0)
+ # Store it for logging purposes.
+ self._file_path = file_path
+ if not self._reader:
+ raise IOError('Failed to open a record reader pointing to %s' % file_path)
+
+ def Load(self):
+ """Loads all new values from disk.
+
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+ while self._reader.GetNext():
+ logging.debug('Got an event from %s', self._file_path)
+ event = event_pb2.Event()
+ event.ParseFromString(self._reader.record())
+ yield event
+ logging.debug('No more events in %s', self._file_path)
+
+
+def main(argv):
+ if len(argv) != 2:
+ print 'Usage: event_file_loader <path-to-the-recordio-file>'
+ return 1
+ loader = EventFileLoader(argv[1])
+ for event in loader.Load():
+ print event
+
+
+if __name__ == '__main__':
+ app.run()
diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py
new file mode 100644
index 0000000000..1dc29d85d5
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader_test.py
@@ -0,0 +1,59 @@
+"""Tests for event_file_loader."""
+
+import os
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import event_file_loader
+
+
+class EventFileLoaderTest(test_util.TensorFlowTestCase):
+ # A record containing a simple event.
+ RECORD = ('\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu'
+ '\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d')
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self.get_temp_dir(), filename)
+ with open(path, 'ab') as f:
+ f.write(data)
+
+ def _LoaderForTestFile(self, filename):
+ return event_file_loader.EventFileLoader(
+ os.path.join(self.get_temp_dir(), filename))
+
+ def testEmptyEventFile(self):
+ self._WriteToFile('empty_event_file', '')
+ loader = self._LoaderForTestFile('empty_event_file')
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testSingleWrite(self):
+ self._WriteToFile('single_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('single_event_file')
+ events = list(loader.Load())
+ self.assertEquals(len(events), 1)
+ self.assertEquals(events[0].wall_time, 1440183447.0)
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testMultipleWrites(self):
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 1)
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleLoads(self):
+ self._WriteToFile('multiple_loads_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('multiple_loads_event_file')
+ loader.Load()
+ loader.Load()
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleWritesAtOnce(self):
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 2)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py
new file mode 100644
index 0000000000..2c9b294841
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir.py
@@ -0,0 +1,164 @@
+"""A key-value[] store that implements reservoir sampling on the values."""
+
+import collections
+import random
+import threading
+
+
+class Reservoir(object):
+ """A map-to-arrays container, with deterministic Reservoir Sampling.
+
+ Items are added with an associated key. Items may be retrieved by key, and
+ a list of keys can also be retrieved. If size is not zero, then it dictates
+ the maximum number of items that will be stored with each key. Once there are
+ more items for a given key, they are replaced via reservoir sampling, such
+ that each item has an equal probability of being included in the sample.
+
+ Deterministic means that for any given seed and bucket size, the sequence of
+ values that are kept for any given tag will always be the same, and that this
+ is independent of any insertions on other tags. That is:
+
+ >>> separate_reservoir = reservoir.Reservoir(10)
+ >>> interleaved_reservoir = reservoir.Reservoir(10)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key1', i)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key2', i)
+ >>> for i in xrange(100):
+ >>> interleaved_reservoir.AddItem('key1', i)
+ >>> interleaved_reservoir.AddItem('key2', i)
+
+ separate_reservoir and interleaved_reservoir will be in identical states.
+
+ See: https://en.wikipedia.org/wiki/Reservoir_sampling
+
+ Adding items has amortized O(1) runtime.
+
+ """
+
+ def __init__(self, size, seed=0):
+ """Creates a new reservoir.
+
+ Args:
+ size: The number of values to keep in the reservoir for each tag. If 0,
+ all values will be kept.
+ seed: The seed of the random number generator to use when sampling.
+ Different values for |seed| will produce different samples from the same
+ input items.
+
+ Raises:
+ ValueError: If size is negative or not an integer.
+ """
+ if size < 0 or size != round(size):
+ raise ValueError('size must be nonegative integer, was %s' % size)
+ self._buckets = collections.defaultdict(
+ lambda: _ReservoirBucket(size, random.Random(seed)))
+ # _mutex guards the keys - creating new keys, retreiving by key, etc
+ # the internal items are guarded by the ReservoirBuckets' internal mutexes
+ self._mutex = threading.Lock()
+
+ def Keys(self):
+ """Return all the keys in the reservoir.
+
+ Returns:
+ ['list', 'of', 'keys'] in the Reservoir.
+ """
+ with self._mutex:
+ return self._buckets.keys()
+
+ def Items(self, key):
+ """Return items associated with given key.
+
+ Args:
+ key: The key for which we are finding associated items.
+
+ Raises:
+ KeyError: If the key is not ofund in the reservoir.
+
+ Returns:
+ [list, of, items] associated with that key.
+ """
+ with self._mutex:
+ if key not in self._buckets:
+ raise KeyError('Key %s was not found in Reservoir' % key)
+ bucket = self._buckets[key]
+ return bucket.Items()
+
+ def AddItem(self, key, item):
+ """Add a new item to the Reservoir with the given tag.
+
+ The new item is guaranteed to be kept in the Reservoir. One other item might
+ be replaced.
+
+ Args:
+ key: The key to store the item under.
+ item: The item to add to the reservoir.
+ """
+ with self._mutex:
+ bucket = self._buckets[key]
+ bucket.AddItem(item)
+
+
+class _ReservoirBucket(object):
+ """A container for items from a stream, that implements reservoir sampling.
+
+ It always stores the most recent item as its final item.
+ """
+
+ def __init__(self, _max_size, _random=None):
+ """Create the _ReservoirBucket.
+
+ Args:
+ _max_size: The maximum size the reservoir bucket may grow to. If size is
+ zero, the bucket has unbounded size.
+ _random: The random number generator to use. If not specified, defaults to
+ random.Random(0).
+
+ Raises:
+ ValueError: if the size is not a nonnegative integer.
+ """
+ if _max_size < 0 or _max_size != round(_max_size):
+ raise ValueError('_max_size must be nonegative int, was %s' % _max_size)
+ self.items = []
+ # This mutex protects the internal items, ensuring that calls to Items and
+ # AddItem are thread-safe
+ self._mutex = threading.Lock()
+ self._max_size = _max_size
+ self._count = 0
+ if _random is not None:
+ self._random = _random
+ else:
+ self._random = random.Random(0)
+
+ def AddItem(self, item):
+ """Add an item to the ReservoirBucket, replacing an old item if necessary.
+
+ The new item is guaranteed to be added to the bucket, and to be the last
+ element in the bucket. If the bucket has reached capacity, then an old item
+ will be replaced. With probability (_max_size/_count) a random item in the
+ bucket will be popped out and the new item will be appended to the end. With
+ probability (1 - _max_size/_count) the last item in the bucket will be
+ replaced.
+
+ Since the O(n) replacements occur with O(1/_count) liklihood, the amortized
+ runtime is O(1).
+
+ Args:
+ item: The item to add to the bucket.
+ """
+ with self._mutex:
+ if len(self.items) < self._max_size or self._max_size == 0:
+ self.items.append(item)
+ else:
+ r = self._random.randint(0, self._count)
+ if r < self._max_size:
+ self.items.pop(r)
+ self.items.append(item)
+ else:
+ self.items[-1] = item
+ self._count += 1
+
+ def Items(self):
+ """Get all the items in the bucket."""
+ with self._mutex:
+ return self.items
diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py
new file mode 100644
index 0000000000..46cbde5940
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir_test.py
@@ -0,0 +1,178 @@
+import tensorflow.python.platform
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import reservoir
+
+
+class ReservoirTest(googletest.TestCase):
+
+ def testEmptyReservoir(self):
+ r = reservoir.Reservoir(1)
+ self.assertFalse(r.Keys())
+
+ def testRespectsSize(self):
+ r = reservoir.Reservoir(42)
+ self.assertEqual(r._buckets['meaning of life']._max_size, 42)
+
+ def testItemsAndKeys(self):
+ r = reservoir.Reservoir(42)
+ r.AddItem('foo', 4)
+ r.AddItem('bar', 9)
+ r.AddItem('foo', 19)
+ self.assertItemsEqual(r.Keys(), ['foo', 'bar'])
+ self.assertEqual(r.Items('foo'), [4, 19])
+ self.assertEqual(r.Items('bar'), [9])
+
+ def testExceptions(self):
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(-1)
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(13.3)
+
+ r = reservoir.Reservoir(12)
+ with self.assertRaises(KeyError):
+ r.Items('missing key')
+
+ def testDeterminism(self):
+ """Tests that the reservoir is deterministic."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10)
+ r2 = reservoir.Reservoir(10)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+
+ self.assertEqual(r1.Items(key), r2.Items(key))
+
+ def testBucketDeterminism(self):
+ """Tests that reservoirs are deterministic at a bucket level.
+
+ This means that only the order elements are added within a bucket matters.
+ """
+ separate_reservoir = reservoir.Reservoir(10)
+ interleaved_reservoir = reservoir.Reservoir(10)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key1', i)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key2', i)
+ for i in xrange(100):
+ interleaved_reservoir.AddItem('key1', i)
+ interleaved_reservoir.AddItem('key2', i)
+
+ for key in ['key1', 'key2']:
+ self.assertEqual(separate_reservoir.Items(key),
+ interleaved_reservoir.Items(key))
+
+ def testUsesSeed(self):
+ """Tests that reservoirs with different seeds keep different samples."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10, seed=0)
+ r2 = reservoir.Reservoir(10, seed=1)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+ self.assertNotEqual(r1.Items(key), r2.Items(key))
+
+
+class ReservoirBucketTest(googletest.TestCase):
+
+ def testEmptyBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ self.assertFalse(b.Items())
+
+ def testFillToSize(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(100):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(100))
+
+ def testDoesntOverfill(self):
+ b = reservoir._ReservoirBucket(10)
+ for i in xrange(1000):
+ b.AddItem(i)
+ self.assertEqual(len(b.Items()), 10)
+
+ def testMaintainsOrder(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10000):
+ b.AddItem(i)
+ items = b.Items()
+ prev = None
+ for item in items:
+ self.assertTrue(item > prev)
+ prev = item
+
+ def testKeepsLatestItem(self):
+ b = reservoir._ReservoirBucket(5)
+ for i in xrange(100):
+ b.AddItem(i)
+ last = b.Items()[-1]
+ self.assertEqual(last, i)
+
+ def testSizeOneBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), [i])
+
+ def testSizeZeroBucket(self):
+ b = reservoir._ReservoirBucket(0)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(i+1))
+
+ def testSizeRequirement(self):
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(-1)
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(10.3)
+
+
+class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
+
+ def setUp(self):
+ self.total = 1000000
+ self.samples = 10000
+ self.n_buckets = 100
+ self.total_per_bucket = self.total / self.n_buckets
+ self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly '
+ 'divisible by the number of buckets')
+ self.assertTrue(self.total > self.samples, 'need to have more items '
+ 'than samples')
+
+ def AssertBinomialQuantity(self, measured):
+ p = 1.0 * self.n_buckets / self.samples
+ mean = p * self.samples
+ variance = p * (1 - p) * self.samples
+ error = measured - mean
+ # Given that the buckets were actually binomially distributed, this
+ # fails with probability ~2E-9
+ passed = error * error <= 36.0 * variance
+ self.assertTrue(passed, 'found a bucket with measured %d '
+ 'too far from expected %d' % (measured, mean))
+
+ def testBucketReservoirSamplingViaStatisticalProperties(self):
+ # Not related to a 'ReservoirBucket', but instead number of buckets we put
+ # samples into for testing the shape of the distribution
+ b = reservoir._ReservoirBucket(_max_size=self.samples)
+ # add one extra item because we always keep the most recent item, which
+ # would skew the distribution; we can just slice it off the end instead.
+ for i in xrange(self.total + 1):
+ b.AddItem(i)
+
+ divbins = [0] * self.n_buckets
+ modbins = [0] * self.n_buckets
+ # Slice off the last item when we iterate.
+ for item in b.Items()[0:-1]:
+ divbins[item / self.total_per_bucket] += 1
+ modbins[item % self.n_buckets] += 1
+
+ for bucket_index in xrange(self.n_buckets):
+ divbin = divbins[bucket_index]
+ modbin = modbins[bucket_index]
+ self.AssertBinomialQuantity(divbin)
+ self.AssertBinomialQuantity(modbin)
+
+
+if __name__ == '__main__':
+ googletest.main()