aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/weighted_picker.h
blob: 3d2c2dbb397721a21e3df9d770310d254ff93781 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

// An abstraction to pick from one of N elements with a specified
// weight per element.
//
// The weight for a given element can be changed in O(lg N) time
// An element can be picked in O(lg N) time.
//
// Uses O(N) bytes of memory.
//
// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
// adjustment after construction.

#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_

#include <assert.h>

#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"

namespace tensorflow {
namespace random {

class SimplePhilox;

class WeightedPicker {
 public:
  // REQUIRES   N >= 0
  // Initializes the elements with a weight of one per element
  explicit WeightedPicker(int N);

  // Releases all resources
  ~WeightedPicker();

  // Pick a random element with probability proportional to its weight.
  // If total weight is zero, returns -1.
  int Pick(SimplePhilox* rnd) const;

  // Deterministically pick element x whose weight covers the
  // specified weight_index.
  // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
  int PickAt(int32 weight_index) const;

  // Get the weight associated with an element
  // REQUIRES 0 <= index < N
  int32 get_weight(int index) const;

  // Set the weight associated with an element
  // REQUIRES weight >= 0.0f
  // REQUIRES 0 <= index < N
  void set_weight(int index, int32 weight);

  // Get the total combined weight of all elements
  int32 total_weight() const;

  // Get the number of elements in the picker
  int num_elements() const;

  // Set weight of each element to "weight"
  void SetAllWeights(int32 weight);

  // Resizes the picker to N and
  // sets the weight of each element i to weight[i].
  // The sum of the weights should not exceed 2^31 - 2
  // Complexity O(N).
  void SetWeightsFromArray(int N, const int32* weights);

  // REQUIRES   N >= 0
  //
  // Resize the weighted picker so that it has "N" elements.
  // Any newly added entries have zero weight.
  //
  // Note: Resizing to a smaller size than num_elements() will
  // not reclaim any memory.  If you wish to reduce memory usage,
  // allocate a new WeightedPicker of the appropriate size.
  //
  // It is efficient to use repeated calls to Resize(num_elements() + 1)
  // to grow the picker to size X (takes total time O(X)).
  void Resize(int N);

  // Grow the picker by one and set the weight of the new entry to "weight".
  //
  // Repeated calls to Append() in order to grow the
  // picker to size X takes a total time of O(X lg(X)).
  // Consider using SetWeightsFromArray instead.
  void Append(int32 weight);

 private:
  // We keep a binary tree with N leaves.  The "i"th leaf contains
  // the weight of the "i"th element.  An internal node contains
  // the sum of the weights of its children.
  int N_;           // Number of elements
  int num_levels_;  // Number of levels in tree (level-0 is root)
  int32** level_;   // Array that holds nodes per level

  // Size of each level
  static int LevelSize(int level) { return 1 << level; }

  // Rebuild the tree weights using the leaf weights
  void RebuildTreeWeights();

  TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
};

inline int32 WeightedPicker::get_weight(int index) const {
  DCHECK_GE(index, 0);
  DCHECK_LT(index, N_);
  return level_[num_levels_ - 1][index];
}

inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }

inline int WeightedPicker::num_elements() const { return N_; }

}  // namespace random
}  // namespace tensorflow

#endif  // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_