aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/exact_uniform_int.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random/exact_uniform_int.h')
-rw-r--r--tensorflow/core/lib/random/exact_uniform_int.h68
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h
new file mode 100644
index 0000000000..616354cc5c
--- /dev/null
+++ b/tensorflow/core/lib/random/exact_uniform_int.h
@@ -0,0 +1,68 @@
+// Exact uniform integers using rejection sampling
+
+#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+
+#include <type_traits>
+
+namespace tensorflow {
+namespace random {
+
+template <typename UintType, typename RandomBits>
+UintType ExactUniformInt(const UintType n, const RandomBits& random) {
+ static_assert(std::is_unsigned<UintType>::value,
+ "UintType must be an unsigned int");
+ static_assert(std::is_same<UintType, decltype(random())>::value,
+ "random() should return UintType");
+ if (n == 0) {
+ // Consume a value anyway
+ // TODO(irving): Assert n != 0, since this case makes no sense.
+ return random() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return random() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // random's output is uniform in the half-open interval [0, 2^{bits}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ const UintType range = ~static_cast<UintType>(0);
+ const UintType rem = (range % n) + 1;
+ UintType rnd;
+
+ // rem = ((2^bits-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = random(); // rnd uniform over [0, 2^{bits})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{bits})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1
+ // = 2^{bits}-1 - ((2^{bits}-1) \bmod n)
+ // = n \cdot \lfloor (2^{bits}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{bits}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_