aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/range_sampler.h
blob: 18364c2c03c4d33b0a91d9822a9fc94f3a962d29 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_

#include <vector>

#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/random/distribution_sampler.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/weighted_picker.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

class Env;

// Abstract subclass for sampling from the set of non-negative integers
// [0, range)
class RangeSampler {
 public:
  explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); }
  virtual ~RangeSampler();

  // Sample a single value
  virtual int64 Sample(random::SimplePhilox* rnd) const = 0;

  // The probability that a single call to Sample() returns the given value.
  // Assumes that value is in [0, range).  No range checking is done.
  virtual float Probability(int64 value) const = 0;

  // Fill "batch" with samples from the distribution.
  // If unique=true, then we re-pick each element until we get a
  // value distinct from all previously picked values in the batch.
  void SampleBatch(random::SimplePhilox* rnd, bool unique,
                   gtl::MutableArraySlice<int64> batch) const;

  // Fill "batch" with samples from the distribution, and report
  // "expected counts".
  //
  // The "expected count" of a value is an estimate of the expected
  // number of occurrences of the value in the batch returned by a
  // call to this function with the given parameters.  If unique=true,
  // the expected count is an inclusion probability.  For details on
  // this estimation, see the comment to "ExpectedCountHelper" in the
  // .cc file.
  //
  // Expected counts for the elements of the returned "batch" are reported
  // in the aligned array "batch_expected_count".
  //
  // The user can optionally provide "extras", containg values in the range.
  // The expected counts for the extras are reported in the aligned array
  // "extras_expected_count".
  //
  // "batch_expected_count" must have size equal to 0 or to the size of "batch".
  // "extras" and "extras_expected_count" must have equal size.
  void SampleBatchGetExpectedCount(
      random::SimplePhilox* rnd, bool unique,
      gtl::MutableArraySlice<int64> batch,
      gtl::MutableArraySlice<float> batch_expected_count,
      gtl::ArraySlice<int64> extras,
      gtl::MutableArraySlice<float> extras_expected_count) const;

  // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
  // We repick to avoid all of the values in "avoided_values".
  // "avoided_values" is only supported with unique=true.  If
  // unique=false, then avoided_values must be empty.
  virtual void SampleBatchGetExpectedCountAvoid(
      random::SimplePhilox* rnd, bool unique,
      gtl::MutableArraySlice<int64> batch,
      gtl::MutableArraySlice<float> batch_expected_count,
      gtl::ArraySlice<int64> extras,
      gtl::MutableArraySlice<float> extras_expected_count,
      gtl::ArraySlice<int64> avoided_values) const;

  // Does this sampler need to be updated with values, e.g. UnigramSampler
  virtual bool NeedsUpdates() const { return false; }

  // Updates the underlying distribution
  virtual void Update(gtl::ArraySlice<int64> values) {
    LOG(FATAL) << "Update not supported for this sampler type.";
  }

  int64 range() { return range_; }

 protected:
  const int64 range_;
};

// An AllSampler only samples batches of size equal to range.
// It returns the entire range.
// It cannot sample single values.
class AllSampler : public RangeSampler {
 public:
  explicit AllSampler(int64 range);

  ~AllSampler() override {}

  int64 Sample(random::SimplePhilox* rnd) const override {
    LOG(FATAL) << "Should not be called";
  }

  float Probability(int64 value) const override {
    LOG(FATAL) << "Should not be called";
  }

  void SampleBatchGetExpectedCountAvoid(
      random::SimplePhilox* rnd, bool unique,
      gtl::MutableArraySlice<int64> batch,
      gtl::MutableArraySlice<float> batch_expected_count,
      gtl::ArraySlice<int64> extras,
      gtl::MutableArraySlice<float> extras_expected_count,
      gtl::ArraySlice<int64> avoided_values) const override;

 private:
  const float inv_range_;
};

class UniformSampler : public RangeSampler {
 public:
  explicit UniformSampler(int64 range);

  ~UniformSampler() override {}

  int64 Sample(random::SimplePhilox* rnd) const override;

  float Probability(int64 value) const override;

 private:
  const float inv_range_;
};

class LogUniformSampler : public RangeSampler {
 public:
  explicit LogUniformSampler(int64 range);

  ~LogUniformSampler() override {}

  int64 Sample(random::SimplePhilox* rnd) const override;

  float Probability(int64 value) const override;

 private:
  const double log_range_;
};

// Thread-unsafe unigram sampler
class ThreadUnsafeUnigramSampler : public RangeSampler {
 public:
  explicit ThreadUnsafeUnigramSampler(int64 range);
  ~ThreadUnsafeUnigramSampler() override {}

  int64 Sample(random::SimplePhilox* rnd) const override;

  float Probability(int64 value) const override;

  bool NeedsUpdates() const override { return true; }
  void Update(gtl::ArraySlice<int64> values) override;

 private:
  random::WeightedPicker picker_;
};

// Thread-safe unigram sampler
class UnigramSampler : public RangeSampler {
 public:
  explicit UnigramSampler(int64 range);
  ~UnigramSampler() override {}

  int64 Sample(random::SimplePhilox* rnd) const override;

  float Probability(int64 value) const override;

  // Overriding at a high level results in far fewer lock aquisitions.
  void SampleBatchGetExpectedCountAvoid(
      random::SimplePhilox* rnd, bool unique,
      gtl::MutableArraySlice<int64> batch,
      gtl::MutableArraySlice<float> batch_expected_count,
      gtl::ArraySlice<int64> extras,
      gtl::MutableArraySlice<float> extras_expected_count,
      gtl::ArraySlice<int64> avoided_values) const override;

  bool NeedsUpdates() const override { return true; }
  void Update(gtl::ArraySlice<int64> values) override;

 private:
  ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
  mutable mutex mu_;
};

// A unigram sampler that uses a fixed unigram distribution read from a
// file or passed in as an in-memory array instead of building up the
// distribution from data on the fly. There is also an option to skew the
// distribution by applying a distortion power to the weights.
class FixedUnigramSampler : public RangeSampler {
 public:
  // The vocab_file is assumed to be a CSV, with the last entry of each row a
  // value representing the counts or probabilities for the corresponding ID.
  FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
                      float distortion, int32 num_reserved_ids,
                      int32 num_shards, int32 shard);

  FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
                      float distortion, int32 num_reserved_ids,
                      int32 num_shards, int32 shard);

  float Probability(int64 value) const override;

  int64 Sample(random::SimplePhilox* rnd) const override;

 private:
  // Underlying distribution sampler.
  std::unique_ptr<random::DistributionSampler> dist_sampler_;
  // Weights for individual samples. The probability of a sample i is defined
  // as weights_.at(i) / total_weight_.
  std::vector<float> weights_;
  // The total weights of all samples.
  float total_weight_;
  // Sharding information of the sampler. The whole vocabulary is sharded
  // into num_shards_ smaller ranges and each sampler is responsible for one
  // such smaller range, identified by the shard number.
  int32 num_shards_;
  int32 shard_;

  // Fill the sampler with the appropriate number of reserved IDs.
  void FillReservedIds(int32 num_reserved_ids);
  // Load IDs to sample from a CSV file. It is assumed that the last item of
  // each row contains a count or probability for the corresponding ID.
  Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
  // Load from an in-memory array.
  void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_