aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/range_sampler.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/range_sampler.h')
-rw-r--r--tensorflow/core/kernels/range_sampler.h237
1 files changed, 237 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
new file mode 100644
index 0000000000..18364c2c03
--- /dev/null
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -0,0 +1,237 @@
+#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/lib/random/weighted_picker.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Env;
+
+// Abstract subclass for sampling from the set of non-negative integers
+// [0, range)
+class RangeSampler {
+ public:
+ explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); }
+ virtual ~RangeSampler();
+
+ // Sample a single value
+ virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
+
+ // The probability that a single call to Sample() returns the given value.
+ // Assumes that value is in [0, range). No range checking is done.
+ virtual float Probability(int64 value) const = 0;
+
+ // Fill "batch" with samples from the distribution.
+ // If unique=true, then we re-pick each element until we get a
+ // value distinct from all previously picked values in the batch.
+ void SampleBatch(random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch) const;
+
+ // Fill "batch" with samples from the distribution, and report
+ // "expected counts".
+ //
+ // The "expected count" of a value is an estimate of the expected
+ // number of occurrences of the value in the batch returned by a
+ // call to this function with the given parameters. If unique=true,
+ // the expected count is an inclusion probability. For details on
+ // this estimation, see the comment to "ExpectedCountHelper" in the
+ // .cc file.
+ //
+ // Expected counts for the elements of the returned "batch" are reported
+ // in the aligned array "batch_expected_count".
+ //
+ // The user can optionally provide "extras", containg values in the range.
+ // The expected counts for the extras are reported in the aligned array
+ // "extras_expected_count".
+ //
+ // "batch_expected_count" must have size equal to 0 or to the size of "batch".
+ // "extras" and "extras_expected_count" must have equal size.
+ void SampleBatchGetExpectedCount(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count) const;
+
+ // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
+ // We repick to avoid all of the values in "avoided_values".
+ // "avoided_values" is only supported with unique=true. If
+ // unique=false, then avoided_values must be empty.
+ virtual void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const;
+
+ // Does this sampler need to be updated with values, e.g. UnigramSampler
+ virtual bool NeedsUpdates() const { return false; }
+
+ // Updates the underlying distribution
+ virtual void Update(gtl::ArraySlice<int64> values) {
+ LOG(FATAL) << "Update not supported for this sampler type.";
+ }
+
+ int64 range() { return range_; }
+
+ protected:
+ const int64 range_;
+};
+
+// An AllSampler only samples batches of size equal to range.
+// It returns the entire range.
+// It cannot sample single values.
+class AllSampler : public RangeSampler {
+ public:
+ explicit AllSampler(int64 range);
+
+ ~AllSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override {
+ LOG(FATAL) << "Should not be called";
+ }
+
+ float Probability(int64 value) const override {
+ LOG(FATAL) << "Should not be called";
+ }
+
+ void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const override;
+
+ private:
+ const float inv_range_;
+};
+
+class UniformSampler : public RangeSampler {
+ public:
+ explicit UniformSampler(int64 range);
+
+ ~UniformSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ private:
+ const float inv_range_;
+};
+
+class LogUniformSampler : public RangeSampler {
+ public:
+ explicit LogUniformSampler(int64 range);
+
+ ~LogUniformSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ private:
+ const double log_range_;
+};
+
+// Thread-unsafe unigram sampler
+class ThreadUnsafeUnigramSampler : public RangeSampler {
+ public:
+ explicit ThreadUnsafeUnigramSampler(int64 range);
+ ~ThreadUnsafeUnigramSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ bool NeedsUpdates() const override { return true; }
+ void Update(gtl::ArraySlice<int64> values) override;
+
+ private:
+ random::WeightedPicker picker_;
+};
+
+// Thread-safe unigram sampler
+class UnigramSampler : public RangeSampler {
+ public:
+ explicit UnigramSampler(int64 range);
+ ~UnigramSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ // Overriding at a high level results in far fewer lock aquisitions.
+ void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const override;
+
+ bool NeedsUpdates() const override { return true; }
+ void Update(gtl::ArraySlice<int64> values) override;
+
+ private:
+ ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
+ mutable mutex mu_;
+};
+
+// A unigram sampler that uses a fixed unigram distribution read from a
+// file or passed in as an in-memory array instead of building up the
+// distribution from data on the fly. There is also an option to skew the
+// distribution by applying a distortion power to the weights.
+class FixedUnigramSampler : public RangeSampler {
+ public:
+ // The vocab_file is assumed to be a CSV, with the last entry of each row a
+ // value representing the counts or probabilities for the corresponding ID.
+ FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
+ float distortion, int32 num_reserved_ids,
+ int32 num_shards, int32 shard);
+
+ FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
+ float distortion, int32 num_reserved_ids,
+ int32 num_shards, int32 shard);
+
+ float Probability(int64 value) const override;
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ private:
+ // Underlying distribution sampler.
+ std::unique_ptr<random::DistributionSampler> dist_sampler_;
+ // Weights for individual samples. The probability of a sample i is defined
+ // as weights_.at(i) / total_weight_.
+ std::vector<float> weights_;
+ // The total weights of all samples.
+ float total_weight_;
+ // Sharding information of the sampler. The whole vocabulary is sharded
+ // into num_shards_ smaller ranges and each sampler is responsible for one
+ // such smaller range, identified by the shard number.
+ int32 num_shards_;
+ int32 shard_;
+
+ // Fill the sampler with the appropriate number of reserved IDs.
+ void FillReservedIds(int32 num_reserved_ids);
+ // Load IDs to sample from a CSV file. It is assumed that the last item of
+ // each row contains a count or probability for the corresponding ID.
+ Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
+ // Load from an in-memory array.
+ void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_