blob: 3d2c2dbb397721a21e3df9d770310d254ff93781 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// An abstraction to pick from one of N elements with a specified
// weight per element.
//
// The weight for a given element can be changed in O(lg N) time
// An element can be picked in O(lg N) time.
//
// Uses O(N) bytes of memory.
//
// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
// adjustment after construction.
#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
#include <assert.h>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
namespace random {
class SimplePhilox;
class WeightedPicker {
public:
// REQUIRES N >= 0
// Initializes the elements with a weight of one per element
explicit WeightedPicker(int N);
// Releases all resources
~WeightedPicker();
// Pick a random element with probability proportional to its weight.
// If total weight is zero, returns -1.
int Pick(SimplePhilox* rnd) const;
// Deterministically pick element x whose weight covers the
// specified weight_index.
// Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
int PickAt(int32 weight_index) const;
// Get the weight associated with an element
// REQUIRES 0 <= index < N
int32 get_weight(int index) const;
// Set the weight associated with an element
// REQUIRES weight >= 0.0f
// REQUIRES 0 <= index < N
void set_weight(int index, int32 weight);
// Get the total combined weight of all elements
int32 total_weight() const;
// Get the number of elements in the picker
int num_elements() const;
// Set weight of each element to "weight"
void SetAllWeights(int32 weight);
// Resizes the picker to N and
// sets the weight of each element i to weight[i].
// The sum of the weights should not exceed 2^31 - 2
// Complexity O(N).
void SetWeightsFromArray(int N, const int32* weights);
// REQUIRES N >= 0
//
// Resize the weighted picker so that it has "N" elements.
// Any newly added entries have zero weight.
//
// Note: Resizing to a smaller size than num_elements() will
// not reclaim any memory. If you wish to reduce memory usage,
// allocate a new WeightedPicker of the appropriate size.
//
// It is efficient to use repeated calls to Resize(num_elements() + 1)
// to grow the picker to size X (takes total time O(X)).
void Resize(int N);
// Grow the picker by one and set the weight of the new entry to "weight".
//
// Repeated calls to Append() in order to grow the
// picker to size X takes a total time of O(X lg(X)).
// Consider using SetWeightsFromArray instead.
void Append(int32 weight);
private:
// We keep a binary tree with N leaves. The "i"th leaf contains
// the weight of the "i"th element. An internal node contains
// the sum of the weights of its children.
int N_; // Number of elements
int num_levels_; // Number of levels in tree (level-0 is root)
int32** level_; // Array that holds nodes per level
// Size of each level
static int LevelSize(int level) { return 1 << level; }
// Rebuild the tree weights using the leaf weights
void RebuildTreeWeights();
TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
};
inline int32 WeightedPicker::get_weight(int index) const {
DCHECK_GE(index, 0);
DCHECK_LT(index, N_);
return level_[num_levels_ - 1][index];
}
inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
inline int WeightedPicker::num_elements() const { return N_; }
} // namespace random
} // namespace tensorflow
#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
|