aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary/impl/reservoir.py
blob: 44b3b2a58cea111d5d52e909b4fa58a8bb285900 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A key-value[] store that implements reservoir sampling on the values."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import random
import threading


class Reservoir(object):
  """A map-to-arrays container, with deterministic Reservoir Sampling.

  Items are added with an associated key. Items may be retrieved by key, and
  a list of keys can also be retrieved. If size is not zero, then it dictates
  the maximum number of items that will be stored with each key. Once there are
  more items for a given key, they are replaced via reservoir sampling, such
  that each item has an equal probability of being included in the sample.

  Deterministic means that for any given seed and bucket size, the sequence of
  values that are kept for any given tag will always be the same, and that this
  is independent of any insertions on other tags. That is:

  >>> separate_reservoir = reservoir.Reservoir(10)
  >>> interleaved_reservoir = reservoir.Reservoir(10)
  >>> for i in xrange(100):
  >>>   separate_reservoir.AddItem('key1', i)
  >>> for i in xrange(100):
  >>>   separate_reservoir.AddItem('key2', i)
  >>> for i in xrange(100):
  >>>   interleaved_reservoir.AddItem('key1', i)
  >>>   interleaved_reservoir.AddItem('key2', i)

  separate_reservoir and interleaved_reservoir will be in identical states.

  See: https://en.wikipedia.org/wiki/Reservoir_sampling

  Adding items has amortized O(1) runtime.

  """

  def __init__(self, size, seed=0):
    """Creates a new reservoir.

    Args:
      size: The number of values to keep in the reservoir for each tag. If 0,
        all values will be kept.
      seed: The seed of the random number generator to use when sampling.
        Different values for |seed| will produce different samples from the same
        input items.

    Raises:
      ValueError: If size is negative or not an integer.
    """
    if size < 0 or size != round(size):
      raise ValueError('size must be nonegative integer, was %s' % size)
    self._buckets = collections.defaultdict(
        lambda: _ReservoirBucket(size, random.Random(seed)))
    # _mutex guards the keys - creating new keys, retreiving by key, etc
    # the internal items are guarded by the ReservoirBuckets' internal mutexes
    self._mutex = threading.Lock()

  def Keys(self):
    """Return all the keys in the reservoir.

    Returns:
      ['list', 'of', 'keys'] in the Reservoir.
    """
    with self._mutex:
      return list(self._buckets.keys())

  def Items(self, key):
    """Return items associated with given key.

    Args:
      key: The key for which we are finding associated items.

    Raises:
      KeyError: If the key is not found in the reservoir.

    Returns:
      [list, of, items] associated with that key.
    """
    with self._mutex:
      if key not in self._buckets:
        raise KeyError('Key %s was not found in Reservoir' % key)
      bucket = self._buckets[key]
    return bucket.Items()

  def AddItem(self, key, item):
    """Add a new item to the Reservoir with the given tag.

    The new item is guaranteed to be kept in the Reservoir. One other item might
    be replaced.

    Args:
      key: The key to store the item under.
      item: The item to add to the reservoir.
    """
    with self._mutex:
      bucket = self._buckets[key]
    bucket.AddItem(item)

  def FilterItems(self, filterFn):
    """Filter items within a Reservoir, using a filtering function.

    Args:
      filterFn: A function that returns True for the items to be kept.

    Returns:
      The number of items removed.
    """
    with self._mutex:
      return sum(bucket.FilterItems(filterFn)
                 for bucket in self._buckets.values())


class _ReservoirBucket(object):
  """A container for items from a stream, that implements reservoir sampling.

  It always stores the most recent item as its final item.
  """

  def __init__(self, _max_size, _random=None):
    """Create the _ReservoirBucket.

    Args:
      _max_size: The maximum size the reservoir bucket may grow to. If size is
        zero, the bucket has unbounded size.
      _random: The random number generator to use. If not specified, defaults to
        random.Random(0).

    Raises:
      ValueError: if the size is not a nonnegative integer.
    """
    if _max_size < 0 or _max_size != round(_max_size):
      raise ValueError('_max_size must be nonegative int, was %s' % _max_size)
    self.items = []
    # This mutex protects the internal items, ensuring that calls to Items and
    # AddItem are thread-safe
    self._mutex = threading.Lock()
    self._max_size = _max_size
    self._num_items_seen = 0
    if _random is not None:
      self._random = _random
    else:
      self._random = random.Random(0)

  def AddItem(self, item):
    """Add an item to the ReservoirBucket, replacing an old item if necessary.

    The new item is guaranteed to be added to the bucket, and to be the last
    element in the bucket. If the bucket has reached capacity, then an old item
    will be replaced. With probability (_max_size/_num_items_seen) a random item
    in the bucket will be popped out and the new item will be appended
    to the end. With probability (1 - _max_size/_num_items_seen)
    the last item in the bucket will be replaced.

    Since the O(n) replacements occur with O(1/_num_items_seen) likelihood,
    the amortized runtime is O(1).

    Args:
      item: The item to add to the bucket.
    """
    with self._mutex:
      if len(self.items) < self._max_size or self._max_size == 0:
        self.items.append(item)
      else:
        r = self._random.randint(0, self._num_items_seen)
        if r < self._max_size:
          self.items.pop(r)
          self.items.append(item)
        else:
          self.items[-1] = item
      self._num_items_seen += 1

  def FilterItems(self, filterFn):
    """Filter items in a ReservoirBucket, using a filtering function.

    Filtering items from the reservoir bucket must update the
    internal state variable self._num_items_seen, which is used for determining
    the rate of replacement in reservoir sampling. Ideally, self._num_items_seen
    would contain the exact number of items that have ever seen by the
    ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not
    have access to all items seen -- it only has access to the subset of items
    that have survived sampling (self.items). Therefore, we estimate
    self._num_items_seen by scaling it by the same ratio as the ratio of items
    not removed from self.items.

    Args:
      filterFn: A function that returns True for items to be kept.

    Returns:
      The number of items removed from the bucket.
    """
    with self._mutex:
      size_before = len(self.items)
      self.items = filter(filterFn, self.items)
      size_diff = size_before - len(self.items)

      # Estimate a correction the the number of items seen
      prop_remaining = len(self.items) / float(
          size_before) if size_before > 0 else 0
      self._num_items_seen = int(round(self._num_items_seen * prop_remaining))
      return size_diff

  def Items(self):
    """Get all the items in the bucket."""
    with self._mutex:
      return self.items