aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/weighted_picker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random/weighted_picker.cc')
-rw-r--r--tensorflow/core/lib/random/weighted_picker.cc203
1 files changed, 203 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc
new file mode 100644
index 0000000000..f96da578ec
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.cc
@@ -0,0 +1,203 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+namespace random {
+
+WeightedPicker::WeightedPicker(int N) {
+ CHECK_GE(N, 0);
+ N_ = N;
+
+ // Find the number of levels
+ num_levels_ = 1;
+ while (LevelSize(num_levels_ - 1) < N) {
+ num_levels_++;
+ }
+
+ // Initialize the levels
+ level_ = new int32*[num_levels_];
+ for (int l = 0; l < num_levels_; l++) {
+ level_[l] = new int32[LevelSize(l)];
+ }
+
+ SetAllWeights(1);
+}
+
+WeightedPicker::~WeightedPicker() {
+ for (int l = 0; l < num_levels_; l++) {
+ delete[] level_[l];
+ }
+ delete[] level_;
+}
+
+static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
+ CHECK_LE(0, n);
+ const uint32 range = ~static_cast<uint32>(0);
+ if (n == 0) {
+ return r->Rand32() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return r->Rand32() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // Rand32's output is uniform in the half-open interval [0, 2^{32}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ uint32 rem = (range % n) + 1;
+ uint32 rnd;
+
+ // rem = ((2^{32}-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = r->Rand32(); // rnd uniform over [0, 2^{32})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{32})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
+ // = 2^{32}-1 - ((2^{32}-1) \bmod n)
+ // = n \cdot \lfloor (2^{32}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{32}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+int WeightedPicker::Pick(SimplePhilox* rnd) const {
+ if (total_weight() == 0) return -1;
+
+ // using unbiased uniform distribution to avoid bias
+ // toward low elements resulting from a possible use
+ // of big weights.
+ return PickAt(UnbiasedUniform(rnd, total_weight()));
+}
+
+int WeightedPicker::PickAt(int32 weight_index) const {
+ if (weight_index < 0 || weight_index >= total_weight()) return -1;
+
+ int32 position = weight_index;
+ int index = 0;
+
+ for (int l = 1; l < num_levels_; l++) {
+ // Pick left or right child of "level_[l-1][index]"
+ const int32 left_weight = level_[l][2 * index];
+ if (position < left_weight) {
+ // Descend to left child
+ index = 2 * index;
+ } else {
+ // Descend to right child
+ index = 2 * index + 1;
+ position -= left_weight;
+ }
+ }
+ CHECK_GE(index, 0);
+ CHECK_LT(index, N_);
+ CHECK_LE(position, level_[num_levels_ - 1][index]);
+ return index;
+}
+
+void WeightedPicker::set_weight(int index, int32 weight) {
+ assert(index >= 0);
+ assert(index < N_);
+
+ // Adjust the sums all the way up to the root
+ const int32 delta = weight - get_weight(index);
+ for (int l = num_levels_ - 1; l >= 0; l--) {
+ level_[l][index] += delta;
+ index >>= 1;
+ }
+}
+
+void WeightedPicker::SetAllWeights(int32 weight) {
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weight;
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
+ Resize(N);
+
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weights[i];
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::RebuildTreeWeights() {
+ for (int l = num_levels_ - 2; l >= 0; l--) {
+ int32* level = level_[l];
+ int32* children = level_[l + 1];
+ for (int i = 0; i < LevelSize(l); i++) {
+ level[i] = children[2 * i] + children[2 * i + 1];
+ }
+ }
+}
+
+void WeightedPicker::Append(int32 weight) {
+ Resize(num_elements() + 1);
+ set_weight(num_elements() - 1, weight);
+}
+
+void WeightedPicker::Resize(int new_size) {
+ CHECK_GE(new_size, 0);
+ if (new_size <= LevelSize(num_levels_ - 1)) {
+ // The new picker fits in the existing levels.
+
+ // First zero out any of the weights that are being dropped so
+ // that the levels are correct (only needed when shrinking)
+ for (int i = new_size; i < N_; i++) {
+ set_weight(i, 0);
+ }
+
+ // We do not need to set any new weights when enlarging because
+ // the unneeded entries always have weight zero.
+ N_ = new_size;
+ return;
+ }
+
+ // We follow the simple strategy of just copying the old
+ // WeightedPicker into a new WeightedPicker. The cost is
+ // O(N) regardless.
+ assert(new_size > N_);
+ WeightedPicker new_picker(new_size);
+ int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
+ int32* src = this->level_[this->num_levels_ - 1];
+ memcpy(dst, src, sizeof(dst[0]) * N_);
+ memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
+ new_picker.RebuildTreeWeights();
+
+ // Now swap the two pickers
+ std::swap(new_picker.N_, this->N_);
+ std::swap(new_picker.num_levels_, this->num_levels_);
+ std::swap(new_picker.level_, this->level_);
+ assert(this->N_ == new_size);
+}
+
+} // namespace random
+} // namespace tensorflow