diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/lib/random/distribution_sampler.h |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/lib/random/distribution_sampler.h')
-rw-r--r-- | tensorflow/core/lib/random/distribution_sampler.h | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h new file mode 100644 index 0000000000..ab9598a205 --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler.h @@ -0,0 +1,79 @@ +// DistributionSampler allows generating a discrete random variable with a given +// distribution. +// The values taken by the variable are [0, N) and relative weights for each +// value are specified using a vector of size N. +// +// The Algorithm takes O(N) time to precompute data at construction time and +// takes O(1) time (2 random number generation, 2 lookups) for each sample. +// The data structure takes O(N) memory. +// +// In contrast, util/random/weighted-picker.h provides O(lg N) sampling. +// The advantage of that implementation is that weights can be adjusted +// dynamically, while DistributionSampler doesn't allow weight adjustment. +// +// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. + +#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ + +#include <memory> +#include <utility> +#include <vector> + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +class DistributionSampler { + public: + explicit DistributionSampler(const gtl::ArraySlice<float>& weights); + + ~DistributionSampler() {} + + int Sample(SimplePhilox* rand) const { + float r = rand->RandFloat(); + // Since n is typically low, we don't bother with UnbiasedUniform. + int idx = rand->Uniform(num_); + if (r < prob(idx)) return idx; + // else pick alt from that bucket. + DCHECK_NE(-1, alt(idx)); + return alt(idx); + } + + int num() const { return num_; } + + private: + float prob(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].first; + } + + int alt(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].second; + } + + void set_prob(int idx, float f) { + DCHECK_LT(idx, num_); + data_[idx].first = f; + } + + void set_alt(int idx, int val) { + DCHECK_LT(idx, num_); + data_[idx].second = val; + } + + int num_; + std::unique_ptr<std::pair<float, int>[]> data_; + + TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler); +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ |