summaryrefslogtreecommitdiff
path: root/absl/random/internal/fast_uniform_bits.h
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random/internal/fast_uniform_bits.h')
-rw-r--r--absl/random/internal/fast_uniform_bits.h202
1 files changed, 103 insertions, 99 deletions
diff --git a/absl/random/internal/fast_uniform_bits.h b/absl/random/internal/fast_uniform_bits.h
index f13c8729..425aaf7d 100644
--- a/absl/random/internal/fast_uniform_bits.h
+++ b/absl/random/internal/fast_uniform_bits.h
@@ -21,6 +21,7 @@
#include <type_traits>
#include "absl/base/config.h"
+#include "absl/meta/type_traits.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
@@ -38,28 +39,17 @@ constexpr bool IsPowerOfTwoOrZero(UIntType n) {
template <typename URBG>
constexpr typename URBG::result_type RangeSize() {
using result_type = typename URBG::result_type;
+ static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0.");
return ((URBG::max)() == (std::numeric_limits<result_type>::max)() &&
(URBG::min)() == std::numeric_limits<result_type>::lowest())
? result_type{0}
- : (URBG::max)() - (URBG::min)() + result_type{1};
-}
-
-template <typename UIntType>
-constexpr UIntType LargestPowerOfTwoLessThanOrEqualTo(UIntType n) {
- return n < 2 ? n : 2 * LargestPowerOfTwoLessThanOrEqualTo(n / 2);
-}
-
-// Given a URBG generating values in the closed interval [Lo, Hi], returns the
-// largest power of two less than or equal to `Hi - Lo + 1`.
-template <typename URBG>
-constexpr typename URBG::result_type PowerOfTwoSubRangeSize() {
- return LargestPowerOfTwoLessThanOrEqualTo(RangeSize<URBG>());
+ : ((URBG::max)() - (URBG::min)() + result_type{1});
}
// Computes the floor of the log. (i.e., std::floor(std::log2(N));
template <typename UIntType>
constexpr UIntType IntegerLog2(UIntType n) {
- return (n <= 1) ? 0 : 1 + IntegerLog2(n / 2);
+ return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1);
}
// Returns the number of bits of randomness returned through
@@ -68,18 +58,23 @@ template <typename URBG>
constexpr size_t NumBits() {
return RangeSize<URBG>() == 0
? std::numeric_limits<typename URBG::result_type>::digits
- : IntegerLog2(PowerOfTwoSubRangeSize<URBG>());
+ : IntegerLog2(RangeSize<URBG>());
}
// Given a shift value `n`, constructs a mask with exactly the low `n` bits set.
// If `n == 0`, all bits are set.
template <typename UIntType>
-constexpr UIntType MaskFromShift(UIntType n) {
+constexpr UIntType MaskFromShift(size_t n) {
return ((n % std::numeric_limits<UIntType>::digits) == 0)
? ~UIntType{0}
: (UIntType{1} << n) - UIntType{1};
}
+// Tags used to dispatch FastUniformBits::generate to the simple or more complex
+// entropy extraction algorithm.
+struct SimplifiedLoopTag {};
+struct RejectionLoopTag {};
+
// FastUniformBits implements a fast path to acquire uniform independent bits
// from a type which conforms to the [rand.req.urbg] concept.
// Parameterized by:
@@ -107,50 +102,16 @@ class FastUniformBits {
"Class-template FastUniformBits<> must be parameterized using "
"an unsigned type.");
- // PowerOfTwoVariate() generates a single random variate, always returning a
- // value in the half-open interval `[0, PowerOfTwoSubRangeSize<URBG>())`. If
- // the URBG already generates values in a power-of-two range, the generator
- // itself is used. Otherwise, we use rejection sampling on the largest
- // possible power-of-two-sized subrange.
- struct PowerOfTwoTag {};
- struct RejectionSamplingTag {};
- template <typename URBG>
- static typename URBG::result_type PowerOfTwoVariate(
- URBG& g) { // NOLINT(runtime/references)
- using tag =
- typename std::conditional<IsPowerOfTwoOrZero(RangeSize<URBG>()),
- PowerOfTwoTag, RejectionSamplingTag>::type;
- return PowerOfTwoVariate(g, tag{});
- }
-
- template <typename URBG>
- static typename URBG::result_type PowerOfTwoVariate(
- URBG& g, // NOLINT(runtime/references)
- PowerOfTwoTag) {
- return g() - (URBG::min)();
- }
-
- template <typename URBG>
- static typename URBG::result_type PowerOfTwoVariate(
- URBG& g, // NOLINT(runtime/references)
- RejectionSamplingTag) {
- // Use rejection sampling to ensure uniformity across the range.
- typename URBG::result_type u;
- do {
- u = g() - (URBG::min)();
- } while (u >= PowerOfTwoSubRangeSize<URBG>());
- return u;
- }
-
// Generate() generates a random value, dispatched on whether
- // the underlying URBG must loop over multiple calls or not.
+ // the underlying URBG must use rejection sampling to generate a value,
+ // or whether a simplified loop will suffice.
template <typename URBG>
result_type Generate(URBG& g, // NOLINT(runtime/references)
- std::true_type /* avoid_looping */);
+ SimplifiedLoopTag);
template <typename URBG>
result_type Generate(URBG& g, // NOLINT(runtime/references)
- std::false_type /* avoid_looping */);
+ RejectionLoopTag);
};
template <typename UIntType>
@@ -162,31 +123,47 @@ FastUniformBits<UIntType>::operator()(URBG& g) { // NOLINT(runtime/references)
// Y = (2 ^ kRange) - 1
static_assert((URBG::max)() > (URBG::min)(),
"URBG::max and URBG::min may not be equal.");
- using urbg_result_type = typename URBG::result_type;
- constexpr urbg_result_type kRangeMask =
- RangeSize<URBG>() == 0
- ? (std::numeric_limits<urbg_result_type>::max)()
- : static_cast<urbg_result_type>(PowerOfTwoSubRangeSize<URBG>() - 1);
- return Generate(g, std::integral_constant<bool, (kRangeMask >= (max)())>{});
+
+ using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()),
+ SimplifiedLoopTag, RejectionLoopTag>;
+ return Generate(g, tag{});
}
template <typename UIntType>
template <typename URBG>
typename FastUniformBits<UIntType>::result_type
FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
- std::true_type /* avoid_looping */) {
- // The width of the result_type is less than than the width of the random bits
- // provided by URBG. Thus, generate a single value and then simply mask off
- // the required bits.
+ SimplifiedLoopTag) {
+ // The simplified version of FastUniformBits works only on URBGs that have
+ // a range that is a power of 2. In this case we simply loop and shift without
+ // attempting to balance the bits across calls.
+ static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()),
+ "incorrect Generate tag for URBG instance");
+
+ static constexpr size_t kResultBits =
+ std::numeric_limits<result_type>::digits;
+ static constexpr size_t kUrbgBits = NumBits<URBG>();
+ static constexpr size_t kIters =
+ (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0);
+ static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits;
+ static constexpr auto kMin = (URBG::min)();
- return PowerOfTwoVariate(g) & (max)();
+ result_type r = static_cast<result_type>(g() - kMin);
+ for (size_t n = 1; n < kIters; ++n) {
+ r = (r << kShift) + static_cast<result_type>(g() - kMin);
+ }
+ return r;
}
template <typename UIntType>
template <typename URBG>
typename FastUniformBits<UIntType>::result_type
FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
- std::false_type /* avoid_looping */) {
+ RejectionLoopTag) {
+ static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()),
+ "incorrect Generate tag for URBG instance");
+ using urbg_result_type = typename URBG::result_type;
+
// See [rand.adapt.ibits] for more details on the constants calculated below.
//
// It is preferable to use roughly the same number of bits from each generator
@@ -199,21 +176,44 @@ FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
// `kSmallIters` and `kLargeIters` times respectively such
// that
//
- // `kTotalWidth == kSmallIters * kSmallWidth
- // + kLargeIters * kLargeWidth`
+ // `kResultBits == kSmallIters * kSmallBits
+ // + kLargeIters * kLargeBits`
//
- // where `kTotalWidth` is the total number of bits in `result_type`.
+ // where `kResultBits` is the total number of bits in `result_type`.
//
- constexpr size_t kTotalWidth = std::numeric_limits<result_type>::digits;
- constexpr size_t kUrbgWidth = NumBits<URBG>();
- constexpr size_t kTotalIters =
- kTotalWidth / kUrbgWidth + (kTotalWidth % kUrbgWidth != 0);
- constexpr size_t kSmallWidth = kTotalWidth / kTotalIters;
- constexpr size_t kLargeWidth = kSmallWidth + 1;
+ static constexpr size_t kResultBits =
+ std::numeric_limits<result_type>::digits; // w
+ static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>(); // R
+ static constexpr size_t kUrbgBits = NumBits<URBG>(); // m
+
+ // compute the initial estimate of the bits used.
+ // [rand.adapt.ibits] 2 (c)
+ static constexpr size_t kA = // ceil(w/m)
+ (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0); // n'
+
+ static constexpr size_t kABits = kResultBits / kA; // w0'
+ static constexpr urbg_result_type kARejection =
+ ((kUrbgRange >> kABits) << kABits); // y0'
+
+ // refine the selection to reduce the rejection frequency.
+ static constexpr size_t kTotalIters =
+ ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1); // n
+
+ // [rand.adapt.ibits] 2 (b)
+ static constexpr size_t kSmallIters =
+ kTotalIters - (kResultBits % kTotalIters); // n0
+ static constexpr size_t kSmallBits = kResultBits / kTotalIters; // w0
+ static constexpr urbg_result_type kSmallRejection =
+ ((kUrbgRange >> kSmallBits) << kSmallBits); // y0
+
+ static constexpr size_t kLargeBits = kSmallBits + 1; // w0+1
+ static constexpr urbg_result_type kLargeRejection =
+ ((kUrbgRange >> kLargeBits) << kLargeBits); // y1
+
//
- // Because `kLargeWidth == kSmallWidth + 1`, it follows that
+ // Because `kLargeBits == kSmallBits + 1`, it follows that
//
- // `kTotalWidth == kTotalIters * kSmallWidth + kLargeIters`
+ // `kResultBits == kSmallIters * kSmallBits + kLargeIters`
//
// and therefore
//
@@ -224,36 +224,40 @@ FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
// mentioned above, if the URBG width is a divisor of `kTotalWidth`, then
// there would be no need for any large iterations (i.e., one loop would
// suffice), and indeed, in this case, `kLargeIters` would be zero.
- constexpr size_t kLargeIters = kTotalWidth % kSmallWidth;
- constexpr size_t kSmallIters =
- (kTotalWidth - (kLargeWidth * kLargeIters)) / kSmallWidth;
+ static_assert(kResultBits == kSmallIters * kSmallBits +
+ (kTotalIters - kSmallIters) * kLargeBits,
+ "Error in looping constant calculations.");
- static_assert(
- kTotalWidth == kSmallIters * kSmallWidth + kLargeIters * kLargeWidth,
- "Error in looping constant calculations.");
+ // The small shift is essentially small bits, but due to the potential
+ // of generating a smaller result_type from a larger urbg type, the actual
+ // shift might be 0.
+ static constexpr size_t kSmallShift = kSmallBits % kResultBits;
+ static constexpr auto kSmallMask =
+ MaskFromShift<urbg_result_type>(kSmallShift);
+ static constexpr size_t kLargeShift = kLargeBits % kResultBits;
+ static constexpr auto kLargeMask =
+ MaskFromShift<urbg_result_type>(kLargeShift);
- result_type s = 0;
+ static constexpr auto kMin = (URBG::min)();
- constexpr size_t kSmallShift = kSmallWidth % kTotalWidth;
- constexpr result_type kSmallMask = MaskFromShift(result_type{kSmallShift});
+ result_type s = 0;
for (size_t n = 0; n < kSmallIters; ++n) {
- s = (s << kSmallShift) +
- (static_cast<result_type>(PowerOfTwoVariate(g)) & kSmallMask);
- }
+ urbg_result_type v;
+ do {
+ v = g() - kMin;
+ } while (v >= kSmallRejection);
- constexpr size_t kLargeShift = kLargeWidth % kTotalWidth;
- constexpr result_type kLargeMask = MaskFromShift(result_type{kLargeShift});
- for (size_t n = 0; n < kLargeIters; ++n) {
- s = (s << kLargeShift) +
- (static_cast<result_type>(PowerOfTwoVariate(g)) & kLargeMask);
+ s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask);
}
- static_assert(
- kLargeShift == kSmallShift + 1 ||
- (kLargeShift == 0 &&
- kSmallShift == std::numeric_limits<result_type>::digits - 1),
- "Error in looping constant calculations");
+ for (size_t n = kSmallIters; n < kTotalIters; ++n) {
+ urbg_result_type v;
+ do {
+ v = g() - kMin;
+ } while (v >= kLargeRejection);
+ s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask);
+ }
return s;
}