diff options
author | Justin Bassett <jbassett@google.com> | 2024-05-20 10:44:01 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2024-05-20 10:45:19 -0700 |
commit | 254b3a5326932026fd23923fd367619d2837f0ad (patch) | |
tree | ae7b15f8c6462ed407c89ef5144fdbf5a5a91e2b /absl | |
parent | 93ac3a4f9ee7792af399cebd873ee99ce15aed08 (diff) |
Add (unused) validation to absl::MockingBitGen
`absl::Uniform(tag, rng, a, b)` has some restrictions on the values it can produce in that it will always be in the range specified by `a` and `b`, but these restrictions can be violated by `absl::MockingBitGen`. This makes it easier than necessary to introduce a bug in tests using a mock RNG.
We can fix this by making `MockingBitGen` emit a runtime error if the value produced is out of bounds.
Immediately fixing all the internal buggy uses of `MockingBitGen` is currently infeasible, so the plan is this:
1. Add turned-off validation to `MockingBitGen` to avoid the costs of maintaining unsubmitted code.
2. Temporarily migrate the internal buggy use cases to keep the current behavior, to be fixed later.
3. Turn on validation for `MockingBitGen`.
4. Fix the internal buggy use cases over time.
---
A few of the different categories of errors I found:
- `Call(tag, rng, a, b) -> a or b`, for open/half-open intervals (i.e. incorrect boundary condition). This case happens quite a lot, e.g. by specifying `absl::Uniform<double>(rng, 0, 1)` to return `1.0`.
- `Call(tag, rng, 0, 1) -> 42` (i.e. return an arbitrary value). These may be straightforward to fix by just returning an in-range value, or sometimes they are difficult to fix because other data structures depend on those values.
PiperOrigin-RevId: 635503223
Change-Id: I9293ab78e79450e2b7b682dcb05149f238ecc550
Diffstat (limited to 'absl')
-rw-r--r-- | absl/random/BUILD.bazel | 13 | ||||
-rw-r--r-- | absl/random/CMakeLists.txt | 33 | ||||
-rw-r--r-- | absl/random/internal/BUILD.bazel | 16 | ||||
-rw-r--r-- | absl/random/internal/mock_helpers.h | 40 | ||||
-rw-r--r-- | absl/random/internal/mock_overload_set.h | 82 | ||||
-rw-r--r-- | absl/random/internal/mock_validators.h | 98 | ||||
-rw-r--r-- | absl/random/mock_distributions.h | 19 | ||||
-rw-r--r-- | absl/random/mock_distributions_test.cc | 206 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen.h | 158 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen_test.cc | 32 |
10 files changed, 553 insertions, 144 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel index 9ae3bc8a..0f31d919 100644 --- a/absl/random/BUILD.bazel +++ b/absl/random/BUILD.bazel @@ -140,9 +140,9 @@ cc_library( deps = [ ":distributions", ":mocking_bit_gen", - "//absl/meta:type_traits", + "//absl/base:config", "//absl/random/internal:mock_overload_set", - "@com_google_googletest//:gtest", + "//absl/random/internal:mock_validators", ], ) @@ -154,15 +154,13 @@ cc_library( ], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ - ":distributions", ":random", + "//absl/base:config", + "//absl/base:core_headers", "//absl/base:fast_type_id", "//absl/container:flat_hash_map", "//absl/meta:type_traits", - "//absl/random/internal:distribution_caller", - "//absl/strings", - "//absl/types:span", - "//absl/types:variant", + "//absl/random/internal:mock_helpers", "//absl/utility", "@com_google_googletest//:gtest", ], @@ -481,6 +479,7 @@ cc_test( "no_test_wasm", ], deps = [ + ":distributions", ":mock_distributions", ":mocking_bit_gen", ":random", diff --git a/absl/random/CMakeLists.txt b/absl/random/CMakeLists.txt index d30b43fc..3cf65b69 100644 --- a/absl/random/CMakeLists.txt +++ b/absl/random/CMakeLists.txt @@ -77,6 +77,7 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config absl::fast_type_id absl::optional ) @@ -92,6 +93,7 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config absl::random_mocking_bit_gen absl::random_internal_mock_helpers TESTONLY @@ -108,17 +110,15 @@ absl_cc_library( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::config + absl::core_headers absl::flat_hash_map absl::raw_logging_internal - absl::random_distributions - absl::random_internal_distribution_caller + absl::random_internal_mock_helpers absl::random_internal_mock_overload_set + absl::random_internal_mock_validators absl::random_random - absl::strings - absl::span - absl::type_traits absl::utility - absl::variant GTest::gmock GTest::gtest PUBLIC @@ -135,6 +135,7 @@ absl_cc_test( LINKOPTS ${ABSL_DEFAULT_LINKOPTS} DEPS + absl::random_distributions absl::random_mocking_bit_gen absl::random_random GTest::gmock @@ -1173,6 +1174,26 @@ absl_cc_library( ) # Internal-only target, do not depend on directly. +absl_cc_library( + NAME + random_internal_mock_validators + HDRS + "internal/mock_validators.h" + COPTS + ${ABSL_DEFAULT_COPTS} + LINKOPTS + ${ABSL_DEFAULT_LINKOPTS} + DEPS + absl::random_internal_iostream_state_saver + absl::random_internal_uniform_helper + absl::config + absl::raw_logging_internal + absl::strings + absl::string_view + TESTONLY +) + +# Internal-only target, do not depend on directly. absl_cc_test( NAME random_internal_uniform_helper_test diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel index 69fb5f2b..5e05130d 100644 --- a/absl/random/internal/BUILD.bazel +++ b/absl/random/internal/BUILD.bazel @@ -527,6 +527,7 @@ cc_library( hdrs = ["mock_helpers.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ + "//absl/base:config", "//absl/base:fast_type_id", "//absl/types:optional", ], @@ -539,6 +540,7 @@ cc_library( linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ ":mock_helpers", + "//absl/base:config", "//absl/random:mocking_bit_gen", "@com_google_googletest//:gtest", ], @@ -712,7 +714,19 @@ cc_library( ":traits", "//absl/base:config", "//absl/meta:type_traits", - "//absl/numeric:int128", + ], +) + +cc_library( + name = "mock_validators", + hdrs = ["mock_validators.h"], + deps = [ + ":iostream_state_saver", + ":uniform_helper", + "//absl/base:config", + "//absl/base:raw_logging_internal", + "//absl/strings", + "//absl/strings:string_view", ], ) diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h index a7a97bfc..19d05612 100644 --- a/absl/random/internal/mock_helpers.h +++ b/absl/random/internal/mock_helpers.h @@ -16,10 +16,9 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ #define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ -#include <tuple> -#include <type_traits> #include <utility> +#include "absl/base/config.h" #include "absl/base/internal/fast_type_id.h" #include "absl/types/optional.h" @@ -27,6 +26,16 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { +// A no-op validator meeting the ValidatorT requirements for MockHelpers. +// +// Custom validators should follow a similar structure, passing the type to +// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()). +struct NoOpValidator { + // Default validation: do nothing. + template <typename ResultT, typename... Args> + static void Validate(ResultT, Args&&...) {} +}; + // MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and // BitGenRef to enable the mocking capability for absl distribution functions. // @@ -109,22 +118,39 @@ class MockHelpers { 0, urbg, std::forward<Args>(args)...); } - // Acquire a mock for the KeyT (may or may not be a signature). + // Acquire a mock for the KeyT (may or may not be a signature), set up to use + // the ValidatorT to verify that the result is in the range of the RNG + // function. // // KeyT is used to generate a typeid-based lookup for the mock. // KeyT is a signature of the form: // result_type(discriminator_type, std::tuple<args...>) // The mocked function signature will be composed from KeyT as: // result_type(args...) - template <typename KeyT, typename MockURBG> - static auto MockFor(MockURBG& m) + // ValidatorT::Validate will be called after the result of the RNG. The + // signature is expected to be of the form: + // ValidatorT::Validate(result, args...) + template <typename KeyT, typename ValidatorT, typename MockURBG> + static auto MockFor(MockURBG& m, ValidatorT) -> decltype(m.template RegisterMock< typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, std::declval<IdType>())) { + m, std::declval<IdType>(), ValidatorT())) { return m.template RegisterMock<typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, ::absl::base_internal::FastTypeId<KeyT>()); + m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT()); + } + + // Acquire a mock for the KeyT (may or may not be a signature). + // + // KeyT is used to generate a typeid-based lookup for the mock. + // KeyT is a signature of the form: + // result_type(discriminator_type, std::tuple<args...>) + // The mocked function signature will be composed from KeyT as: + // result_type(args...) + template <typename KeyT, typename MockURBG> + static decltype(auto) MockFor(MockURBG& m) { + return MockFor<KeyT>(m, NoOpValidator()); } }; diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h index 0d9c6c12..cfaeeeef 100644 --- a/absl/random/internal/mock_overload_set.h +++ b/absl/random/internal/mock_overload_set.h @@ -16,9 +16,11 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ #define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ +#include <tuple> #include <type_traits> #include "gmock/gmock.h" +#include "absl/base/config.h" #include "absl/random/internal/mock_helpers.h" #include "absl/random/mocking_bit_gen.h" @@ -26,7 +28,7 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { -template <typename DistrT, typename Fn> +template <typename DistrT, typename ValidatorT, typename Fn> struct MockSingleOverload; // MockSingleOverload @@ -38,8 +40,8 @@ struct MockSingleOverload; // arguments to MockingBitGen::Register. // // The underlying KeyT must match the KeyT constructed by DistributionCaller. -template <typename DistrT, typename Ret, typename... Args> -struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...); } }; -template <typename DistrT, typename Ret, typename Arg, typename... Args> -struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename Arg, + typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, + Ret(Arg, MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, - matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...); } }; +// MockOverloadSetWithValidator +// +// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes +// an additional Validator parameter, allowing for customization of the mock +// behavior. +// +// `ValidatorT::Validate(result, args...)` will be called after the mock +// distribution returns a value in `result`, allowing for validation against the +// args. +template <typename DistrT, typename ValidatorT, typename... Fns> +struct MockOverloadSetWithValidator; + +template <typename DistrT, typename ValidatorT, typename Sig> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig> + : public MockSingleOverload<DistrT, ValidatorT, Sig> { + using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call; +}; + +template <typename DistrT, typename ValidatorT, typename FirstSig, + typename... Rest> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...> + : public MockSingleOverload<DistrT, ValidatorT, FirstSig>, + public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> { + using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call; + using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call; +}; + // MockOverloadSet // // MockOverloadSet takes a distribution and a collection of signatures and @@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { // `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution // correctly. template <typename DistrT, typename... Signatures> -struct MockOverloadSet; - -template <typename DistrT, typename Sig> -struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> { - using MockSingleOverload<DistrT, Sig>::gmock_Call; -}; - -template <typename DistrT, typename FirstSig, typename... Rest> -struct MockOverloadSet<DistrT, FirstSig, Rest...> - : public MockSingleOverload<DistrT, FirstSig>, - public MockOverloadSet<DistrT, Rest...> { - using MockSingleOverload<DistrT, FirstSig>::gmock_Call; - using MockOverloadSet<DistrT, Rest...>::gmock_Call; -}; +using MockOverloadSet = + MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>; } // namespace random_internal ABSL_NAMESPACE_END diff --git a/absl/random/internal/mock_validators.h b/absl/random/internal/mock_validators.h new file mode 100644 index 00000000..0ab2ee9b --- /dev/null +++ b/absl/random/internal/mock_validators.h @@ -0,0 +1,98 @@ +// Copyright 2024 The Abseil Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ +#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ + +#include <type_traits> + +#include "absl/base/config.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/uniform_helper.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace absl { +ABSL_NAMESPACE_BEGIN +namespace random_internal { + +template <typename NumType> +class UniformDistributionValidator { + public: + template <typename TagType> + static void Validate(NumType x, TagType tag, NumType lo, NumType hi) { + // For invalid ranges, absl::Uniform() simply returns one of the bounds. + if (x == lo && lo == hi) return; + + ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi); + } + + static void Validate(NumType x, NumType lo, NumType hi) { + Validate(x, IntervalClosedOpenTag(), lo, hi); + } + + template <typename NumType_ = NumType> + static void Validate(NumType) { + // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any + // value is okay. + static_assert(std::is_integral<NumType_>{}, + "Non-integer types may have valid values outside of the full " + "range (e.g. floating point NaN)."); + } + + private: + static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; } + static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; } + static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; } + static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; } + + template <typename TagType> + static void ValidateImpl(std::true_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + // uniform_real_distribution is always closed-open, so the upper bound is + // always non-inclusive. + ABSL_INTERNAL_CHECK(lb <= x && x < ub, + absl::StrCat(x, " is not in ", TagLbBound(tag), lo, + ", ", hi, TagUbBound(tag))); + } + + template <typename TagType> + static void ValidateImpl(std::false_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + using stream_type = + typename random_internal::stream_format_type<NumType>::type; + + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + ABSL_INTERNAL_CHECK( + lb <= x && x <= ub, + absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag), + stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag))); + } +}; + +} // namespace random_internal +ABSL_NAMESPACE_END +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ diff --git a/absl/random/mock_distributions.h b/absl/random/mock_distributions.h index 764ab370..b379262c 100644 --- a/absl/random/mock_distributions.h +++ b/absl/random/mock_distributions.h @@ -46,16 +46,18 @@ #ifndef ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_ #define ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_ -#include <limits> -#include <type_traits> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/meta/type_traits.h" +#include "absl/base/config.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" #include "absl/random/distributions.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" #include "absl/random/internal/mock_overload_set.h" +#include "absl/random/internal/mock_validators.h" +#include "absl/random/log_uniform_int_distribution.h" #include "absl/random/mocking_bit_gen.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/zipf_distribution.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -80,8 +82,9 @@ ABSL_NAMESPACE_BEGIN // assert(x == 123456) // template <typename R> -using MockUniform = random_internal::MockOverloadSet< +using MockUniform = random_internal::MockOverloadSetWithValidator< random_internal::UniformDistributionWrapper<R>, + random_internal::UniformDistributionValidator<R>, R(IntervalClosedOpenTag, MockingBitGen&, R, R), R(IntervalClosedClosedTag, MockingBitGen&, R, R), R(IntervalOpenOpenTag, MockingBitGen&, R, R), diff --git a/absl/random/mock_distributions_test.cc b/absl/random/mock_distributions_test.cc index de23bafe..917799f0 100644 --- a/absl/random/mock_distributions_test.cc +++ b/absl/random/mock_distributions_test.cc @@ -14,7 +14,12 @@ #include "absl/random/mock_distributions.h" +#include <cmath> +#include <limits> + +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/random/distributions.h" #include "absl/random/mocking_bit_gen.h" #include "absl/random/random.h" @@ -69,4 +74,205 @@ TEST(MockDistributions, Examples) { EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040); } +TEST(MockUniform, OutOfBoundsIsAllowed) { + absl::MockingBitGen gen; + + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)).WillOnce(Return(0)); + EXPECT_EQ(absl::Uniform<int>(gen, 1, 100), 0); +} + +TEST(ValidatedMockDistributions, UniformDoubleBoundaryCases) { + absl::random_internal::MockingBitGenImpl<true> gen; + + EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0)) + .WillOnce(Return( + std::nextafter(10.0, -std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), + std::nextafter(10.0, -std::numeric_limits<double>::infinity())); + + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return( + std::nextafter(10.0, -std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + std::nextafter(10.0, -std::numeric_limits<double>::infinity())); + + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce( + Return(std::nextafter(1.0, std::numeric_limits<double>::infinity()))); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + std::nextafter(1.0, std::numeric_limits<double>::infinity())); +} + +TEST(ValidatedMockDistributions, UniformDoubleEmptyRangeCases) { + absl::random_internal::MockingBitGenImpl<true> gen; + + ON_CALL(absl::MockUniform<double>(), Call(absl::IntervalOpen, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 1.0), 1.0); + + ON_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpenClosed, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpenClosed, gen, 1.0, 1.0), + 1.0); + + ON_CALL(absl::MockUniform<double>(), + Call(absl::IntervalClosedOpen, gen, 1.0, 1.0)) + .WillByDefault(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalClosedOpen, gen, 1.0, 1.0), + 1.0); +} + +TEST(ValidatedMockDistributions, UniformIntEmptyRangeCases) { + absl::random_internal::MockingBitGenImpl<true> gen; + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpen, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpen, gen, 1, 1), 1); + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpenClosed, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 1), 1); + + ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalClosedOpen, gen, 1, 1)) + .WillByDefault(Return(1)); + EXPECT_EQ(absl::Uniform<int>(absl::IntervalClosedOpen, gen, 1, 1), 1); +} + +TEST(ValidatedMockUniformDeathTest, Examples) { + absl::random_internal::MockingBitGenImpl<true> gen; + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(gen, 1, 100); + }, + " 0 is not in \\[1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(gen, 1, 100); + }, + " 101 is not in \\[1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)) + .WillOnce(Return(100)); + absl::Uniform<int>(gen, 1, 100); + }, + " 100 is not in \\[1, 100\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(1)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 1 is not in \\(1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\)"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpen, gen, 1, 100)) + .WillOnce(Return(100)); + absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100); + }, + " 100 is not in \\(1, 100\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(1)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 1 is not in \\(1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\]"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 0 is not in \\(1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalOpenClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100); + }, + " 101 is not in \\(1, 100\\]"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalClosed, gen, 1, 100)) + .WillOnce(Return(0)); + absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100); + }, + " 0 is not in \\[1, 100\\]"); + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<int>(), + Call(absl::IntervalClosed, gen, 1, 100)) + .WillOnce(Return(101)); + absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100); + }, + " 101 is not in \\[1, 100\\]"); +} + +TEST(ValidatedMockUniformDeathTest, DoubleBoundaryCases) { + absl::random_internal::MockingBitGenImpl<true> gen; + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0)) + .WillOnce(Return(10.0)); + EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), 10.0); + }, + " 10 is not in \\[1, 10\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return(10.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + 10.0); + }, + " 10 is not in \\(1, 10\\)"); + + EXPECT_DEATH_IF_SUPPORTED( + { + EXPECT_CALL(absl::MockUniform<double>(), + Call(absl::IntervalOpen, gen, 1.0, 10.0)) + .WillOnce(Return(1.0)); + EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0), + 1.0); + }, + " 1 is not in \\(1, 10\\)"); +} + } // namespace diff --git a/absl/random/mocking_bit_gen.h b/absl/random/mocking_bit_gen.h index 89fa5a47..92f2e4fc 100644 --- a/absl/random/mocking_bit_gen.h +++ b/absl/random/mocking_bit_gen.h @@ -28,83 +28,37 @@ #ifndef ABSL_RANDOM_MOCKING_BIT_GEN_H_ #define ABSL_RANDOM_MOCKING_BIT_GEN_H_ -#include <iterator> -#include <limits> #include <memory> #include <tuple> #include <type_traits> #include <utility> #include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/base/attributes.h" +#include "absl/base/config.h" #include "absl/base/internal/fast_type_id.h" #include "absl/container/flat_hash_map.h" #include "absl/meta/type_traits.h" -#include "absl/random/distributions.h" -#include "absl/random/internal/distribution_caller.h" +#include "absl/random/internal/mock_helpers.h" #include "absl/random/random.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" #include "absl/utility/utility.h" namespace absl { ABSL_NAMESPACE_BEGIN +class BitGenRef; + namespace random_internal { template <typename> struct DistributionCaller; class MockHelpers; -} // namespace random_internal -class BitGenRef; - -// MockingBitGen -// -// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class -// which can act in place of an `absl::BitGen` URBG within tests using the -// Googletest testing framework. -// -// Usage: -// -// Use an `absl::MockingBitGen` along with a mock distribution object (within -// mock_distributions.h) inside Googletest constructs such as ON_CALL(), -// EXPECT_TRUE(), etc. to produce deterministic results conforming to the -// distribution's API contract. -// -// Example: -// -// // Mock a call to an `absl::Bernoulli` distribution using Googletest -// absl::MockingBitGen bitgen; -// -// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5)) -// .WillByDefault(testing::Return(true)); -// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5)); -// -// // Mock a call to an `absl::Uniform` distribution within Googletest -// absl::MockingBitGen bitgen; -// -// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_)) -// .WillByDefault([] (int low, int high) { -// return low + (high - low) / 2; -// }); -// -// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5); -// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35); -// -// At this time, only mock distributions supplied within the Abseil random -// library are officially supported. -// -// EXPECT_CALL and ON_CALL need to be made within the same DLL component as -// the call to absl::Uniform and related methods, otherwise mocking will fail -// since the underlying implementation creates a type-specific pointer which -// will be distinct across different DLL boundaries. -// -class MockingBitGen { +// Implements MockingBitGen with an option to turn on extra validation. +template <bool EnableValidation> +class MockingBitGenImpl { public: - MockingBitGen() = default; - ~MockingBitGen() = default; + MockingBitGenImpl() = default; + ~MockingBitGenImpl() = default; // URBG interface using result_type = absl::BitGen::result_type; @@ -125,15 +79,19 @@ class MockingBitGen { // NOTE: MockFnCaller is essentially equivalent to the lambda: // [fn](auto... args) { return fn->Call(std::move(args)...)} // however that fails to build on some supported platforms. - template <typename MockFnType, typename ResultT, typename Tuple> + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename Tuple> struct MockFnCaller; // specialization for std::tuple. - template <typename MockFnType, typename ResultT, typename... Args> - struct MockFnCaller<MockFnType, ResultT, std::tuple<Args...>> { + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename... Args> + struct MockFnCaller<MockFnType, ValidatorT, ResultT, std::tuple<Args...>> { MockFnType* fn; inline ResultT operator()(Args... args) { - return fn->Call(std::move(args)...); + ResultT result = fn->Call(args...); + ValidatorT::Validate(result, args...); + return result; } }; @@ -150,16 +108,17 @@ class MockingBitGen { /*ResultT*/ void* result) = 0; }; - template <typename MockFnType, typename ResultT, typename ArgTupleT> + template <typename MockFnType, typename ValidatorT, typename ResultT, + typename ArgTupleT> class FunctionHolderImpl final : public FunctionHolder { public: - void Apply(void* args_tuple, void* result) override { + void Apply(void* args_tuple, void* result) final { // Requires tuple_args to point to a ArgTupleT, which is a // std::tuple<Args...> used to invoke the mock function. Requires result // to point to a ResultT, which is the result of the call. - *static_cast<ResultT*>(result) = - absl::apply(MockFnCaller<MockFnType, ResultT, ArgTupleT>{&mock_fn_}, - *static_cast<ArgTupleT*>(args_tuple)); + *static_cast<ResultT*>(result) = absl::apply( + MockFnCaller<MockFnType, ValidatorT, ResultT, ArgTupleT>{&mock_fn_}, + *static_cast<ArgTupleT*>(args_tuple)); } MockFnType mock_fn_; @@ -175,26 +134,29 @@ class MockingBitGen { // // The returned MockFunction<...> type can be used to setup additional // distribution parameters of the expectation. - template <typename ResultT, typename ArgTupleT, typename SelfT> - auto RegisterMock(SelfT&, base_internal::FastTypeIdType type) + template <typename ResultT, typename ArgTupleT, typename SelfT, + typename ValidatorT> + auto RegisterMock(SelfT&, base_internal::FastTypeIdType type, ValidatorT) -> decltype(GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>()))& { + using ActualValidatorT = + std::conditional_t<EnableValidation, ValidatorT, NoOpValidator>; using MockFnType = decltype(GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>())); using WrappedFnType = absl::conditional_t< - std::is_same<SelfT, ::testing::NiceMock<absl::MockingBitGen>>::value, + std::is_same<SelfT, ::testing::NiceMock<MockingBitGenImpl>>::value, ::testing::NiceMock<MockFnType>, absl::conditional_t< - std::is_same<SelfT, - ::testing::NaggyMock<absl::MockingBitGen>>::value, + std::is_same<SelfT, ::testing::NaggyMock<MockingBitGenImpl>>::value, ::testing::NaggyMock<MockFnType>, absl::conditional_t< std::is_same<SelfT, - ::testing::StrictMock<absl::MockingBitGen>>::value, + ::testing::StrictMock<MockingBitGenImpl>>::value, ::testing::StrictMock<MockFnType>, MockFnType>>>; - using ImplT = FunctionHolderImpl<WrappedFnType, ResultT, ArgTupleT>; + using ImplT = + FunctionHolderImpl<WrappedFnType, ActualValidatorT, ResultT, ArgTupleT>; auto& mock = mocks_[type]; if (!mock) { mock = absl::make_unique<ImplT>(); @@ -234,6 +196,58 @@ class MockingBitGen { // InvokeMock }; +} // namespace random_internal + +// MockingBitGen +// +// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class +// which can act in place of an `absl::BitGen` URBG within tests using the +// Googletest testing framework. +// +// Usage: +// +// Use an `absl::MockingBitGen` along with a mock distribution object (within +// mock_distributions.h) inside Googletest constructs such as ON_CALL(), +// EXPECT_TRUE(), etc. to produce deterministic results conforming to the +// distribution's API contract. +// +// Example: +// +// // Mock a call to an `absl::Bernoulli` distribution using Googletest +// absl::MockingBitGen bitgen; +// +// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5)) +// .WillByDefault(testing::Return(true)); +// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5)); +// +// // Mock a call to an `absl::Uniform` distribution within Googletest +// absl::MockingBitGen bitgen; +// +// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_)) +// .WillByDefault([] (int low, int high) { +// return low + (high - low) / 2; +// }); +// +// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5); +// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35); +// +// At this time, only mock distributions supplied within the Abseil random +// library are officially supported. +// +// EXPECT_CALL and ON_CALL need to be made within the same DLL component as +// the call to absl::Uniform and related methods, otherwise mocking will fail +// since the underlying implementation creates a type-specific pointer which +// will be distinct across different DLL boundaries. +// +using MockingBitGen = random_internal::MockingBitGenImpl<false>; + +// UnvalidatedMockingBitGen +// +// UnvalidatedMockingBitGen is a variant of MockingBitGen which does no extra +// validation. +using UnvalidatedMockingBitGen ABSL_DEPRECATED("Use MockingBitGen instead") = + random_internal::MockingBitGenImpl<false>; + ABSL_NAMESPACE_END } // namespace absl diff --git a/absl/random/mocking_bit_gen_test.cc b/absl/random/mocking_bit_gen_test.cc index 9ccdf568..26e673ac 100644 --- a/absl/random/mocking_bit_gen_test.cc +++ b/absl/random/mocking_bit_gen_test.cc @@ -16,9 +16,11 @@ #include "absl/random/mocking_bit_gen.h" #include <cmath> +#include <cstddef> #include <cstdint> +#include <iterator> #include <numeric> -#include <random> +#include <vector> #include "gmock/gmock.h" #include "gtest/gtest-spi.h" @@ -246,33 +248,33 @@ TEST(WillOnce, DistinctCounters) { absl::MockingBitGen gen; EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000)) .Times(3) - .WillRepeatedly(Return(0)); + .WillRepeatedly(Return(1)); EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1000001, 2000000)) .Times(3) - .WillRepeatedly(Return(1)); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); - EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0); + .WillRepeatedly(Return(1000001)); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); + EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001); + EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1); } TEST(TimesModifier, ModifierSaturatesAndExpires) { EXPECT_NONFATAL_FAILURE( []() { absl::MockingBitGen gen; - EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000)) + EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 0, 1000000)) .Times(3) .WillRepeatedly(Return(15)) .RetiresOnSaturation(); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); - EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); + EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15); // Times(3) has expired - Should get a different value now. - EXPECT_NE(absl::Uniform(gen, 1, 1000000), 15); + EXPECT_NE(absl::Uniform(gen, 0, 1000000), 15); }(), ""); } @@ -394,7 +396,7 @@ TEST(MockingBitGen, StrictMock_TooMany) { EXPECT_EQ(absl::Uniform(gen, 1, 1000), 145); EXPECT_NONFATAL_FAILURE( - [&]() { EXPECT_EQ(absl::Uniform(gen, 10, 1000), 0); }(), + [&]() { EXPECT_EQ(absl::Uniform(gen, 0, 1000), 0); }(), "over-saturated and active"); } |