diff options
author | Benjamin Barenblat <bbaren@google.com> | 2024-09-03 11:49:29 -0400 |
---|---|---|
committer | Benjamin Barenblat <bbaren@google.com> | 2024-09-03 11:49:29 -0400 |
commit | c1afa8b8238c25591ca80d068477aa7d4ce05fc8 (patch) | |
tree | 284a9f8b319de5783ff83ad004a9e390cb60fd0d /absl/random | |
parent | 23778b53f420f54eebc195dd8430e79bda165e5b (diff) | |
parent | 4447c7562e3bc702ade25105912dce503f0c4010 (diff) |
Merge new upstream LTS 20240722.0
Diffstat (limited to 'absl/random')
-rw-r--r-- | absl/random/BUILD.bazel | 24 | ||||
-rw-r--r-- | absl/random/CMakeLists.txt | 38 | ||||
-rw-r--r-- | absl/random/benchmarks.cc | 4 | ||||
-rw-r--r-- | absl/random/beta_distribution.h | 16 | ||||
-rw-r--r-- | absl/random/bit_gen_ref.h | 16 | ||||
-rw-r--r-- | absl/random/discrete_distribution_test.cc | 2 | ||||
-rw-r--r-- | absl/random/distributions.h | 22 | ||||
-rw-r--r-- | absl/random/distributions_test.cc | 60 | ||||
-rw-r--r-- | absl/random/internal/BUILD.bazel | 24 | ||||
-rw-r--r-- | absl/random/internal/mock_helpers.h | 40 | ||||
-rw-r--r-- | absl/random/internal/mock_overload_set.h | 82 | ||||
-rw-r--r-- | absl/random/internal/mock_validators.h | 98 | ||||
-rw-r--r-- | absl/random/mock_distributions.h | 19 | ||||
-rw-r--r-- | absl/random/mock_distributions_test.cc | 215 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen.h | 158 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen_test.cc | 49 | ||||
-rw-r--r-- | absl/random/seed_sequences.h | 2 |
17 files changed, 676 insertions, 193 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel index 80c4f055..f276cc08 100644 --- a/absl/random/BUILD.bazel +++ b/absl/random/BUILD.bazel @@ -108,9 +108,11 @@ cc_library( deps = [ ":seed_gen_exception", "//absl/base:config", + "//absl/base:nullability", "//absl/random/internal:pool_urbg", "//absl/random/internal:salted_seed_seq", "//absl/random/internal:seed_material", + "//absl/strings:string_view", "//absl/types:span", ], ) @@ -132,35 +134,33 @@ cc_library( cc_library( name = "mock_distributions", - testonly = 1, + testonly = True, hdrs = ["mock_distributions.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ ":distributions", ":mocking_bit_gen", - "//absl/meta:type_traits", + "//absl/base:config", "//absl/random/internal:mock_overload_set", - "@com_google_googletest//:gtest", + "//absl/random/internal:mock_validators", ], ) cc_library( name = "mocking_bit_gen", - testonly = 1, + testonly = True, hdrs = [ "mocking_bit_gen.h", ], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ - ":distributions", ":random", + "//absl/base:config", + "//absl/base:core_headers", "//absl/base:fast_type_id", "//absl/container:flat_hash_map", "//absl/meta:type_traits", - "//absl/random/internal:distribution_caller", - "//absl/strings", - "//absl/types:span", - "//absl/types:variant", + "//absl/random/internal:mock_helpers", "//absl/utility", "@com_google_googletest//:gtest", ], @@ -221,6 +221,8 @@ cc_test( deps = [ ":distributions", ":random", + "//absl/meta:type_traits", + "//absl/numeric:int128", "//absl/random/internal:distribution_test_util", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -479,9 +481,11 @@ cc_test( "no_test_wasm", ], deps = [ + ":distributions", ":mock_distributions", ":mocking_bit_gen", ":random", + "//absl/numeric:int128", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], @@ -521,7 +525,7 @@ cc_test( # Benchmarks for various methods / test utilities cc_binary( name = "benchmarks", - testonly = 1, + testonly = True, srcs = [ "benchmarks.cc", ], diff --git a/absl/random/CMakeLists.txt b/absl/random/CMakeLists.txt index bd363d88..ad5477e3 100644 --- a/absl/random/CMakeLists.txt +++ b/absl/random/CMakeLists.txt @@ -77,6 +77,7 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config absl::fast_type_id absl::optional ) @@ -92,6 +93,7 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config absl::random_mocking_bit_gen absl::random_internal_mock_helpers TESTONLY @@ -108,17 +110,15 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config + absl::core_headers + absl::fast_type_id absl::flat_hash_map absl::raw_logging_internal - absl::random_distributions - absl::random_internal_distribution_caller - absl::random_internal_mock_overload_set + absl::random_internal_mock_helpers absl::random_random - absl::strings - absl::span absl::type_traits absl::utility - absl::variant GTest::gmock GTest::gtest PUBLIC @@ -135,6 +135,7 @@ absl_cc_test( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::random_distributions absl::random_mocking_bit_gen absl::random_random GTest::gmock @@ -225,11 +226,13 @@ absl_cc_library( DEPS absl::config absl::inlined_vector + absl::nullability absl::random_internal_pool_urbg absl::random_internal_salted_seed_seq absl::random_internal_seed_material absl::random_seed_gen_exception absl::span + absl::string_view ) absl_cc_test( @@ -285,6 +288,8 @@ absl_cc_test( DEPS absl::random_distributions absl::random_random + absl::type_traits + absl::int128 absl::random_internal_distribution_test_util GTest::gmock GTest::gtest_main @@ -1171,6 +1176,26 @@ absl_cc_library( ) # Internal-only target, do not depend on directly. +absl_cc_library( + NAME + random_internal_mock_validators + HDRS + "internal/mock_validators.h" + COPTS + ${ABSL_DEFAULT_COPTS} + LINKOPTS + ${ABSL_DEFAULT_LINKOPTS} + DEPS + absl::random_internal_iostream_state_saver + absl::random_internal_uniform_helper + absl::config + absl::raw_logging_internal + absl::strings + absl::string_view + TESTONLY +) + +# Internal-only target, do not depend on directly. absl_cc_test( NAME random_internal_uniform_helper_test @@ -1183,6 +1208,7 @@ absl_cc_test( DEPS absl::random_internal_uniform_helper GTest::gtest_main + absl::int128 ) # Internal-only target, do not depend on directly. diff --git a/absl/random/benchmarks.cc b/absl/random/benchmarks.cc index 0900e818..26bc95e8 100644 --- a/absl/random/benchmarks.cc +++ b/absl/random/benchmarks.cc @@ -291,7 +291,7 @@ void BM_Thread(benchmark::State& state) { BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 100)->ThreadPerCpu(); \ BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 1000)->ThreadPerCpu(); \ BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 100)->ThreadPerCpu(); \ - BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 1000)->ThreadPerCpu(); + BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 1000)->ThreadPerCpu() #define BM_EXTENDED(Engine) \ /* -------------- Extended Uniform -----------------------*/ \ @@ -355,7 +355,7 @@ void BM_Thread(benchmark::State& state) { BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<float>, 410, \ 580); \ BENCHMARK_TEMPLATE(BM_Gamma, Engine, std::gamma_distribution<float>, 199); \ - BENCHMARK_TEMPLATE(BM_Gamma, Engine, std::gamma_distribution<double>, 199); + BENCHMARK_TEMPLATE(BM_Gamma, Engine, std::gamma_distribution<double>, 199) // ABSL Recommended interfaces. BM_BASIC(absl::InsecureBitGen); // === pcg64_2018_engine diff --git a/absl/random/beta_distribution.h b/absl/random/beta_distribution.h index c154066f..432c5161 100644 --- a/absl/random/beta_distribution.h +++ b/absl/random/beta_distribution.h @@ -181,18 +181,18 @@ class beta_distribution { result_type alpha_; result_type beta_; - result_type a_; // the smaller of {alpha, beta}, or 1.0/alpha_ in JOEHNK - result_type b_; // the larger of {alpha, beta}, or 1.0/beta_ in JOEHNK - result_type x_; // alpha + beta, or the result in degenerate cases - result_type log_x_; // log(x_) - result_type y_; // "beta" in Cheng - result_type gamma_; // "gamma" in Cheng + result_type a_{}; // the smaller of {alpha, beta}, or 1.0/alpha_ in JOEHNK + result_type b_{}; // the larger of {alpha, beta}, or 1.0/beta_ in JOEHNK + result_type x_{}; // alpha + beta, or the result in degenerate cases + result_type log_x_{}; // log(x_) + result_type y_{}; // "beta" in Cheng + result_type gamma_{}; // "gamma" in Cheng - Method method_; + Method method_{}; // Placing this last for optimal alignment. // Whether alpha_ != a_, i.e. true iff alpha_ > beta_. - bool inverted_; + bool inverted_{}; static_assert(std::is_floating_point<RealType>::value, "Class-template absl::beta_distribution<> must be " diff --git a/absl/random/bit_gen_ref.h b/absl/random/bit_gen_ref.h index e475221a..ac26d9d4 100644 --- a/absl/random/bit_gen_ref.h +++ b/absl/random/bit_gen_ref.h @@ -28,6 +28,7 @@ #include <type_traits> #include <utility> +#include "absl/base/attributes.h" #include "absl/base/internal/fast_type_id.h" #include "absl/base/macros.h" #include "absl/meta/type_traits.h" @@ -110,20 +111,21 @@ class BitGenRef { BitGenRef& operator=(const BitGenRef&) = default; BitGenRef& operator=(BitGenRef&&) = default; - template <typename URBG, typename absl::enable_if_t< - (!std::is_same<URBG, BitGenRef>::value && - random_internal::is_urbg<URBG>::value && - !HasInvokeMock<URBG>::value)>* = nullptr> - BitGenRef(URBG& gen) // NOLINT + template < + typename URBGRef, typename URBG = absl::remove_cvref_t<URBGRef>, + typename absl::enable_if_t<(!std::is_same<URBG, BitGenRef>::value && + random_internal::is_urbg<URBG>::value && + !HasInvokeMock<URBG>::value)>* = nullptr> + BitGenRef(URBGRef&& gen ABSL_ATTRIBUTE_LIFETIME_BOUND) // NOLINT : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), mock_call_(NotAMock), generate_impl_fn_(ImplFn<URBG>) {} - template <typename URBG, + template <typename URBGRef, typename URBG = absl::remove_cvref_t<URBGRef>, typename absl::enable_if_t<(!std::is_same<URBG, BitGenRef>::value && random_internal::is_urbg<URBG>::value && HasInvokeMock<URBG>::value)>* = nullptr> - BitGenRef(URBG& gen) // NOLINT + BitGenRef(URBGRef&& gen ABSL_ATTRIBUTE_LIFETIME_BOUND) // NOLINT : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), mock_call_(&MockCall<URBG>), generate_impl_fn_(ImplFn<URBG>) {} diff --git a/absl/random/discrete_distribution_test.cc b/absl/random/discrete_distribution_test.cc index 32405ea9..f82ef840 100644 --- a/absl/random/discrete_distribution_test.cc +++ b/absl/random/discrete_distribution_test.cc @@ -200,7 +200,7 @@ TEST(DiscreteDistributionTest, ChiSquaredTest50) { } TEST(DiscreteDistributionTest, StabilityTest) { - // absl::discrete_distribution stabilitiy relies on + // absl::discrete_distribution stability relies on // absl::uniform_int_distribution and absl::bernoulli_distribution. absl::random_internal::sequence_urbg urbg( {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, diff --git a/absl/random/distributions.h b/absl/random/distributions.h index 4e3b332e..b6ade685 100644 --- a/absl/random/distributions.h +++ b/absl/random/distributions.h @@ -32,8 +32,8 @@ // continuously and independently at a constant average rate // * `absl::Gaussian` (also known as "normal distributions") for continuous // distributions using an associated quadratic function -// * `absl::LogUniform` for continuous uniform distributions where the log -// to the given base of all values is uniform +// * `absl::LogUniform` for discrete distributions where the log to the given +// base of all values is uniform // * `absl::Poisson` for discrete probability distributions that express the // probability of a given number of events occurring within a fixed interval // * `absl::Zipf` for discrete probability distributions commonly used for @@ -46,23 +46,23 @@ #ifndef ABSL_RANDOM_DISTRIBUTIONS_H_ #define ABSL_RANDOM_DISTRIBUTIONS_H_ -#include <algorithm> -#include <cmath> #include <limits> -#include <random> #include <type_traits> +#include "absl/base/config.h" #include "absl/base/internal/inline_variable.h" +#include "absl/meta/type_traits.h" #include "absl/random/bernoulli_distribution.h" #include "absl/random/beta_distribution.h" #include "absl/random/exponential_distribution.h" #include "absl/random/gaussian_distribution.h" #include "absl/random/internal/distribution_caller.h" // IWYU pragma: export +#include "absl/random/internal/traits.h" #include "absl/random/internal/uniform_helper.h" // IWYU pragma: export #include "absl/random/log_uniform_int_distribution.h" #include "absl/random/poisson_distribution.h" -#include "absl/random/uniform_int_distribution.h" -#include "absl/random/uniform_real_distribution.h" +#include "absl/random/uniform_int_distribution.h" // IWYU pragma: export +#include "absl/random/uniform_real_distribution.h" // IWYU pragma: export #include "absl/random/zipf_distribution.h" namespace absl { @@ -176,7 +176,7 @@ Uniform(TagType tag, return random_internal::DistributionCaller<gen_t>::template Call< distribution_t>(&urbg, tag, static_cast<return_t>(lo), - static_cast<return_t>(hi)); + static_cast<return_t>(hi)); } // absl::Uniform(bitgen, lo, hi) @@ -200,7 +200,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) return random_internal::DistributionCaller<gen_t>::template Call< distribution_t>(&urbg, static_cast<return_t>(lo), - static_cast<return_t>(hi)); + static_cast<return_t>(hi)); } // absl::Uniform<unsigned T>(bitgen) @@ -208,7 +208,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) // Overload of Uniform() using the minimum and maximum values of a given type // `T` (which must be unsigned), returning a value of type `unsigned T` template <typename R, typename URBG> -typename absl::enable_if_t<!std::is_signed<R>::value, R> // +typename absl::enable_if_t<!std::numeric_limits<R>::is_signed, R> // Uniform(URBG&& urbg) { // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = random_internal::UniformDistributionWrapper<R>; @@ -362,7 +362,7 @@ RealType Gaussian(URBG&& urbg, // NOLINT(runtime/references) // If `lo` is nonzero then this distribution is shifted to the desired interval, // so LogUniform(lo, hi, b) is equivalent to LogUniform(0, hi-lo, b)+lo. // -// See https://en.wikipedia.org/wiki/Log-normal_distribution +// See https://en.wikipedia.org/wiki/Reciprocal_distribution // // Example: // diff --git a/absl/random/distributions_test.cc b/absl/random/distributions_test.cc index 5321a11c..ea321839 100644 --- a/absl/random/distributions_test.cc +++ b/absl/random/distributions_test.cc @@ -17,10 +17,14 @@ #include <cfloat> #include <cmath> #include <cstdint> -#include <random> +#include <limits> +#include <type_traits> +#include <utility> #include <vector> #include "gtest/gtest.h" +#include "absl/meta/type_traits.h" +#include "absl/numeric/int128.h" #include "absl/random/internal/distribution_test_util.h" #include "absl/random/random.h" @@ -30,7 +34,6 @@ constexpr int kSize = 400000; class RandomDistributionsTest : public testing::Test {}; - struct Invalid {}; template <typename A, typename B> @@ -93,17 +96,18 @@ void CheckArgsInferType() { } template <typename A, typename B, typename ExplicitRet> -auto ExplicitUniformReturnT(int) -> decltype( - absl::Uniform<ExplicitRet>(*std::declval<absl::InsecureBitGen*>(), - std::declval<A>(), std::declval<B>())); +auto ExplicitUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>( + std::declval<absl::InsecureBitGen&>(), + std::declval<A>(), std::declval<B>())); template <typename, typename, typename ExplicitRet> Invalid ExplicitUniformReturnT(...); template <typename TagType, typename A, typename B, typename ExplicitRet> -auto ExplicitTaggedUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>( - std::declval<TagType>(), *std::declval<absl::InsecureBitGen*>(), - std::declval<A>(), std::declval<B>())); +auto ExplicitTaggedUniformReturnT(int) + -> decltype(absl::Uniform<ExplicitRet>( + std::declval<TagType>(), std::declval<absl::InsecureBitGen&>(), + std::declval<A>(), std::declval<B>())); template <typename, typename, typename, typename ExplicitRet> Invalid ExplicitTaggedUniformReturnT(...); @@ -135,6 +139,14 @@ void CheckArgsReturnExpectedType() { ""); } +// Takes the type of `absl::Uniform<R>(gen)` if valid or `Invalid` otherwise. +template <typename R> +auto UniformNoBoundsReturnT(int) + -> decltype(absl::Uniform<R>(std::declval<absl::InsecureBitGen&>())); + +template <typename> +Invalid UniformNoBoundsReturnT(...); + TEST_F(RandomDistributionsTest, UniformTypeInference) { // Infers common types. CheckArgsInferType<uint16_t, uint16_t, uint16_t>(); @@ -221,6 +233,38 @@ TEST_F(RandomDistributionsTest, UniformNoBounds) { absl::Uniform<uint32_t>(gen); absl::Uniform<uint64_t>(gen); absl::Uniform<absl::uint128>(gen); + + // Compile-time validity tests. + + // Allows unsigned ints. + testing::StaticAssertTypeEq<uint8_t, + decltype(UniformNoBoundsReturnT<uint8_t>(0))>(); + testing::StaticAssertTypeEq<uint16_t, + decltype(UniformNoBoundsReturnT<uint16_t>(0))>(); + testing::StaticAssertTypeEq<uint32_t, + decltype(UniformNoBoundsReturnT<uint32_t>(0))>(); + testing::StaticAssertTypeEq<uint64_t, + decltype(UniformNoBoundsReturnT<uint64_t>(0))>(); + testing::StaticAssertTypeEq< + absl::uint128, decltype(UniformNoBoundsReturnT<absl::uint128>(0))>(); + + // Disallows signed ints. + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<int8_t>(0))>(); + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<int16_t>(0))>(); + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<int32_t>(0))>(); + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<int64_t>(0))>(); + testing::StaticAssertTypeEq< + Invalid, decltype(UniformNoBoundsReturnT<absl::int128>(0))>(); + + // Disallows float types. + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<float>(0))>(); + testing::StaticAssertTypeEq<Invalid, + decltype(UniformNoBoundsReturnT<double>(0))>(); } TEST_F(RandomDistributionsTest, UniformNonsenseRanges) { diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel index 71a742ee..5e05130d 100644 --- a/absl/random/internal/BUILD.bazel +++ b/absl/random/internal/BUILD.bazel @@ -137,7 +137,7 @@ cc_library( cc_library( name = "explicit_seed_seq", - testonly = 1, + testonly = True, hdrs = [ "explicit_seed_seq.h", ], @@ -151,7 +151,7 @@ cc_library( cc_library( name = "sequence_urbg", - testonly = 1, + testonly = True, hdrs = [ "sequence_urbg.h", ], @@ -375,7 +375,7 @@ cc_binary( cc_library( name = "distribution_test_util", - testonly = 1, + testonly = True, srcs = [ "chi_square.cc", "distribution_test_util.cc", @@ -527,6 +527,7 @@ cc_library( hdrs = ["mock_helpers.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ + "//absl/base:config", "//absl/base:fast_type_id", "//absl/types:optional", ], @@ -534,11 +535,12 @@ cc_library( cc_library( name = "mock_overload_set", - testonly = 1, + testonly = True, hdrs = ["mock_overload_set.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ ":mock_helpers", + "//absl/base:config", "//absl/random:mocking_bit_gen", "@com_google_googletest//:gtest", ], @@ -712,7 +714,19 @@ cc_library( ":traits", "//absl/base:config", "//absl/meta:type_traits", - "//absl/numeric:int128", + ], +) + +cc_library( + name = "mock_validators", + hdrs = ["mock_validators.h"], + deps = [ + ":iostream_state_saver", + ":uniform_helper", + "//absl/base:config", + "//absl/base:raw_logging_internal", + "//absl/strings", + "//absl/strings:string_view", ], ) diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h index a7a97bfc..19d05612 100644 --- a/absl/random/internal/mock_helpers.h +++ b/absl/random/internal/mock_helpers.h @@ -16,10 +16,9 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ #define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ -#include <tuple> -#include <type_traits> #include <utility> +#include "absl/base/config.h" #include "absl/base/internal/fast_type_id.h" #include "absl/types/optional.h" @@ -27,6 +26,16 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { +// A no-op validator meeting the ValidatorT requirements for MockHelpers. +// +// Custom validators should follow a similar structure, passing the type to +// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()). +struct NoOpValidator { + // Default validation: do nothing. + template <typename ResultT, typename... Args> + static void Validate(ResultT, Args&&...) {} +}; + // MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and // BitGenRef to enable the mocking capability for absl distribution functions. // @@ -109,22 +118,39 @@ class MockHelpers { 0, urbg, std::forward<Args>(args)...); } - // Acquire a mock for the KeyT (may or may not be a signature). + // Acquire a mock for the KeyT (may or may not be a signature), set up to use + // the ValidatorT to verify that the result is in the range of the RNG + // function. // // KeyT is used to generate a typeid-based lookup for the mock. // KeyT is a signature of the form: // result_type(discriminator_type, std::tuple<args...>) // The mocked function signature will be composed from KeyT as: // result_type(args...) - template <typename KeyT, typename MockURBG> - static auto MockFor(MockURBG& m) + // ValidatorT::Validate will be called after the result of the RNG. The + // signature is expected to be of the form: + // ValidatorT::Validate(result, args...) + template <typename KeyT, typename ValidatorT, typename MockURBG> + static auto MockFor(MockURBG& m, ValidatorT) -> decltype(m.template RegisterMock< typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, std::declval<IdType>())) { + m, std::declval<IdType>(), ValidatorT())) { return m.template RegisterMock<typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, ::absl::base_internal::FastTypeId<KeyT>()); + m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT()); + } + + // Acquire a mock for the KeyT (may or may not be a signature). + // + // KeyT is used to generate a typeid-based lookup for the mock. + // KeyT is a signature of the form: + // result_type(discriminator_type, std::tuple<args...>) + // The mocked function signature will be composed from KeyT as: + // result_type(args...) + template <typename KeyT, typename MockURBG> + static decltype(auto) MockFor(MockURBG& m) { + return MockFor<KeyT>(m, NoOpValidator()); } }; diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h index 0d9c6c12..cfaeeeef 100644 --- a/absl/random/internal/mock_overload_set.h +++ b/absl/random/internal/mock_overload_set.h @@ -16,9 +16,11 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ #define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ +#include <tuple> #include <type_traits> #include "gmock/gmock.h" +#include "absl/base/config.h" #include "absl/random/internal/mock_helpers.h" #include "absl/random/mocking_bit_gen.h" @@ -26,7 +28,7 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { -template <typename DistrT, typename Fn> +template <typename DistrT, typename ValidatorT, typename Fn> struct MockSingleOverload; // MockSingleOverload @@ -38,8 +40,8 @@ struct MockSingleOverload; // arguments to MockingBitGen::Register. // // The underlying KeyT must match the KeyT constructed by DistributionCaller. -template <typename DistrT, typename Ret, typename... Args> -struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...); } }; -template <typename DistrT, typename Ret, typename Arg, typename... Args> -struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename Arg, + typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, + Ret(Arg, MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, - matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...); } }; +// MockOverloadSetWithValidator +// +// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes +// an additional Validator parameter, allowing for customization of the mock +// behavior. +// +// `ValidatorT::Validate(result, args...)` will be called after the mock +// distribution returns a value in `result`, allowing for validation against the +// args. +template <typename DistrT, typename ValidatorT, typename... Fns> +struct MockOverloadSetWithValidator; + +template <typename DistrT, typename ValidatorT, typename Sig> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig> + : public MockSingleOverload<DistrT, ValidatorT, Sig> { + using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call; +}; + +template <typename DistrT, typename ValidatorT, typename FirstSig, + typename... Rest> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...> + : public MockSingleOverload<DistrT, ValidatorT, FirstSig>, + public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> { + using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call; + using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call; +}; + // MockOverloadSet // // MockOverloadSet takes a distribution and a collection of signatures and @@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { // `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution // correctly. template <typename DistrT, typename... Signatures> -struct MockOverloadSet; - -template <typename DistrT, typename Sig> -struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> { - using MockSingleOverload<DistrT, Sig>::gmock_Call; -}; - -template <typename DistrT, typename FirstSig, typename... Rest> -struct MockOverloadSet<DistrT, FirstSig, Rest...> - : public MockSingleOverload<DistrT, FirstSig>, - public MockOverloadSet<DistrT, Rest...> { - using MockSingleOverload<DistrT, FirstSig>::gmock_Call; - using MockOverloadSet<DistrT, Rest...>::gmock_Call; -}; +using MockOverloadSet = + MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>; } // namespace random_internal ABSL_NAMESPACE_END diff --git a/absl/random/internal/mock_validators.h b/absl/random/internal/mock_validators.h new file mode 100644 index 00000000..d76d169c --- /dev/null +++ b/absl/random/internal/mock_validators.h @@ -0,0 +1,98 @@ +// Copyright 2024 The Abseil Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ +#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ + +#include <type_traits> + +#include "absl/base/config.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/uniform_helper.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace absl { +ABSL_NAMESPACE_BEGIN +namespace random_internal { + +template <typename NumType> +class UniformDistributionValidator { + public: + // Handle absl::Uniform<NumType>(gen, absl::IntervalTag, lo, hi). + template <typename TagType> + static void Validate(NumType x, TagType tag, NumType lo, NumType hi) { + // For invalid ranges, absl::Uniform() simply returns one of the bounds. + if (x == lo && lo == hi) return; + + ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi); + } + + // Handle absl::Uniform<NumType>(gen, lo, hi). + static void Validate(NumType x, NumType lo, NumType hi) { + Validate(x, IntervalClosedOpenTag(), lo, hi); + } + + // Handle absl::Uniform<NumType>(gen). + static void Validate(NumType) { + // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any + // value is okay. This overload exists because the validation logic attempts + // to call it anyway rather than adding extra SFINAE. + } + + private: + static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; } + static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; } + static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; } + static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; } + + template <typename TagType> + static void ValidateImpl(std::true_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + // uniform_real_distribution is always closed-open, so the upper bound is + // always non-inclusive. + ABSL_INTERNAL_CHECK(lb <= x && x < ub, + absl::StrCat(x, " is not in ", TagLbBound(tag), lo, + ", ", hi, TagUbBound(tag))); + } + + template <typename TagType> + static void ValidateImpl(std::false_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + using stream_type = + typename random_internal::stream_format_type<NumType>::type; + + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + ABSL_INTERNAL_CHECK( + lb <= x && x <= ub, + absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag), + stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag))); + } +}; + +} // namespace random_internal +ABSL_NAMESPACE_END +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ diff --git a/absl/random/mock_distributions.h b/absl/random/mock_distributions.h index 764ab370..b379262c 100644 --- a/absl/random/mock_distributions.h +++ b/absl/random/mock_distributions.h @@ -46,16 +46,18 @@ #ifndef ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_ #define ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_ -#include <limits> -#include <type_traits> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/meta/type_traits.h" +#include "absl/base/config.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" #include "absl/random/distributions.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" #include "absl/random/internal/mock_overload_set.h" +#include "absl/random/internal/mock_validators.h" +#include "absl/random/log_uniform_int_distribution.h" #include "absl/random/mocking_bit_gen.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/zipf_distribution.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -80,8 +82,9 @@ ABSL_NAMESPACE_BEGIN // assert(x == 123456) // template <typename R> -using MockUniform = random_internal::MockOverloadSet< +using MockUniform = random_internal::MockOverloadSetWithValidator< random_internal::UniformDistributionWrapper<R>, + random_internal::UniformDistributionValidator<R>, R(IntervalClosedOpenTag, MockingBitGen&, R, R), R(IntervalClosedClosedTag, MockingBitGen&, R, R), R(IntervalOpenOpenTag, MockingBitGen&, R, R), diff --git a/absl/random/mock_distributions_test.cc b/absl/random/mock_distributions_test.cc index de23bafe..05e313cd 100644 --- a/absl/random/mock_distributions_test.cc +++ b/absl/random/mock_distributions_test.cc @@ -14,7 +14,13 @@ #include "absl/random/mock_distributions.h" +#include <cmath> +#include <limits> + +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/numeric/int128.h" +#include "absl/random/distributions.h" #include "absl/random/mocking_bit_gen.h" #include "absl/random/random.h" @@ -69,4 +75,213 @@ TEST(MockDistributions, Examples) { EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040); } +TEST(MockUniform, OutOfBoundsIsAllowed) { + absl::UnvalidatedMockingBitGen gen; + + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)).WillOnce(Return(0)); + EXPECT_EQ(absl::Uniform<int>(gen, 1, 100), 0); +} + +TEST(ValidatedMockDistributions, UniformUInt128Works) { + absl::MockingBitGen gen; + + EXPECT_CALL(absl::MockUniform<absl::uint128>(), Call(gen)) + .WillOnce(Return(absl::Uint128Max())); + EXPECT_EQ(absl::Uniform<absl::uint128>(gen), absl::Uint128Max()); +} + +TEST(ValidatedMockDistributions, UniformDoubleBoundaryCases) { + absl::MockingBitGen gen; + + EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0)) + .WillOnce(Return( + std::nextafter(10.0, -std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), + std::nextafter(10.0, -std::numeric_limits<double>::infinity())); + + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return( + std::nextafter(10.0, -std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + std::nextafter(10.0, -std::numeric_limits<double>::infinity())); + + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce( + Return(std::nextafter(1.0, std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + std::nextafter(1.0, std::numeric_limits<double>::infinity())); +} + +TEST(ValidatedMockDistributions, UniformDoubleEmptyRangeCases) { + absl::MockingBitGen gen; + + ON_CALL(absl::MockUniform<double>(), Call(absl::IntervalOpen, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 1.0), 1.0); + + ON_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpenClosed, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpenClosed, gen, 1.0, 1.0), + 1.0); + + ON_CALL(absl::MockUniform<double>(), + Call(absl::IntervalClosedOpen, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalClosedOpen, gen, 1.0, 1.0), + 1.0); +} + +TEST(ValidatedMockDistributions, UniformIntEmptyRangeCases) { + absl::MockingBitGen gen; + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpen, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpen, gen, 1, 1), 1); + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpenClosed, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 1), 1); + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalClosedOpen, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalClosedOpen, gen, 1, 1), 1); +} + +TEST(ValidatedMockUniformDeathTest, Examples) { + absl::MockingBitGen gen; + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(gen, 1, 100); + }, + " 0 is not in \\[1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(gen, 1, 100); + }, + " 101 is not in \\[1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(100)); + absl::Uniform<int>(gen, 1, 100); + }, + " 100 is not in \\[1, 100\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(1)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 1 is not in \\(1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(100)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 100 is not in \\(1, 100\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(1)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 1 is not in \\(1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\]"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 0 is not in \\(1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\]"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalClosed, gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100); + }, + " 0 is not in \\[1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100); + }, + " 101 is not in \\[1, 100\\]"); +} + +TEST(ValidatedMockUniformDeathTest, DoubleBoundaryCases) { + absl::MockingBitGen gen; + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0)) + .WillOnce(Return(10.0)); + EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), 10.0); + }, + " 10 is not in \\[1, 10\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return(10.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + 10.0); + }, + " 10 is not in \\(1, 10\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + 1.0); + }, + " 1 is not in \\(1, 10\\)"); +} + } // namespace diff --git a/absl/random/mocking_bit_gen.h b/absl/random/mocking_bit_gen.h index 89fa5a47..041989de 100644 --- a/absl/random/mocking_bit_gen.h +++ b/absl/random/mocking_bit_gen.h @@ -28,83 +28,37 @@ #ifndef ABSL_RANDOM_MOCKING_BIT_GEN_H_ #define ABSL_RANDOM_MOCKING_BIT_GEN_H_ -#include <iterator> -#include <limits> #include <memory> #include <tuple> #include <type_traits> #include <utility> #include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/base/attributes.h" +#include "absl/base/config.h" #include "absl/base/internal/fast_type_id.h" #include "absl/container/flat_hash_map.h" #include "absl/meta/type_traits.h" -#include "absl/random/distributions.h" -#include "absl/random/internal/distribution_caller.h" +#include "absl/random/internal/mock_helpers.h" #include "absl/random/random.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" #include "absl/utility/utility.h" namespace absl { ABSL_NAMESPACE_BEGIN +class BitGenRef; + namespace random_internal { template <typename> struct DistributionCaller; class MockHelpers; -} // namespace random_internal -class BitGenRef; - -// MockingBitGen -// -// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class -// which can act in place of an `absl::BitGen` URBG within tests using the -// Googletest testing framework. -// -// Usage: -// -// Use an `absl::MockingBitGen` along with a mock distribution object (within -// mock_distributions.h) inside Googletest constructs such as ON_CALL(), -// EXPECT_TRUE(), etc. to produce deterministic results conforming to the -// distribution's API contract. -// -// Example: -// -// // Mock a call to an `absl::Bernoulli` distribution using Googletest -// absl::MockingBitGen bitgen; -// -// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5)) -// .WillByDefault(testing::Return(true)); -// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5)); -// -// // Mock a call to an `absl::Uniform` distribution within Googletest -// absl::MockingBitGen bitgen; -// -// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_)) -// .WillByDefault([] (int low, int high) { -// return low + (high - low) / 2; -// }); -// -// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5); -// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35); -// -// At this time, only mock distributions supplied within the Abseil random -// library are officially supported. -// -// EXPECT_CALL and ON_CALL need to be made within the same DLL component as -// the call to absl::Uniform and related methods, otherwise mocking will fail -// since the underlying implementation creates a type-specific pointer which -// will be distinct across different DLL boundaries. -// -class MockingBitGen { +// Implements MockingBitGen with an option to turn on extra validation. +template <bool EnableValidation> +class MockingBitGenImpl { public: - MockingBitGen() = default; - ~MockingBitGen() = default; + MockingBitGenImpl() = default; + ~MockingBitGenImpl() = default; // URBG interface using result_type = absl::BitGen::result_type; @@ -125,15 +79,19 @@ class MockingBitGen { // NOTE: MockFnCaller is essentially equivalent to the lambda: // [fn](auto... args) { return fn->Call(std::move(args)...)} // however that fails to build on some supported platforms. - template <typename MockFnType, typename ResultT, typename Tuple> + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename Tuple> struct MockFnCaller; // specialization for std::tuple. - template <typename MockFnType, typename ResultT, typename... Args> - struct MockFnCaller<MockFnType, ResultT, std::tuple<Args...>> { + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename... Args> + struct MockFnCaller<MockFnType, ValidatorT, ResultT, std::tuple<Args...>> { MockFnType* fn; inline ResultT operator()(Args... args) { - return fn->Call(std::move(args)...); + ResultT result = fn->Call(args...); + ValidatorT::Validate(result, args...); + return result; } }; @@ -150,16 +108,17 @@ class MockingBitGen { /*ResultT*/ void* result) = 0; }; - template <typename MockFnType, typename ResultT, typename ArgTupleT> + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename ArgTupleT> class FunctionHolderImpl final : public FunctionHolder { public: - void Apply(void* args_tuple, void* result) override { + void Apply(void* args_tuple, void* result) final { // Requires tuple_args to point to a ArgTupleT, which is a // std::tuple<Args...> used to invoke the mock function. Requires result // to point to a ResultT, which is the result of the call. - *static_cast<ResultT*>(result) = - absl::apply(MockFnCaller<MockFnType, ResultT, ArgTupleT>{&mock_fn_}, - *static_cast<ArgTupleT*>(args_tuple)); + *static_cast<ResultT*>(result) = absl::apply( + MockFnCaller<MockFnType, ValidatorT, ResultT, ArgTupleT>{&mock_fn_}, + *static_cast<ArgTupleT*>(args_tuple)); } MockFnType mock_fn_; @@ -175,26 +134,29 @@ class MockingBitGen { // // The returned MockFunction<...> type can be used to setup additional // distribution parameters of the expectation. - template <typename ResultT, typename ArgTupleT, typename SelfT> - auto RegisterMock(SelfT&, base_internal::FastTypeIdType type) + template <typename ResultT, typename ArgTupleT, typename SelfT, + typename ValidatorT> + auto RegisterMock(SelfT&, base_internal::FastTypeIdType type, ValidatorT) -> decltype(GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>()))& { + using ActualValidatorT = + std::conditional_t<EnableValidation, ValidatorT, NoOpValidator>; using MockFnType = decltype(GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>())); using WrappedFnType = absl::conditional_t< - std::is_same<SelfT, ::testing::NiceMock<absl::MockingBitGen>>::value, + std::is_same<SelfT, ::testing::NiceMock<MockingBitGenImpl>>::value, ::testing::NiceMock<MockFnType>, absl::conditional_t< - std::is_same<SelfT, - ::testing::NaggyMock<absl::MockingBitGen>>::value, + std::is_same<SelfT, ::testing::NaggyMock<MockingBitGenImpl>>::value, ::testing::NaggyMock<MockFnType>, absl::conditional_t< std::is_same<SelfT, - ::testing::StrictMock<absl::MockingBitGen>>::value, + ::testing::StrictMock<MockingBitGenImpl>>::value, ::testing::StrictMock<MockFnType>, MockFnType>>>; - using ImplT = FunctionHolderImpl<WrappedFnType, ResultT, ArgTupleT>; + using ImplT = + FunctionHolderImpl<WrappedFnType, ActualValidatorT, ResultT, ArgTupleT>; auto& mock = mocks_[type]; if (!mock) { mock = absl::make_unique<ImplT>(); @@ -234,6 +196,58 @@ class MockingBitGen { // InvokeMock }; +} // namespace random_internal + +// MockingBitGen +// +// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class +// which can act in place of an `absl::BitGen` URBG within tests using the +// Googletest testing framework. +// +// Usage: +// +// Use an `absl::MockingBitGen` along with a mock distribution object (within +// mock_distributions.h) inside Googletest constructs such as ON_CALL(), +// EXPECT_TRUE(), etc. to produce deterministic results conforming to the +// distribution's API contract. +// +// Example: +// +// // Mock a call to an `absl::Bernoulli` distribution using Googletest +// absl::MockingBitGen bitgen; +// +// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5)) +// .WillByDefault(testing::Return(true)); +// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5)); +// +// // Mock a call to an `absl::Uniform` distribution within Googletest +// absl::MockingBitGen bitgen; +// +// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_)) +// .WillByDefault([] (int low, int high) { +// return low + (high - low) / 2; +// }); +// +// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5); +// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35); +// +// At this time, only mock distributions supplied within the Abseil random +// library are officially supported. +// +// EXPECT_CALL and ON_CALL need to be made within the same DLL component as +// the call to absl::Uniform and related methods, otherwise mocking will fail +// since the underlying implementation creates a type-specific pointer which +// will be distinct across different DLL boundaries. +// +using MockingBitGen = random_internal::MockingBitGenImpl<true>; + +// UnvalidatedMockingBitGen +// +// UnvalidatedMockingBitGen is a variant of MockingBitGen which does no extra +// validation. +using UnvalidatedMockingBitGen ABSL_DEPRECATED("Use MockingBitGen instead") = + random_internal::MockingBitGenImpl<false>; + ABSL_NAMESPACE_END } // namespace absl diff --git a/absl/random/mocking_bit_gen_test.cc b/absl/random/mocking_bit_gen_test.cc index c713ceaf..26e673ac 100644 --- a/absl/random/mocking_bit_gen_test.cc +++ b/absl/random/mocking_bit_gen_test.cc @@ -16,8 +16,11 @@ #include "absl/random/mocking_bit_gen.h" #include <cmath> +#include <cstddef> +#include <cstdint> +#include <iterator> #include <numeric> -#include <random> +#include <vector> #include "gmock/gmock.h" #include "gtest/gtest-spi.h" @@ -176,12 +179,18 @@ TEST(BasicMocking, MultipleGenerators) { EXPECT_NE(get_value(mocked_with_11), 11); } -TEST(BasicMocking, MocksNotTrigeredForIncorrectTypes) { +TEST(BasicMocking, MocksNotTriggeredForIncorrectTypes) { absl::MockingBitGen gen; - EXPECT_CALL(absl::MockUniform<uint32_t>(), Call(gen)).WillOnce(Return(42)); - - EXPECT_NE(absl::Uniform<uint16_t>(gen), 42); // Not mocked - EXPECT_EQ(absl::Uniform<uint32_t>(gen), 42); // Mock triggered + EXPECT_CALL(absl::MockUniform<uint32_t>(), Call(gen)) + .WillRepeatedly(Return(42)); + + bool uint16_always42 = true; + for (int i = 0; i < 10000; i++) { + EXPECT_EQ(absl::Uniform<uint32_t>(gen), 42); // Mock triggered. + // uint16_t not mocked. + uint16_always42 = uint16_always42 && absl::Uniform<uint16_t>(gen) == 42; + } + EXPECT_FALSE(uint16_always42); } TEST(BasicMocking, FailsOnUnsatisfiedMocks) { @@ -239,33 +248,33 @@ TEST(WillOnce, DistinctCounters) { absl::MockingBitGen gen; EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000)) .Times(3) - .WillRepeatedly(Return(0)); + .WillRepeatedly(Return(1)); EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1000001, 2000000)) .Times(3) - .WillRepeatedly(Return(1)); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); + .WillRepeatedly(Return(1000001)); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); } TEST(TimesModifier, ModifierSaturatesAndExpires) { EXPECT_NONFATAL_FAILURE( []() { absl::MockingBitGen gen; - EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000)) + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 0, 1000000)) .Times(3) .WillRepeatedly(Return(15)) .RetiresOnSaturation(); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); // Times(3) has expired - Should get a different value now. - EXPECT_NE(absl::Uniform(gen, 1, 1000000), 15); + EXPECT_NE(absl::Uniform(gen, 0, 1000000), 15); }(), ""); } @@ -387,7 +396,7 @@ TEST(MockingBitGen, StrictMock_TooMany) { EXPECT_EQ(absl::Uniform(gen, 1, 1000), 145); EXPECT_NONFATAL_FAILURE( - [&]() { EXPECT_EQ(absl::Uniform(gen, 10, 1000), 0); }(), + [&]() { EXPECT_EQ(absl::Uniform(gen, 0, 1000), 0); }(), "over-saturated and active"); } diff --git a/absl/random/seed_sequences.h b/absl/random/seed_sequences.h index c3af4b00..33970be5 100644 --- a/absl/random/seed_sequences.h +++ b/absl/random/seed_sequences.h @@ -29,9 +29,11 @@ #include <random> #include "absl/base/config.h" +#include "absl/base/nullability.h" #include "absl/random/internal/salted_seed_seq.h" #include "absl/random/internal/seed_material.h" #include "absl/random/seed_gen_exception.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" namespace absl { |