aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/exact_uniform_int.h
blob: 616354cc5c2a2684d35e84e43e5e10571f5582c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Exact uniform integers using rejection sampling

#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_

#include <type_traits>

namespace tensorflow {
namespace random {

template <typename UintType, typename RandomBits>
UintType ExactUniformInt(const UintType n, const RandomBits& random) {
  static_assert(std::is_unsigned<UintType>::value,
                "UintType must be an unsigned int");
  static_assert(std::is_same<UintType, decltype(random())>::value,
                "random() should return UintType");
  if (n == 0) {
    // Consume a value anyway
    // TODO(irving): Assert n != 0, since this case makes no sense.
    return random() * n;
  } else if (0 == (n & (n - 1))) {
    // N is a power of two, so just mask off the lower bits.
    return random() & (n - 1);
  } else {
    // Reject all numbers that skew the distribution towards 0.

    // random's output is uniform in the half-open interval [0, 2^{bits}).
    // For any interval [m,n), the number of elements in it is n-m.

    const UintType range = ~static_cast<UintType>(0);
    const UintType rem = (range % n) + 1;
    UintType rnd;

    // rem = ((2^bits-1) \bmod n) + 1
    // 1 <= rem <= n

    // NB: rem == n is impossible, since n is not a power of 2 (from
    // earlier check).

    do {
      rnd = random();     // rnd uniform over [0, 2^{bits})
    } while (rnd < rem);  // reject [0, rem)
    // rnd is uniform over [rem, 2^{bits})
    //
    // The number of elements in the half-open interval is
    //
    //  2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1
    //                 = 2^{bits}-1 - ((2^{bits}-1) \bmod n)
    //                 = n \cdot \lfloor (2^{bits}-1)/n \rfloor
    //
    // therefore n evenly divides the number of integers in the
    // interval.
    //
    // The function v \rightarrow v % n takes values from [bias,
    // 2^{bits}) to [0, n).  Each integer in the range interval [0, n)
    // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from
    // the domain interval.
    //
    // Therefore, v % n is uniform over [0, n).  QED.

    return rnd % n;
  }
}

}  // namespace random
}  // namespace tensorflow

#endif  // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_