summaryrefslogtreecommitdiff
path: root/absl/random/internal/distribution_caller.h
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random/internal/distribution_caller.h')
-rw-r--r--absl/random/internal/distribution_caller.h66
1 files changed, 49 insertions, 17 deletions
diff --git a/absl/random/internal/distribution_caller.h b/absl/random/internal/distribution_caller.h
index 02603cf8..fc81b787 100644
--- a/absl/random/internal/distribution_caller.h
+++ b/absl/random/internal/distribution_caller.h
@@ -20,6 +20,8 @@
#include <utility>
#include "absl/base/config.h"
+#include "absl/base/internal/fast_type_id.h"
+#include "absl/utility/utility.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
@@ -30,27 +32,57 @@ namespace random_internal {
// to intercept such calls.
template <typename URBG>
struct DistributionCaller {
- // Call the provided distribution type. The parameters are expected
- // to be explicitly specified.
- // DistrT is the distribution type.
- // FormatT is the formatter type:
+ // SFINAE to detect whether the URBG type includes a member matching
+ // bool InvokeMock(base_internal::FastTypeIdType, void*, void*).
//
- // struct FormatT {
- // using result_type = distribution_t::result_type;
- // static std::string FormatCall(
- // const distribution_t& distr,
- // absl::Span<const result_type>);
- //
- // static std::string FormatExpectation(
- // absl::string_view match_args,
- // absl::Span<const result_t> results);
- // }
- //
- template <typename DistrT, typename FormatT, typename... Args>
- static typename DistrT::result_type Call(URBG* urbg, Args&&... args) {
+ // These live inside BitGenRef so that they have friend access
+ // to MockingBitGen. (see similar methods in DistributionCaller).
+ template <template <class...> class Trait, class AlwaysVoid, class... Args>
+ struct detector : std::false_type {};
+ template <template <class...> class Trait, class... Args>
+ struct detector<Trait, absl::void_t<Trait<Args...>>, Args...>
+ : std::true_type {};
+
+ template <class T>
+ using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock(
+ std::declval<::absl::base_internal::FastTypeIdType>(),
+ std::declval<void*>(), std::declval<void*>()));
+
+ using HasInvokeMock = typename detector<invoke_mock_t, void, URBG>::type;
+
+ // Default implementation of distribution caller.
+ template <typename DistrT, typename... Args>
+ static typename DistrT::result_type Impl(std::false_type, URBG* urbg,
+ Args&&... args) {
DistrT dist(std::forward<Args>(args)...);
return dist(*urbg);
}
+
+ // Mock implementation of distribution caller.
+ // The underlying KeyT must match the KeyT constructed by MockOverloadSet.
+ template <typename DistrT, typename... Args>
+ static typename DistrT::result_type Impl(std::true_type, URBG* urbg,
+ Args&&... args) {
+ using ResultT = typename DistrT::result_type;
+ using ArgTupleT = std::tuple<absl::decay_t<Args>...>;
+ using KeyT = ResultT(DistrT, ArgTupleT);
+
+ ArgTupleT arg_tuple(std::forward<Args>(args)...);
+ ResultT result;
+ if (!urbg->InvokeMock(::absl::base_internal::FastTypeId<KeyT>(), &arg_tuple,
+ &result)) {
+ auto dist = absl::make_from_tuple<DistrT>(arg_tuple);
+ result = dist(*urbg);
+ }
+ return result;
+ }
+
+ // Default implementation of distribution caller.
+ template <typename DistrT, typename... Args>
+ static typename DistrT::result_type Call(URBG* urbg, Args&&... args) {
+ return Impl<DistrT, Args...>(HasInvokeMock{}, urbg,
+ std::forward<Args>(args)...);
+ }
};
} // namespace random_internal