summaryrefslogtreecommitdiff
path: root/absl
diff options
context:
space:
mode:
authorGravatar Justin Bassett <jbassett@google.com>2024-05-20 10:44:01 -0700
committerGravatar Copybara-Service <copybara-worker@google.com>2024-05-20 10:45:19 -0700
commit254b3a5326932026fd23923fd367619d2837f0ad (patch)
treeae7b15f8c6462ed407c89ef5144fdbf5a5a91e2b /absl
parent93ac3a4f9ee7792af399cebd873ee99ce15aed08 (diff)
Add (unused) validation to absl::MockingBitGen
`absl::Uniform(tag, rng, a, b)` has some restrictions on the values it can produce in that it will always be in the range specified by `a` and `b`, but these restrictions can be violated by `absl::MockingBitGen`. This makes it easier than necessary to introduce a bug in tests using a mock RNG. We can fix this by making `MockingBitGen` emit a runtime error if the value produced is out of bounds. Immediately fixing all the internal buggy uses of `MockingBitGen` is currently infeasible, so the plan is this: 1. Add turned-off validation to `MockingBitGen` to avoid the costs of maintaining unsubmitted code. 2. Temporarily migrate the internal buggy use cases to keep the current behavior, to be fixed later. 3. Turn on validation for `MockingBitGen`. 4. Fix the internal buggy use cases over time. --- A few of the different categories of errors I found: - `Call(tag, rng, a, b) -> a or b`, for open/half-open intervals (i.e. incorrect boundary condition). This case happens quite a lot, e.g. by specifying `absl::Uniform<double>(rng, 0, 1)` to return `1.0`. - `Call(tag, rng, 0, 1) -> 42` (i.e. return an arbitrary value). These may be straightforward to fix by just returning an in-range value, or sometimes they are difficult to fix because other data structures depend on those values. PiperOrigin-RevId: 635503223 Change-Id: I9293ab78e79450e2b7b682dcb05149f238ecc550
Diffstat (limited to 'absl')
-rw-r--r--absl/random/BUILD.bazel13
-rw-r--r--absl/random/CMakeLists.txt33
-rw-r--r--absl/random/internal/BUILD.bazel16
-rw-r--r--absl/random/internal/mock_helpers.h40
-rw-r--r--absl/random/internal/mock_overload_set.h82
-rw-r--r--absl/random/internal/mock_validators.h98
-rw-r--r--absl/random/mock_distributions.h19
-rw-r--r--absl/random/mock_distributions_test.cc206
-rw-r--r--absl/random/mocking_bit_gen.h158
-rw-r--r--absl/random/mocking_bit_gen_test.cc32
10 files changed, 553 insertions, 144 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel
index 9ae3bc8a..0f31d919 100644
--- a/absl/random/BUILD.bazel
+++ b/absl/random/BUILD.bazel
@@ -140,9 +140,9 @@ cc_library(
deps = [
":distributions",
":mocking_bit_gen",
- "//absl/meta:type_traits",
+ "//absl/base:config",
"//absl/random/internal:mock_overload_set",
- "@com_google_googletest//:gtest",
+ "//absl/random/internal:mock_validators",
],
)
@@ -154,15 +154,13 @@ cc_library(
],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
- ":distributions",
":random",
+ "//absl/base:config",
+ "//absl/base:core_headers",
"//absl/base:fast_type_id",
"//absl/container:flat_hash_map",
"//absl/meta:type_traits",
- "//absl/random/internal:distribution_caller",
- "//absl/strings",
- "//absl/types:span",
- "//absl/types:variant",
+ "//absl/random/internal:mock_helpers",
"//absl/utility",
"@com_google_googletest//:gtest",
],
@@ -481,6 +479,7 @@ cc_test(
"no_test_wasm",
],
deps = [
+ ":distributions",
":mock_distributions",
":mocking_bit_gen",
":random",
diff --git a/absl/random/CMakeLists.txt b/absl/random/CMakeLists.txt
index d30b43fc..3cf65b69 100644
--- a/absl/random/CMakeLists.txt
+++ b/absl/random/CMakeLists.txt
@@ -77,6 +77,7 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
+ absl::config
absl::fast_type_id
absl::optional
)
@@ -92,6 +93,7 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
+ absl::config
absl::random_mocking_bit_gen
absl::random_internal_mock_helpers
TESTONLY
@@ -108,17 +110,15 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
+ absl::config
+ absl::core_headers
absl::flat_hash_map
absl::raw_logging_internal
- absl::random_distributions
- absl::random_internal_distribution_caller
+ absl::random_internal_mock_helpers
absl::random_internal_mock_overload_set
+ absl::random_internal_mock_validators
absl::random_random
- absl::strings
- absl::span
- absl::type_traits
absl::utility
- absl::variant
GTest::gmock
GTest::gtest
PUBLIC
@@ -135,6 +135,7 @@ absl_cc_test(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
+ absl::random_distributions
absl::random_mocking_bit_gen
absl::random_random
GTest::gmock
@@ -1173,6 +1174,26 @@ absl_cc_library(
)
# Internal-only target, do not depend on directly.
+absl_cc_library(
+ NAME
+ random_internal_mock_validators
+ HDRS
+ "internal/mock_validators.h"
+ COPTS
+ ${ABSL_DEFAULT_COPTS}
+ LINKOPTS
+ ${ABSL_DEFAULT_LINKOPTS}
+ DEPS
+ absl::random_internal_iostream_state_saver
+ absl::random_internal_uniform_helper
+ absl::config
+ absl::raw_logging_internal
+ absl::strings
+ absl::string_view
+ TESTONLY
+)
+
+# Internal-only target, do not depend on directly.
absl_cc_test(
NAME
random_internal_uniform_helper_test
diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel
index 69fb5f2b..5e05130d 100644
--- a/absl/random/internal/BUILD.bazel
+++ b/absl/random/internal/BUILD.bazel
@@ -527,6 +527,7 @@ cc_library(
hdrs = ["mock_helpers.h"],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
+ "//absl/base:config",
"//absl/base:fast_type_id",
"//absl/types:optional",
],
@@ -539,6 +540,7 @@ cc_library(
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
":mock_helpers",
+ "//absl/base:config",
"//absl/random:mocking_bit_gen",
"@com_google_googletest//:gtest",
],
@@ -712,7 +714,19 @@ cc_library(
":traits",
"//absl/base:config",
"//absl/meta:type_traits",
- "//absl/numeric:int128",
+ ],
+)
+
+cc_library(
+ name = "mock_validators",
+ hdrs = ["mock_validators.h"],
+ deps = [
+ ":iostream_state_saver",
+ ":uniform_helper",
+ "//absl/base:config",
+ "//absl/base:raw_logging_internal",
+ "//absl/strings",
+ "//absl/strings:string_view",
],
)
diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h
index a7a97bfc..19d05612 100644
--- a/absl/random/internal/mock_helpers.h
+++ b/absl/random/internal/mock_helpers.h
@@ -16,10 +16,9 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
#define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
-#include <tuple>
-#include <type_traits>
#include <utility>
+#include "absl/base/config.h"
#include "absl/base/internal/fast_type_id.h"
#include "absl/types/optional.h"
@@ -27,6 +26,16 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
+// A no-op validator meeting the ValidatorT requirements for MockHelpers.
+//
+// Custom validators should follow a similar structure, passing the type to
+// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()).
+struct NoOpValidator {
+ // Default validation: do nothing.
+ template <typename ResultT, typename... Args>
+ static void Validate(ResultT, Args&&...) {}
+};
+
// MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and
// BitGenRef to enable the mocking capability for absl distribution functions.
//
@@ -109,22 +118,39 @@ class MockHelpers {
0, urbg, std::forward<Args>(args)...);
}
- // Acquire a mock for the KeyT (may or may not be a signature).
+ // Acquire a mock for the KeyT (may or may not be a signature), set up to use
+ // the ValidatorT to verify that the result is in the range of the RNG
+ // function.
//
// KeyT is used to generate a typeid-based lookup for the mock.
// KeyT is a signature of the form:
// result_type(discriminator_type, std::tuple<args...>)
// The mocked function signature will be composed from KeyT as:
// result_type(args...)
- template <typename KeyT, typename MockURBG>
- static auto MockFor(MockURBG& m)
+ // ValidatorT::Validate will be called after the result of the RNG. The
+ // signature is expected to be of the form:
+ // ValidatorT::Validate(result, args...)
+ template <typename KeyT, typename ValidatorT, typename MockURBG>
+ static auto MockFor(MockURBG& m, ValidatorT)
-> decltype(m.template RegisterMock<
typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
- m, std::declval<IdType>())) {
+ m, std::declval<IdType>(), ValidatorT())) {
return m.template RegisterMock<typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
- m, ::absl::base_internal::FastTypeId<KeyT>());
+ m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT());
+ }
+
+ // Acquire a mock for the KeyT (may or may not be a signature).
+ //
+ // KeyT is used to generate a typeid-based lookup for the mock.
+ // KeyT is a signature of the form:
+ // result_type(discriminator_type, std::tuple<args...>)
+ // The mocked function signature will be composed from KeyT as:
+ // result_type(args...)
+ template <typename KeyT, typename MockURBG>
+ static decltype(auto) MockFor(MockURBG& m) {
+ return MockFor<KeyT>(m, NoOpValidator());
}
};
diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h
index 0d9c6c12..cfaeeeef 100644
--- a/absl/random/internal/mock_overload_set.h
+++ b/absl/random/internal/mock_overload_set.h
@@ -16,9 +16,11 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
#define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
+#include <tuple>
#include <type_traits>
#include "gmock/gmock.h"
+#include "absl/base/config.h"
#include "absl/random/internal/mock_helpers.h"
#include "absl/random/mocking_bit_gen.h"
@@ -26,7 +28,7 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
-template <typename DistrT, typename Fn>
+template <typename DistrT, typename ValidatorT, typename Fn>
struct MockSingleOverload;
// MockSingleOverload
@@ -38,8 +40,8 @@ struct MockSingleOverload;
// arguments to MockingBitGen::Register.
//
// The underlying KeyT must match the KeyT constructed by DistributionCaller.
-template <typename DistrT, typename Ret, typename... Args>
-struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
+template <typename DistrT, typename ValidatorT, typename Ret, typename... Args>
+struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
@@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers)
- -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) {
- static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
- "Mocking requires an absl::MockingBitGen");
- return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...);
+ -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matchers...)) {
+ static_assert(
+ std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
+ std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
+ "Mocking requires an absl::MockingBitGen");
+ return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matchers...);
}
};
-template <typename DistrT, typename Ret, typename Arg, typename... Args>
-struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
+template <typename DistrT, typename ValidatorT, typename Ret, typename Arg,
+ typename... Args>
+struct MockSingleOverload<DistrT, ValidatorT,
+ Ret(Arg, MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
@@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen,
const ::testing::Matcher<Args>&... matchers)
- -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher,
- matchers...)) {
- static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
- "Mocking requires an absl::MockingBitGen");
- return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...);
+ -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matcher, matchers...)) {
+ static_assert(
+ std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
+ std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
+ "Mocking requires an absl::MockingBitGen");
+ return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matcher, matchers...);
}
};
+// MockOverloadSetWithValidator
+//
+// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes
+// an additional Validator parameter, allowing for customization of the mock
+// behavior.
+//
+// `ValidatorT::Validate(result, args...)` will be called after the mock
+// distribution returns a value in `result`, allowing for validation against the
+// args.
+template <typename DistrT, typename ValidatorT, typename... Fns>
+struct MockOverloadSetWithValidator;
+
+template <typename DistrT, typename ValidatorT, typename Sig>
+struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig>
+ : public MockSingleOverload<DistrT, ValidatorT, Sig> {
+ using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call;
+};
+
+template <typename DistrT, typename ValidatorT, typename FirstSig,
+ typename... Rest>
+struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...>
+ : public MockSingleOverload<DistrT, ValidatorT, FirstSig>,
+ public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> {
+ using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call;
+ using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call;
+};
+
// MockOverloadSet
//
// MockOverloadSet takes a distribution and a collection of signatures and
@@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
// `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution
// correctly.
template <typename DistrT, typename... Signatures>
-struct MockOverloadSet;
-
-template <typename DistrT, typename Sig>
-struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> {
- using MockSingleOverload<DistrT, Sig>::gmock_Call;
-};
-
-template <typename DistrT, typename FirstSig, typename... Rest>
-struct MockOverloadSet<DistrT, FirstSig, Rest...>
- : public MockSingleOverload<DistrT, FirstSig>,
- public MockOverloadSet<DistrT, Rest...> {
- using MockSingleOverload<DistrT, FirstSig>::gmock_Call;
- using MockOverloadSet<DistrT, Rest...>::gmock_Call;
-};
+using MockOverloadSet =
+ MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>;
} // namespace random_internal
ABSL_NAMESPACE_END
diff --git a/absl/random/internal/mock_validators.h b/absl/random/internal/mock_validators.h
new file mode 100644
index 00000000..0ab2ee9b
--- /dev/null
+++ b/absl/random/internal/mock_validators.h
@@ -0,0 +1,98 @@
+// Copyright 2024 The Abseil Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
+#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
+
+#include <type_traits>
+
+#include "absl/base/config.h"
+#include "absl/base/internal/raw_logging.h"
+#include "absl/random/internal/iostream_state_saver.h"
+#include "absl/random/internal/uniform_helper.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace absl {
+ABSL_NAMESPACE_BEGIN
+namespace random_internal {
+
+template <typename NumType>
+class UniformDistributionValidator {
+ public:
+ template <typename TagType>
+ static void Validate(NumType x, TagType tag, NumType lo, NumType hi) {
+ // For invalid ranges, absl::Uniform() simply returns one of the bounds.
+ if (x == lo && lo == hi) return;
+
+ ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi);
+ }
+
+ static void Validate(NumType x, NumType lo, NumType hi) {
+ Validate(x, IntervalClosedOpenTag(), lo, hi);
+ }
+
+ template <typename NumType_ = NumType>
+ static void Validate(NumType) {
+ // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any
+ // value is okay.
+ static_assert(std::is_integral<NumType_>{},
+ "Non-integer types may have valid values outside of the full "
+ "range (e.g. floating point NaN).");
+ }
+
+ private:
+ static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; }
+ static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; }
+ static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; }
+ static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; }
+ static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; }
+ static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; }
+ static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; }
+ static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; }
+
+ template <typename TagType>
+ static void ValidateImpl(std::true_type /* is_floating_point */, NumType x,
+ TagType tag, NumType lo, NumType hi) {
+ UniformDistributionWrapper<NumType> dist(tag, lo, hi);
+ NumType lb = dist.a();
+ NumType ub = dist.b();
+ // uniform_real_distribution is always closed-open, so the upper bound is
+ // always non-inclusive.
+ ABSL_INTERNAL_CHECK(lb <= x && x < ub,
+ absl::StrCat(x, " is not in ", TagLbBound(tag), lo,
+ ", ", hi, TagUbBound(tag)));
+ }
+
+ template <typename TagType>
+ static void ValidateImpl(std::false_type /* is_floating_point */, NumType x,
+ TagType tag, NumType lo, NumType hi) {
+ using stream_type =
+ typename random_internal::stream_format_type<NumType>::type;
+
+ UniformDistributionWrapper<NumType> dist(tag, lo, hi);
+ NumType lb = dist.a();
+ NumType ub = dist.b();
+ ABSL_INTERNAL_CHECK(
+ lb <= x && x <= ub,
+ absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag),
+ stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag)));
+ }
+};
+
+} // namespace random_internal
+ABSL_NAMESPACE_END
+} // namespace absl
+
+#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
diff --git a/absl/random/mock_distributions.h b/absl/random/mock_distributions.h
index 764ab370..b379262c 100644
--- a/absl/random/mock_distributions.h
+++ b/absl/random/mock_distributions.h
@@ -46,16 +46,18 @@
#ifndef ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_
#define ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_
-#include <limits>
-#include <type_traits>
-#include <utility>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "absl/meta/type_traits.h"
+#include "absl/base/config.h"
+#include "absl/random/bernoulli_distribution.h"
+#include "absl/random/beta_distribution.h"
#include "absl/random/distributions.h"
+#include "absl/random/exponential_distribution.h"
+#include "absl/random/gaussian_distribution.h"
#include "absl/random/internal/mock_overload_set.h"
+#include "absl/random/internal/mock_validators.h"
+#include "absl/random/log_uniform_int_distribution.h"
#include "absl/random/mocking_bit_gen.h"
+#include "absl/random/poisson_distribution.h"
+#include "absl/random/zipf_distribution.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
@@ -80,8 +82,9 @@ ABSL_NAMESPACE_BEGIN
// assert(x == 123456)
//
template <typename R>
-using MockUniform = random_internal::MockOverloadSet<
+using MockUniform = random_internal::MockOverloadSetWithValidator<
random_internal::UniformDistributionWrapper<R>,
+ random_internal::UniformDistributionValidator<R>,
R(IntervalClosedOpenTag, MockingBitGen&, R, R),
R(IntervalClosedClosedTag, MockingBitGen&, R, R),
R(IntervalOpenOpenTag, MockingBitGen&, R, R),
diff --git a/absl/random/mock_distributions_test.cc b/absl/random/mock_distributions_test.cc
index de23bafe..917799f0 100644
--- a/absl/random/mock_distributions_test.cc
+++ b/absl/random/mock_distributions_test.cc
@@ -14,7 +14,12 @@
#include "absl/random/mock_distributions.h"
+#include <cmath>
+#include <limits>
+
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/random/distributions.h"
#include "absl/random/mocking_bit_gen.h"
#include "absl/random/random.h"
@@ -69,4 +74,205 @@ TEST(MockDistributions, Examples) {
EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040);
}
+TEST(MockUniform, OutOfBoundsIsAllowed) {
+ absl::MockingBitGen gen;
+
+ EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)).WillOnce(Return(0));
+ EXPECT_EQ(absl::Uniform<int>(gen, 1, 100), 0);
+}
+
+TEST(ValidatedMockDistributions, UniformDoubleBoundaryCases) {
+ absl::random_internal::MockingBitGenImpl<true> gen;
+
+ EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0))
+ .WillOnce(Return(
+ std::nextafter(10.0, -std::numeric_limits<double>::infinity())));
+ EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0),
+ std::nextafter(10.0, -std::numeric_limits<double>::infinity()));
+
+ EXPECT_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalOpen, gen, 1.0, 10.0))
+ .WillOnce(Return(
+ std::nextafter(10.0, -std::numeric_limits<double>::infinity())));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
+ std::nextafter(10.0, -std::numeric_limits<double>::infinity()));
+
+ EXPECT_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalOpen, gen, 1.0, 10.0))
+ .WillOnce(
+ Return(std::nextafter(1.0, std::numeric_limits<double>::infinity())));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
+ std::nextafter(1.0, std::numeric_limits<double>::infinity()));
+}
+
+TEST(ValidatedMockDistributions, UniformDoubleEmptyRangeCases) {
+ absl::random_internal::MockingBitGenImpl<true> gen;
+
+ ON_CALL(absl::MockUniform<double>(), Call(absl::IntervalOpen, gen, 1.0, 1.0))
+ .WillByDefault(Return(1.0));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 1.0), 1.0);
+
+ ON_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalOpenClosed, gen, 1.0, 1.0))
+ .WillByDefault(Return(1.0));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpenClosed, gen, 1.0, 1.0),
+ 1.0);
+
+ ON_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalClosedOpen, gen, 1.0, 1.0))
+ .WillByDefault(Return(1.0));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalClosedOpen, gen, 1.0, 1.0),
+ 1.0);
+}
+
+TEST(ValidatedMockDistributions, UniformIntEmptyRangeCases) {
+ absl::random_internal::MockingBitGenImpl<true> gen;
+
+ ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpen, gen, 1, 1))
+ .WillByDefault(Return(1));
+ EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpen, gen, 1, 1), 1);
+
+ ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpenClosed, gen, 1, 1))
+ .WillByDefault(Return(1));
+ EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 1), 1);
+
+ ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalClosedOpen, gen, 1, 1))
+ .WillByDefault(Return(1));
+ EXPECT_EQ(absl::Uniform<int>(absl::IntervalClosedOpen, gen, 1, 1), 1);
+}
+
+TEST(ValidatedMockUniformDeathTest, Examples) {
+ absl::random_internal::MockingBitGenImpl<true> gen;
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
+ .WillOnce(Return(0));
+ absl::Uniform<int>(gen, 1, 100);
+ },
+ " 0 is not in \\[1, 100\\)");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
+ .WillOnce(Return(101));
+ absl::Uniform<int>(gen, 1, 100);
+ },
+ " 101 is not in \\[1, 100\\)");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
+ .WillOnce(Return(100));
+ absl::Uniform<int>(gen, 1, 100);
+ },
+ " 100 is not in \\[1, 100\\)");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpen, gen, 1, 100))
+ .WillOnce(Return(1));
+ absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
+ },
+ " 1 is not in \\(1, 100\\)");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpen, gen, 1, 100))
+ .WillOnce(Return(101));
+ absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
+ },
+ " 101 is not in \\(1, 100\\)");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpen, gen, 1, 100))
+ .WillOnce(Return(100));
+ absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
+ },
+ " 100 is not in \\(1, 100\\)");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpenClosed, gen, 1, 100))
+ .WillOnce(Return(1));
+ absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
+ },
+ " 1 is not in \\(1, 100\\]");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpenClosed, gen, 1, 100))
+ .WillOnce(Return(101));
+ absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
+ },
+ " 101 is not in \\(1, 100\\]");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpenClosed, gen, 1, 100))
+ .WillOnce(Return(0));
+ absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
+ },
+ " 0 is not in \\(1, 100\\]");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalOpenClosed, gen, 1, 100))
+ .WillOnce(Return(101));
+ absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
+ },
+ " 101 is not in \\(1, 100\\]");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalClosed, gen, 1, 100))
+ .WillOnce(Return(0));
+ absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100);
+ },
+ " 0 is not in \\[1, 100\\]");
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<int>(),
+ Call(absl::IntervalClosed, gen, 1, 100))
+ .WillOnce(Return(101));
+ absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100);
+ },
+ " 101 is not in \\[1, 100\\]");
+}
+
+TEST(ValidatedMockUniformDeathTest, DoubleBoundaryCases) {
+ absl::random_internal::MockingBitGenImpl<true> gen;
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0))
+ .WillOnce(Return(10.0));
+ EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), 10.0);
+ },
+ " 10 is not in \\[1, 10\\)");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalOpen, gen, 1.0, 10.0))
+ .WillOnce(Return(10.0));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
+ 10.0);
+ },
+ " 10 is not in \\(1, 10\\)");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ {
+ EXPECT_CALL(absl::MockUniform<double>(),
+ Call(absl::IntervalOpen, gen, 1.0, 10.0))
+ .WillOnce(Return(1.0));
+ EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
+ 1.0);
+ },
+ " 1 is not in \\(1, 10\\)");
+}
+
} // namespace
diff --git a/absl/random/mocking_bit_gen.h b/absl/random/mocking_bit_gen.h
index 89fa5a47..92f2e4fc 100644
--- a/absl/random/mocking_bit_gen.h
+++ b/absl/random/mocking_bit_gen.h
@@ -28,83 +28,37 @@
#ifndef ABSL_RANDOM_MOCKING_BIT_GEN_H_
#define ABSL_RANDOM_MOCKING_BIT_GEN_H_
-#include <iterator>
-#include <limits>
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
#include "gmock/gmock.h"
-#include "gtest/gtest.h"
+#include "absl/base/attributes.h"
+#include "absl/base/config.h"
#include "absl/base/internal/fast_type_id.h"
#include "absl/container/flat_hash_map.h"
#include "absl/meta/type_traits.h"
-#include "absl/random/distributions.h"
-#include "absl/random/internal/distribution_caller.h"
+#include "absl/random/internal/mock_helpers.h"
#include "absl/random/random.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/types/span.h"
-#include "absl/types/variant.h"
#include "absl/utility/utility.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
+class BitGenRef;
+
namespace random_internal {
template <typename>
struct DistributionCaller;
class MockHelpers;
-} // namespace random_internal
-class BitGenRef;
-
-// MockingBitGen
-//
-// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class
-// which can act in place of an `absl::BitGen` URBG within tests using the
-// Googletest testing framework.
-//
-// Usage:
-//
-// Use an `absl::MockingBitGen` along with a mock distribution object (within
-// mock_distributions.h) inside Googletest constructs such as ON_CALL(),
-// EXPECT_TRUE(), etc. to produce deterministic results conforming to the
-// distribution's API contract.
-//
-// Example:
-//
-// // Mock a call to an `absl::Bernoulli` distribution using Googletest
-// absl::MockingBitGen bitgen;
-//
-// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5))
-// .WillByDefault(testing::Return(true));
-// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5));
-//
-// // Mock a call to an `absl::Uniform` distribution within Googletest
-// absl::MockingBitGen bitgen;
-//
-// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_))
-// .WillByDefault([] (int low, int high) {
-// return low + (high - low) / 2;
-// });
-//
-// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5);
-// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35);
-//
-// At this time, only mock distributions supplied within the Abseil random
-// library are officially supported.
-//
-// EXPECT_CALL and ON_CALL need to be made within the same DLL component as
-// the call to absl::Uniform and related methods, otherwise mocking will fail
-// since the underlying implementation creates a type-specific pointer which
-// will be distinct across different DLL boundaries.
-//
-class MockingBitGen {
+// Implements MockingBitGen with an option to turn on extra validation.
+template <bool EnableValidation>
+class MockingBitGenImpl {
public:
- MockingBitGen() = default;
- ~MockingBitGen() = default;
+ MockingBitGenImpl() = default;
+ ~MockingBitGenImpl() = default;
// URBG interface
using result_type = absl::BitGen::result_type;
@@ -125,15 +79,19 @@ class MockingBitGen {
// NOTE: MockFnCaller is essentially equivalent to the lambda:
// [fn](auto... args) { return fn->Call(std::move(args)...)}
// however that fails to build on some supported platforms.
- template <typename MockFnType, typename ResultT, typename Tuple>
+ template <typename MockFnType, typename ValidatorT, typename ResultT,
+ typename Tuple>
struct MockFnCaller;
// specialization for std::tuple.
- template <typename MockFnType, typename ResultT, typename... Args>
- struct MockFnCaller<MockFnType, ResultT, std::tuple<Args...>> {
+ template <typename MockFnType, typename ValidatorT, typename ResultT,
+ typename... Args>
+ struct MockFnCaller<MockFnType, ValidatorT, ResultT, std::tuple<Args...>> {
MockFnType* fn;
inline ResultT operator()(Args... args) {
- return fn->Call(std::move(args)...);
+ ResultT result = fn->Call(args...);
+ ValidatorT::Validate(result, args...);
+ return result;
}
};
@@ -150,16 +108,17 @@ class MockingBitGen {
/*ResultT*/ void* result) = 0;
};
- template <typename MockFnType, typename ResultT, typename ArgTupleT>
+ template <typename MockFnType, typename ValidatorT, typename ResultT,
+ typename ArgTupleT>
class FunctionHolderImpl final : public FunctionHolder {
public:
- void Apply(void* args_tuple, void* result) override {
+ void Apply(void* args_tuple, void* result) final {
// Requires tuple_args to point to a ArgTupleT, which is a
// std::tuple<Args...> used to invoke the mock function. Requires result
// to point to a ResultT, which is the result of the call.
- *static_cast<ResultT*>(result) =
- absl::apply(MockFnCaller<MockFnType, ResultT, ArgTupleT>{&mock_fn_},
- *static_cast<ArgTupleT*>(args_tuple));
+ *static_cast<ResultT*>(result) = absl::apply(
+ MockFnCaller<MockFnType, ValidatorT, ResultT, ArgTupleT>{&mock_fn_},
+ *static_cast<ArgTupleT*>(args_tuple));
}
MockFnType mock_fn_;
@@ -175,26 +134,29 @@ class MockingBitGen {
//
// The returned MockFunction<...> type can be used to setup additional
// distribution parameters of the expectation.
- template <typename ResultT, typename ArgTupleT, typename SelfT>
- auto RegisterMock(SelfT&, base_internal::FastTypeIdType type)
+ template <typename ResultT, typename ArgTupleT, typename SelfT,
+ typename ValidatorT>
+ auto RegisterMock(SelfT&, base_internal::FastTypeIdType type, ValidatorT)
-> decltype(GetMockFnType(std::declval<ResultT>(),
std::declval<ArgTupleT>()))& {
+ using ActualValidatorT =
+ std::conditional_t<EnableValidation, ValidatorT, NoOpValidator>;
using MockFnType = decltype(GetMockFnType(std::declval<ResultT>(),
std::declval<ArgTupleT>()));
using WrappedFnType = absl::conditional_t<
- std::is_same<SelfT, ::testing::NiceMock<absl::MockingBitGen>>::value,
+ std::is_same<SelfT, ::testing::NiceMock<MockingBitGenImpl>>::value,
::testing::NiceMock<MockFnType>,
absl::conditional_t<
- std::is_same<SelfT,
- ::testing::NaggyMock<absl::MockingBitGen>>::value,
+ std::is_same<SelfT, ::testing::NaggyMock<MockingBitGenImpl>>::value,
::testing::NaggyMock<MockFnType>,
absl::conditional_t<
std::is_same<SelfT,
- ::testing::StrictMock<absl::MockingBitGen>>::value,
+ ::testing::StrictMock<MockingBitGenImpl>>::value,
::testing::StrictMock<MockFnType>, MockFnType>>>;
- using ImplT = FunctionHolderImpl<WrappedFnType, ResultT, ArgTupleT>;
+ using ImplT =
+ FunctionHolderImpl<WrappedFnType, ActualValidatorT, ResultT, ArgTupleT>;
auto& mock = mocks_[type];
if (!mock) {
mock = absl::make_unique<ImplT>();
@@ -234,6 +196,58 @@ class MockingBitGen {
// InvokeMock
};
+} // namespace random_internal
+
+// MockingBitGen
+//
+// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class
+// which can act in place of an `absl::BitGen` URBG within tests using the
+// Googletest testing framework.
+//
+// Usage:
+//
+// Use an `absl::MockingBitGen` along with a mock distribution object (within
+// mock_distributions.h) inside Googletest constructs such as ON_CALL(),
+// EXPECT_TRUE(), etc. to produce deterministic results conforming to the
+// distribution's API contract.
+//
+// Example:
+//
+// // Mock a call to an `absl::Bernoulli` distribution using Googletest
+// absl::MockingBitGen bitgen;
+//
+// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5))
+// .WillByDefault(testing::Return(true));
+// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5));
+//
+// // Mock a call to an `absl::Uniform` distribution within Googletest
+// absl::MockingBitGen bitgen;
+//
+// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_))
+// .WillByDefault([] (int low, int high) {
+// return low + (high - low) / 2;
+// });
+//
+// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5);
+// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35);
+//
+// At this time, only mock distributions supplied within the Abseil random
+// library are officially supported.
+//
+// EXPECT_CALL and ON_CALL need to be made within the same DLL component as
+// the call to absl::Uniform and related methods, otherwise mocking will fail
+// since the underlying implementation creates a type-specific pointer which
+// will be distinct across different DLL boundaries.
+//
+using MockingBitGen = random_internal::MockingBitGenImpl<false>;
+
+// UnvalidatedMockingBitGen
+//
+// UnvalidatedMockingBitGen is a variant of MockingBitGen which does no extra
+// validation.
+using UnvalidatedMockingBitGen ABSL_DEPRECATED("Use MockingBitGen instead") =
+ random_internal::MockingBitGenImpl<false>;
+
ABSL_NAMESPACE_END
} // namespace absl
diff --git a/absl/random/mocking_bit_gen_test.cc b/absl/random/mocking_bit_gen_test.cc
index 9ccdf568..26e673ac 100644
--- a/absl/random/mocking_bit_gen_test.cc
+++ b/absl/random/mocking_bit_gen_test.cc
@@ -16,9 +16,11 @@
#include "absl/random/mocking_bit_gen.h"
#include <cmath>
+#include <cstddef>
#include <cstdint>
+#include <iterator>
#include <numeric>
-#include <random>
+#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest-spi.h"
@@ -246,33 +248,33 @@ TEST(WillOnce, DistinctCounters) {
absl::MockingBitGen gen;
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000))
.Times(3)
- .WillRepeatedly(Return(0));
+ .WillRepeatedly(Return(1));
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1000001, 2000000))
.Times(3)
- .WillRepeatedly(Return(1));
- EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
- EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
- EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
+ .WillRepeatedly(Return(1000001));
+ EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
+ EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
+ EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
+ EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
+ EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
+ EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
}
TEST(TimesModifier, ModifierSaturatesAndExpires) {
EXPECT_NONFATAL_FAILURE(
[]() {
absl::MockingBitGen gen;
- EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000))
+ EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 0, 1000000))
.Times(3)
.WillRepeatedly(Return(15))
.RetiresOnSaturation();
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
- EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
+ EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
+ EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
+ EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
// Times(3) has expired - Should get a different value now.
- EXPECT_NE(absl::Uniform(gen, 1, 1000000), 15);
+ EXPECT_NE(absl::Uniform(gen, 0, 1000000), 15);
}(),
"");
}
@@ -394,7 +396,7 @@ TEST(MockingBitGen, StrictMock_TooMany) {
EXPECT_EQ(absl::Uniform(gen, 1, 1000), 145);
EXPECT_NONFATAL_FAILURE(
- [&]() { EXPECT_EQ(absl::Uniform(gen, 10, 1000), 0); }(),
+ [&]() { EXPECT_EQ(absl::Uniform(gen, 0, 1000), 0); }(),
"over-saturated and active");
}