aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/weighted_picker.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random/weighted_picker.h')
-rw-r--r--tensorflow/core/lib/random/weighted_picker.h118
1 files changed, 118 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h
new file mode 100644
index 0000000000..3d2c2dbb39
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.h
@@ -0,0 +1,118 @@
+
+// An abstraction to pick from one of N elements with a specified
+// weight per element.
+//
+// The weight for a given element can be changed in O(lg N) time
+// An element can be picked in O(lg N) time.
+//
+// Uses O(N) bytes of memory.
+//
+// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
+// adjustment after construction.
+
+#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class SimplePhilox;
+
+class WeightedPicker {
+ public:
+ // REQUIRES N >= 0
+ // Initializes the elements with a weight of one per element
+ explicit WeightedPicker(int N);
+
+ // Releases all resources
+ ~WeightedPicker();
+
+ // Pick a random element with probability proportional to its weight.
+ // If total weight is zero, returns -1.
+ int Pick(SimplePhilox* rnd) const;
+
+ // Deterministically pick element x whose weight covers the
+ // specified weight_index.
+ // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
+ int PickAt(int32 weight_index) const;
+
+ // Get the weight associated with an element
+ // REQUIRES 0 <= index < N
+ int32 get_weight(int index) const;
+
+ // Set the weight associated with an element
+ // REQUIRES weight >= 0.0f
+ // REQUIRES 0 <= index < N
+ void set_weight(int index, int32 weight);
+
+ // Get the total combined weight of all elements
+ int32 total_weight() const;
+
+ // Get the number of elements in the picker
+ int num_elements() const;
+
+ // Set weight of each element to "weight"
+ void SetAllWeights(int32 weight);
+
+ // Resizes the picker to N and
+ // sets the weight of each element i to weight[i].
+ // The sum of the weights should not exceed 2^31 - 2
+ // Complexity O(N).
+ void SetWeightsFromArray(int N, const int32* weights);
+
+ // REQUIRES N >= 0
+ //
+ // Resize the weighted picker so that it has "N" elements.
+ // Any newly added entries have zero weight.
+ //
+ // Note: Resizing to a smaller size than num_elements() will
+ // not reclaim any memory. If you wish to reduce memory usage,
+ // allocate a new WeightedPicker of the appropriate size.
+ //
+ // It is efficient to use repeated calls to Resize(num_elements() + 1)
+ // to grow the picker to size X (takes total time O(X)).
+ void Resize(int N);
+
+ // Grow the picker by one and set the weight of the new entry to "weight".
+ //
+ // Repeated calls to Append() in order to grow the
+ // picker to size X takes a total time of O(X lg(X)).
+ // Consider using SetWeightsFromArray instead.
+ void Append(int32 weight);
+
+ private:
+ // We keep a binary tree with N leaves. The "i"th leaf contains
+ // the weight of the "i"th element. An internal node contains
+ // the sum of the weights of its children.
+ int N_; // Number of elements
+ int num_levels_; // Number of levels in tree (level-0 is root)
+ int32** level_; // Array that holds nodes per level
+
+ // Size of each level
+ static int LevelSize(int level) { return 1 << level; }
+
+ // Rebuild the tree weights using the leaf weights
+ void RebuildTreeWeights();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
+};
+
+inline int32 WeightedPicker::get_weight(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, N_);
+ return level_[num_levels_ - 1][index];
+}
+
+inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
+
+inline int WeightedPicker::num_elements() const { return N_; }
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_