summaryrefslogtreecommitdiff
path: root/absl/random/internal
diff options
context:
space:
mode:
authorGravatar Abseil Team <absl-team@google.com>2021-11-22 08:22:28 -0800
committerGravatar Derek Mauro <dmauro@google.com>2021-11-22 14:44:20 -0500
commitec0d76f1d012cc1a4b3b08dfafcfc5237f5ba2c9 (patch)
tree3ed0065b80edc2f3c10b06e0cc9314118f9caa63 /absl/random/internal
parent72c765111173a61de6e4184bb837f855b7869952 (diff)
Export of internal Abseil changes
-- a0847bf19789c97689f1a8b0133a53b8af5e5caa by Samuel Benzaquen <sbenza@google.com>: Add support for absl::(u)int128 to random distributions and generators. PiperOrigin-RevId: 411565479 GitOrigin-RevId: a0847bf19789c97689f1a8b0133a53b8af5e5caa Change-Id: Ide434bdd93fcab8e90f791796498de14833b667c
Diffstat (limited to 'absl/random/internal')
-rw-r--r--absl/random/internal/BUILD.bazel8
-rw-r--r--absl/random/internal/fast_uniform_bits.h3
-rw-r--r--absl/random/internal/traits.h58
-rw-r--r--absl/random/internal/uniform_helper.h10
-rw-r--r--absl/random/internal/wide_multiply.h81
-rw-r--r--absl/random/internal/wide_multiply_test.cc110
6 files changed, 182 insertions, 88 deletions
diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel
index e93eebb6..67808aa0 100644
--- a/absl/random/internal/BUILD.bazel
+++ b/absl/random/internal/BUILD.bazel
@@ -35,7 +35,11 @@ cc_library(
hdrs = ["traits.h"],
copts = ABSL_DEFAULT_COPTS,
linkopts = ABSL_DEFAULT_LINKOPTS,
- deps = ["//absl/base:config"],
+ deps = [
+ "//absl/base:config",
+ "//absl/numeric:bits",
+ "//absl/numeric:int128",
+ ],
)
cc_library(
@@ -58,6 +62,7 @@ cc_library(
copts = ABSL_DEFAULT_COPTS,
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
+ ":traits",
"//absl/base:config",
"//absl/meta:type_traits",
],
@@ -674,6 +679,7 @@ cc_library(
":traits",
"//absl/base:config",
"//absl/meta:type_traits",
+ "//absl/numeric:int128",
],
)
diff --git a/absl/random/internal/fast_uniform_bits.h b/absl/random/internal/fast_uniform_bits.h
index 425aaf7d..f3a5c00f 100644
--- a/absl/random/internal/fast_uniform_bits.h
+++ b/absl/random/internal/fast_uniform_bits.h
@@ -22,6 +22,7 @@
#include "absl/base/config.h"
#include "absl/meta/type_traits.h"
+#include "absl/random/internal/traits.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
@@ -98,7 +99,7 @@ class FastUniformBits {
result_type operator()(URBG& g); // NOLINT(runtime/references)
private:
- static_assert(std::is_unsigned<UIntType>::value,
+ static_assert(IsUnsigned<UIntType>::value,
"Class-template FastUniformBits<> must be parameterized using "
"an unsigned type.");
diff --git a/absl/random/internal/traits.h b/absl/random/internal/traits.h
index 75772bd9..f874a0f7 100644
--- a/absl/random/internal/traits.h
+++ b/absl/random/internal/traits.h
@@ -20,6 +20,8 @@
#include <type_traits>
#include "absl/base/config.h"
+#include "absl/numeric/bits.h"
+#include "absl/numeric/int128.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
@@ -59,6 +61,31 @@ class is_widening_convertible {
rank<A>() <= rank<B>();
};
+template <typename T>
+struct IsIntegral : std::is_integral<T> {};
+template <>
+struct IsIntegral<absl::int128> : std::true_type {};
+template <>
+struct IsIntegral<absl::uint128> : std::true_type {};
+
+template <typename T>
+struct MakeUnsigned : std::make_unsigned<T> {};
+template <>
+struct MakeUnsigned<absl::int128> {
+ using type = absl::uint128;
+};
+template <>
+struct MakeUnsigned<absl::uint128> {
+ using type = absl::uint128;
+};
+
+template <typename T>
+struct IsUnsigned : std::is_unsigned<T> {};
+template <>
+struct IsUnsigned<absl::int128> : std::false_type {};
+template <>
+struct IsUnsigned<absl::uint128> : std::true_type {};
+
// unsigned_bits<N>::type returns the unsigned int type with the indicated
// number of bits.
template <size_t N>
@@ -81,19 +108,40 @@ struct unsigned_bits<64> {
using type = uint64_t;
};
-#ifdef ABSL_HAVE_INTRINSIC_INT128
template <>
struct unsigned_bits<128> {
- using type = __uint128_t;
+ using type = absl::uint128;
+};
+
+// 256-bit wrapper for wide multiplications.
+struct U256 {
+ uint128 hi;
+ uint128 lo;
+};
+template <>
+struct unsigned_bits<256> {
+ using type = U256;
};
-#endif
template <typename IntType>
struct make_unsigned_bits {
- using type = typename unsigned_bits<std::numeric_limits<
- typename std::make_unsigned<IntType>::type>::digits>::type;
+ using type = typename unsigned_bits<
+ std::numeric_limits<typename MakeUnsigned<IntType>::type>::digits>::type;
};
+template <typename T>
+int BitWidth(T v) {
+ // Workaround for bit_width not supporting int128.
+ // Don't hardcode `64` to make sure this code does not trigger compiler
+ // warnings in smaller types.
+ constexpr int half_bits = sizeof(T) * 8 / 2;
+ if (sizeof(T) == 16 && (v >> half_bits) != 0) {
+ return bit_width(static_cast<uint64_t>(v >> half_bits)) + half_bits;
+ } else {
+ return bit_width(static_cast<uint64_t>(v));
+ }
+}
+
} // namespace random_internal
ABSL_NAMESPACE_END
} // namespace absl
diff --git a/absl/random/internal/uniform_helper.h b/absl/random/internal/uniform_helper.h
index 1243bc1c..e68b82ee 100644
--- a/absl/random/internal/uniform_helper.h
+++ b/absl/random/internal/uniform_helper.h
@@ -100,7 +100,7 @@ using uniform_inferred_return_t =
template <typename IntType, typename Tag>
typename absl::enable_if_t<
absl::conjunction<
- std::is_integral<IntType>,
+ IsIntegral<IntType>,
absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>,
std::is_same<Tag, IntervalOpenOpenTag>>>::value,
IntType>
@@ -131,7 +131,7 @@ uniform_lower_bound(Tag, NumType a, NumType) {
template <typename IntType, typename Tag>
typename absl::enable_if_t<
absl::conjunction<
- std::is_integral<IntType>,
+ IsIntegral<IntType>,
absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>,
std::is_same<Tag, IntervalOpenOpenTag>>>::value,
IntType>
@@ -153,7 +153,7 @@ uniform_upper_bound(Tag, FloatType, FloatType b) {
template <typename IntType, typename Tag>
typename absl::enable_if_t<
absl::conjunction<
- std::is_integral<IntType>,
+ IsIntegral<IntType>,
absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
std::is_same<Tag, IntervalOpenClosedTag>>>::value,
IntType>
@@ -201,7 +201,7 @@ is_uniform_range_valid(FloatType a, FloatType b) {
}
template <typename IntType>
-absl::enable_if_t<std::is_integral<IntType>::value, bool>
+absl::enable_if_t<IsIntegral<IntType>::value, bool>
is_uniform_range_valid(IntType a, IntType b) {
return a <= b;
}
@@ -210,7 +210,7 @@ is_uniform_range_valid(IntType a, IntType b) {
// or absl::uniform_real_distribution depending on the NumType parameter.
template <typename NumType>
using UniformDistribution =
- typename std::conditional<std::is_integral<NumType>::value,
+ typename std::conditional<IsIntegral<NumType>::value,
absl::uniform_int_distribution<NumType>,
absl::uniform_real_distribution<NumType>>::type;
diff --git a/absl/random/internal/wide_multiply.h b/absl/random/internal/wide_multiply.h
index b6e6c4b6..891e3630 100644
--- a/absl/random/internal/wide_multiply.h
+++ b/absl/random/internal/wide_multiply.h
@@ -34,43 +34,6 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
-// Helper object to multiply two 64-bit values to a 128-bit value.
-// MultiplyU64ToU128 multiplies two 64-bit values to a 128-bit value.
-// If an intrinsic is available, it is used, otherwise use native 32-bit
-// multiplies to construct the result.
-inline absl::uint128 MultiplyU64ToU128(uint64_t a, uint64_t b) {
-#if defined(ABSL_HAVE_INTRINSIC_INT128)
- return absl::uint128(static_cast<__uint128_t>(a) * b);
-#elif defined(ABSL_INTERNAL_USE_UMUL128)
- // uint64_t * uint64_t => uint128 multiply using imul intrinsic on MSVC.
- uint64_t high = 0;
- const uint64_t low = _umul128(a, b, &high);
- return absl::MakeUint128(high, low);
-#else
- // uint128(a) * uint128(b) in emulated mode computes a full 128-bit x 128-bit
- // multiply. However there are many cases where that is not necessary, and it
- // is only necessary to support a 64-bit x 64-bit = 128-bit multiply. This is
- // for those cases.
- const uint64_t a00 = static_cast<uint32_t>(a);
- const uint64_t a32 = a >> 32;
- const uint64_t b00 = static_cast<uint32_t>(b);
- const uint64_t b32 = b >> 32;
-
- const uint64_t c00 = a00 * b00;
- const uint64_t c32a = a00 * b32;
- const uint64_t c32b = a32 * b00;
- const uint64_t c64 = a32 * b32;
-
- const uint32_t carry =
- static_cast<uint32_t>(((c00 >> 32) + static_cast<uint32_t>(c32a) +
- static_cast<uint32_t>(c32b)) >>
- 32);
-
- return absl::MakeUint128(c64 + (c32a >> 32) + (c32b >> 32) + carry,
- c00 + (c32a << 32) + (c32b << 32));
-#endif
-}
-
// wide_multiply<T> multiplies two N-bit values to a 2N-bit result.
template <typename UIntType>
struct wide_multiply {
@@ -82,27 +45,49 @@ struct wide_multiply {
return static_cast<result_type>(a) * b;
}
- static input_type hi(result_type r) { return r >> kN; }
- static input_type lo(result_type r) { return r; }
+ static input_type hi(result_type r) {
+ return static_cast<input_type>(r >> kN);
+ }
+ static input_type lo(result_type r) { return static_cast<input_type>(r); }
static_assert(std::is_unsigned<UIntType>::value,
"Class-template wide_multiply<> argument must be unsigned.");
};
-#ifndef ABSL_HAVE_INTRINSIC_INT128
+// MultiplyU128ToU256 multiplies two 128-bit values to a 256-bit value.
+inline U256 MultiplyU128ToU256(uint128 a, uint128 b) {
+ const uint128 a00 = static_cast<uint64_t>(a);
+ const uint128 a64 = a >> 64;
+ const uint128 b00 = static_cast<uint64_t>(b);
+ const uint128 b64 = b >> 64;
+
+ const uint128 c00 = a00 * b00;
+ const uint128 c64a = a00 * b64;
+ const uint128 c64b = a64 * b00;
+ const uint128 c128 = a64 * b64;
+
+ const uint64_t carry =
+ static_cast<uint64_t>(((c00 >> 64) + static_cast<uint64_t>(c64a) +
+ static_cast<uint64_t>(c64b)) >>
+ 64);
+
+ return {c128 + (c64a >> 64) + (c64b >> 64) + carry,
+ c00 + (c64a << 64) + (c64b << 64)};
+}
+
+
template <>
-struct wide_multiply<uint64_t> {
- using input_type = uint64_t;
- using result_type = absl::uint128;
+struct wide_multiply<uint128> {
+ using input_type = uint128;
+ using result_type = U256;
- static result_type multiply(uint64_t a, uint64_t b) {
- return MultiplyU64ToU128(a, b);
+ static result_type multiply(input_type a, input_type b) {
+ return MultiplyU128ToU256(a, b);
}
- static uint64_t hi(result_type r) { return absl::Uint128High64(r); }
- static uint64_t lo(result_type r) { return absl::Uint128Low64(r); }
+ static input_type hi(result_type r) { return r.hi; }
+ static input_type lo(result_type r) { return r.lo; }
};
-#endif
} // namespace random_internal
ABSL_NAMESPACE_END
diff --git a/absl/random/internal/wide_multiply_test.cc b/absl/random/internal/wide_multiply_test.cc
index e276cb51..f8ee35c0 100644
--- a/absl/random/internal/wide_multiply_test.cc
+++ b/absl/random/internal/wide_multiply_test.cc
@@ -14,52 +14,106 @@
#include "absl/random/internal/wide_multiply.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/numeric/int128.h"
-using absl::random_internal::MultiplyU64ToU128;
+using absl::random_internal::MultiplyU128ToU256;
+using absl::random_internal::U256;
namespace {
-TEST(WideMultiplyTest, MultiplyU64ToU128Test) {
- constexpr uint64_t k1 = 1;
- constexpr uint64_t kMax = ~static_cast<uint64_t>(0);
+U256 LeftShift(U256 v, int s) {
+ if (s == 0) {
+ return v;
+ } else if (s < 128) {
+ return {(v.hi << s) | (v.lo >> (128 - s)), v.lo << s};
+ } else {
+ return {v.lo << (s - 128), 0};
+ }
+}
+
+MATCHER_P2(Eq256, hi, lo, "") { return arg.hi == hi && arg.lo == lo; }
+MATCHER_P(Eq256, v, "") { return arg.hi == v.hi && arg.lo == v.lo; }
+
+TEST(WideMultiplyTest, MultiplyU128ToU256Test) {
+ using absl::uint128;
+ constexpr uint128 k1 = 1;
+ constexpr uint128 kMax = ~static_cast<uint128>(0);
- EXPECT_EQ(absl::uint128(0), MultiplyU64ToU128(0, 0));
+ EXPECT_THAT(MultiplyU128ToU256(0, 0), Eq256(0, 0));
- // Max uint64_t
- EXPECT_EQ(MultiplyU64ToU128(kMax, kMax),
- absl::MakeUint128(0xfffffffffffffffe, 0x0000000000000001));
- EXPECT_EQ(absl::MakeUint128(0, kMax), MultiplyU64ToU128(kMax, 1));
- EXPECT_EQ(absl::MakeUint128(0, kMax), MultiplyU64ToU128(1, kMax));
+ // Max uin128_t
+ EXPECT_THAT(MultiplyU128ToU256(kMax, kMax), Eq256(kMax << 1, 1));
+ EXPECT_THAT(MultiplyU128ToU256(kMax, 1), Eq256(0, kMax));
+ EXPECT_THAT(MultiplyU128ToU256(1, kMax), Eq256(0, kMax));
for (int i = 0; i < 64; ++i) {
- EXPECT_EQ(absl::MakeUint128(0, kMax) << i,
- MultiplyU64ToU128(kMax, k1 << i));
- EXPECT_EQ(absl::MakeUint128(0, kMax) << i,
- MultiplyU64ToU128(k1 << i, kMax));
+ SCOPED_TRACE(i);
+ EXPECT_THAT(MultiplyU128ToU256(kMax, k1 << i),
+ Eq256(LeftShift({0, kMax}, i)));
+ EXPECT_THAT(MultiplyU128ToU256(k1 << i, kMax),
+ Eq256(LeftShift({0, kMax}, i)));
}
// 1-bit x 1-bit.
for (int i = 0; i < 64; ++i) {
for (int j = 0; j < 64; ++j) {
- EXPECT_EQ(absl::MakeUint128(0, 1) << (i + j),
- MultiplyU64ToU128(k1 << i, k1 << j));
- EXPECT_EQ(absl::MakeUint128(0, 1) << (i + j),
- MultiplyU64ToU128(k1 << i, k1 << j));
+ EXPECT_THAT(MultiplyU128ToU256(k1 << i, k1 << j),
+ Eq256(LeftShift({0, 1}, i + j)));
}
}
// Verified multiplies
- EXPECT_EQ(MultiplyU64ToU128(0xffffeeeeddddcccc, 0xbbbbaaaa99998888),
- absl::MakeUint128(0xbbbb9e2692c5dddc, 0xc28f7531048d2c60));
- EXPECT_EQ(MultiplyU64ToU128(0x0123456789abcdef, 0xfedcba9876543210),
- absl::MakeUint128(0x0121fa00ad77d742, 0x2236d88fe5618cf0));
- EXPECT_EQ(MultiplyU64ToU128(0x0123456789abcdef, 0xfdb97531eca86420),
- absl::MakeUint128(0x0120ae99d26725fc, 0xce197f0ecac319e0));
- EXPECT_EQ(MultiplyU64ToU128(0x97a87f4f261ba3f2, 0xfedcba9876543210),
- absl::MakeUint128(0x96fbf1a8ae78d0ba, 0x5a6dd4b71f278320));
- EXPECT_EQ(MultiplyU64ToU128(0xfedcba9876543210, 0xfdb97531eca86420),
- absl::MakeUint128(0xfc98c6981a413e22, 0x342d0bbf48948200));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0xc502da0d6ea99fe8, 0xfa3c9141a1f50912),
+ absl::MakeUint128(0x96bcf1ac37f97bd6, 0x27e2cdeb5fb2299e)),
+ Eq256(absl::MakeUint128(0x740113d838f96a64, 0x22e8cfa4d71f89ea),
+ absl::MakeUint128(0x19184a345c62e993, 0x237871b630337b1c)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0x6f29e670cee07230, 0xc3d8e6c3e4d86759),
+ absl::MakeUint128(0x3227d29fa6386db1, 0x231682bb1e4b764f)),
+ Eq256(absl::MakeUint128(0x15c779d9d5d3b07c, 0xd7e6c827f0c81cbe),
+ absl::MakeUint128(0xf88e3914f7fa287a, 0x15b79975137dea77)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0xafb77107215646e1, 0x3b844cb1ac5769e7),
+ absl::MakeUint128(0x1ff7b2d888b62479, 0x92f758ae96fcba0b)),
+ Eq256(absl::MakeUint128(0x15f13b70181f6985, 0x2adb36bbabce7d02),
+ absl::MakeUint128(0x6c470d72e13aad04, 0x63fba3f5841762ed)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0xd85d5558d67ac905, 0xf88c70654dae19b1),
+ absl::MakeUint128(0x17252c6727db3738, 0x399ff658c511eedc)),
+ Eq256(absl::MakeUint128(0x138fcdaf8b0421ee, 0x1b465ddf2a0d03f6),
+ absl::MakeUint128(0x8f573ba68296860f, 0xf327d2738741a21c)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0x46f0421a37ff6bee, 0xa61df89f09d140b1),
+ absl::MakeUint128(0x3d712ec9f37ca2e1, 0x9658a2cba47ef4b1)),
+ Eq256(absl::MakeUint128(0x11069cc48ee7c95d, 0xd35fb1c7aa91c978),
+ absl::MakeUint128(0xbe2f4a6de874b015, 0xd2f7ac1b76746e61)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0x730d27c72d58fa49, 0x3ebeda7498f8827c),
+ absl::MakeUint128(0xa2c959eca9f503af, 0x189c687eb842bbd8)),
+ Eq256(absl::MakeUint128(0x4928d0ea356ba022, 0x1546d34a2963393),
+ absl::MakeUint128(0x7481531e1e0a16d1, 0xdd8025015cf6aca0)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0x6ca41020f856d2f1, 0xb9b0838c04a7f4aa),
+ absl::MakeUint128(0x9cf41d28a8396f54, 0x1d681695e377ffe6)),
+ Eq256(absl::MakeUint128(0x429b92934d9be6f1, 0xea182877157c1e7),
+ absl::MakeUint128(0x7135c23f0a4a475, 0xc1adc366f4a126bc)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0x57472833797c332, 0x6c79272fdec4687a),
+ absl::MakeUint128(0xb5f022ea3838e46b, 0x16face2f003e27a6)),
+ Eq256(absl::MakeUint128(0x3e072e0962b3400, 0x5d9fe8fdc3d0e1f4),
+ absl::MakeUint128(0x7dc0df47cedafd62, 0xbe6501f1acd2551c)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0xf0fb4198322eb1c2, 0xfe7f5f31f3885938),
+ absl::MakeUint128(0xd99012b71bb7aa31, 0xac7a6f9eb190789)),
+ Eq256(absl::MakeUint128(0xcccc998cf075ca01, 0x642d144322fb873a),
+ absl::MakeUint128(0xc79dc12b69d91ed4, 0xa83459132ce046f8)));
+ EXPECT_THAT(MultiplyU128ToU256(
+ absl::MakeUint128(0xb5c04120848cdb47, 0x8aa62a827bf52635),
+ absl::MakeUint128(0x8d07a359be2f1380, 0x467bb90d59da0dea)),
+ Eq256(absl::MakeUint128(0x64205019d139a9ce, 0x99425c5fb6e7a977),
+ absl::MakeUint128(0xd3e99628a9e5fca7, 0x9c7824cb7279d72)));
}
} // namespace