#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ #define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ #include #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/weighted_picker.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/status.h" namespace tensorflow { class Env; // Abstract subclass for sampling from the set of non-negative integers // [0, range) class RangeSampler { public: explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); } virtual ~RangeSampler(); // Sample a single value virtual int64 Sample(random::SimplePhilox* rnd) const = 0; // The probability that a single call to Sample() returns the given value. // Assumes that value is in [0, range). No range checking is done. virtual float Probability(int64 value) const = 0; // Fill "batch" with samples from the distribution. // If unique=true, then we re-pick each element until we get a // value distinct from all previously picked values in the batch. void SampleBatch(random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch) const; // Fill "batch" with samples from the distribution, and report // "expected counts". // // The "expected count" of a value is an estimate of the expected // number of occurrences of the value in the batch returned by a // call to this function with the given parameters. If unique=true, // the expected count is an inclusion probability. For details on // this estimation, see the comment to "ExpectedCountHelper" in the // .cc file. // // Expected counts for the elements of the returned "batch" are reported // in the aligned array "batch_expected_count". // // The user can optionally provide "extras", containg values in the range. // The expected counts for the extras are reported in the aligned array // "extras_expected_count". // // "batch_expected_count" must have size equal to 0 or to the size of "batch". // "extras" and "extras_expected_count" must have equal size. void SampleBatchGetExpectedCount( random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch, gtl::MutableArraySlice batch_expected_count, gtl::ArraySlice extras, gtl::MutableArraySlice extras_expected_count) const; // Same as SampleBatchGetExpectedCount (see above), but with avoided values. // We repick to avoid all of the values in "avoided_values". // "avoided_values" is only supported with unique=true. If // unique=false, then avoided_values must be empty. virtual void SampleBatchGetExpectedCountAvoid( random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch, gtl::MutableArraySlice batch_expected_count, gtl::ArraySlice extras, gtl::MutableArraySlice extras_expected_count, gtl::ArraySlice avoided_values) const; // Does this sampler need to be updated with values, e.g. UnigramSampler virtual bool NeedsUpdates() const { return false; } // Updates the underlying distribution virtual void Update(gtl::ArraySlice values) { LOG(FATAL) << "Update not supported for this sampler type."; } int64 range() { return range_; } protected: const int64 range_; }; // An AllSampler only samples batches of size equal to range. // It returns the entire range. // It cannot sample single values. class AllSampler : public RangeSampler { public: explicit AllSampler(int64 range); ~AllSampler() override {} int64 Sample(random::SimplePhilox* rnd) const override { LOG(FATAL) << "Should not be called"; } float Probability(int64 value) const override { LOG(FATAL) << "Should not be called"; } void SampleBatchGetExpectedCountAvoid( random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch, gtl::MutableArraySlice batch_expected_count, gtl::ArraySlice extras, gtl::MutableArraySlice extras_expected_count, gtl::ArraySlice avoided_values) const override; private: const float inv_range_; }; class UniformSampler : public RangeSampler { public: explicit UniformSampler(int64 range); ~UniformSampler() override {} int64 Sample(random::SimplePhilox* rnd) const override; float Probability(int64 value) const override; private: const float inv_range_; }; class LogUniformSampler : public RangeSampler { public: explicit LogUniformSampler(int64 range); ~LogUniformSampler() override {} int64 Sample(random::SimplePhilox* rnd) const override; float Probability(int64 value) const override; private: const double log_range_; }; // Thread-unsafe unigram sampler class ThreadUnsafeUnigramSampler : public RangeSampler { public: explicit ThreadUnsafeUnigramSampler(int64 range); ~ThreadUnsafeUnigramSampler() override {} int64 Sample(random::SimplePhilox* rnd) const override; float Probability(int64 value) const override; bool NeedsUpdates() const override { return true; } void Update(gtl::ArraySlice values) override; private: random::WeightedPicker picker_; }; // Thread-safe unigram sampler class UnigramSampler : public RangeSampler { public: explicit UnigramSampler(int64 range); ~UnigramSampler() override {} int64 Sample(random::SimplePhilox* rnd) const override; float Probability(int64 value) const override; // Overriding at a high level results in far fewer lock aquisitions. void SampleBatchGetExpectedCountAvoid( random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch, gtl::MutableArraySlice batch_expected_count, gtl::ArraySlice extras, gtl::MutableArraySlice extras_expected_count, gtl::ArraySlice avoided_values) const override; bool NeedsUpdates() const override { return true; } void Update(gtl::ArraySlice values) override; private: ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_); mutable mutex mu_; }; // A unigram sampler that uses a fixed unigram distribution read from a // file or passed in as an in-memory array instead of building up the // distribution from data on the fly. There is also an option to skew the // distribution by applying a distortion power to the weights. class FixedUnigramSampler : public RangeSampler { public: // The vocab_file is assumed to be a CSV, with the last entry of each row a // value representing the counts or probabilities for the corresponding ID. FixedUnigramSampler(Env* env, int64 range, const string& vocab_file, float distortion, int32 num_reserved_ids, int32 num_shards, int32 shard); FixedUnigramSampler(int64 range, const std::vector& unigrams, float distortion, int32 num_reserved_ids, int32 num_shards, int32 shard); float Probability(int64 value) const override; int64 Sample(random::SimplePhilox* rnd) const override; private: // Underlying distribution sampler. std::unique_ptr dist_sampler_; // Weights for individual samples. The probability of a sample i is defined // as weights_.at(i) / total_weight_. std::vector weights_; // The total weights of all samples. float total_weight_; // Sharding information of the sampler. The whole vocabulary is sharded // into num_shards_ smaller ranges and each sampler is responsible for one // such smaller range, identified by the shard number. int32 num_shards_; int32 shard_; // Fill the sampler with the appropriate number of reserved IDs. void FillReservedIds(int32 num_reserved_ids); // Load IDs to sample from a CSV file. It is assumed that the last item of // each row contains a count or probability for the corresponding ID. Status LoadFromFile(Env* env, const string& vocab_file, float distortion); // Load from an in-memory array. void LoadFromUnigrams(const std::vector& unigrams, float distortion); }; } // namespace tensorflow #endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_