diff options
Diffstat (limited to 'absl/random/bit_gen_ref.h')
-rw-r--r-- | absl/random/bit_gen_ref.h | 105 |
1 files changed, 66 insertions, 39 deletions
diff --git a/absl/random/bit_gen_ref.h b/absl/random/bit_gen_ref.h index 59591a47..00e36248 100644 --- a/absl/random/bit_gen_ref.h +++ b/absl/random/bit_gen_ref.h @@ -24,11 +24,11 @@ #ifndef ABSL_RANDOM_BIT_GEN_REF_H_ #define ABSL_RANDOM_BIT_GEN_REF_H_ +#include "absl/base/internal/fast_type_id.h" #include "absl/base/macros.h" #include "absl/meta/type_traits.h" #include "absl/random/internal/distribution_caller.h" #include "absl/random/internal/fast_uniform_bits.h" -#include "absl/random/internal/mocking_bit_gen_base.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -51,6 +51,9 @@ struct is_urbg< typename std::decay<decltype(std::declval<URBG>()())>::type>::value>> : std::true_type {}; +template <typename> +struct DistributionCaller; + } // namespace random_internal // ----------------------------------------------------------------------------- @@ -77,23 +80,50 @@ struct is_urbg< // } // class BitGenRef { - public: - using result_type = uint64_t; + // SFINAE to detect whether the URBG type includes a member matching + // bool InvokeMock(base_internal::FastTypeIdType, void*, void*). + // + // These live inside BitGenRef so that they have friend access + // to MockingBitGen. (see similar methods in DistributionCaller). + template <template <class...> class Trait, class AlwaysVoid, class... Args> + struct detector : std::false_type {}; + template <template <class...> class Trait, class... Args> + struct detector<Trait, absl::void_t<Trait<Args...>>, Args...> + : std::true_type {}; + + template <class T> + using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock( + std::declval<base_internal::FastTypeIdType>(), std::declval<void*>(), + std::declval<void*>())); + + template <typename T> + using HasInvokeMock = typename detector<invoke_mock_t, void, T>::type; - BitGenRef(const absl::BitGenRef&) = default; - BitGenRef(absl::BitGenRef&&) = default; - BitGenRef& operator=(const absl::BitGenRef&) = default; - BitGenRef& operator=(absl::BitGenRef&&) = default; + public: + BitGenRef(const BitGenRef&) = default; + BitGenRef(BitGenRef&&) = default; + BitGenRef& operator=(const BitGenRef&) = default; + BitGenRef& operator=(BitGenRef&&) = default; + + template <typename URBG, typename absl::enable_if_t< + (!std::is_same<URBG, BitGenRef>::value && + random_internal::is_urbg<URBG>::value && + !HasInvokeMock<URBG>::value)>* = nullptr> + BitGenRef(URBG& gen) // NOLINT + : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), + mock_call_(NotAMock), + generate_impl_fn_(ImplFn<URBG>) {} template <typename URBG, - typename absl::enable_if_t< - (!std::is_same<URBG, BitGenRef>::value && - random_internal::is_urbg<URBG>::value)>* = nullptr> + typename absl::enable_if_t<(!std::is_same<URBG, BitGenRef>::value && + random_internal::is_urbg<URBG>::value && + HasInvokeMock<URBG>::value)>* = nullptr> BitGenRef(URBG& gen) // NOLINT - : mocked_gen_ptr_(MakeMockPointer(&gen)), - t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), - generate_impl_fn_(ImplFn<URBG>) { - } + : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), + mock_call_(&MockCall<URBG>), + generate_impl_fn_(ImplFn<URBG>) {} + + using result_type = uint64_t; static constexpr result_type(min)() { return (std::numeric_limits<result_type>::min)(); @@ -106,14 +136,9 @@ class BitGenRef { result_type operator()() { return generate_impl_fn_(t_erased_gen_ptr_); } private: - friend struct absl::random_internal::DistributionCaller<absl::BitGenRef>; using impl_fn = result_type (*)(uintptr_t); - using mocker_base_t = absl::random_internal::MockingBitGenBase; - - // Convert an arbitrary URBG pointer into either a valid mocker_base_t - // pointer or a nullptr. - static inline mocker_base_t* MakeMockPointer(mocker_base_t* t) { return t; } - static inline mocker_base_t* MakeMockPointer(void*) { return nullptr; } + using mock_call_fn = bool (*)(uintptr_t, base_internal::FastTypeIdType, void*, + void*); template <typename URBG> static result_type ImplFn(uintptr_t ptr) { @@ -123,29 +148,31 @@ class BitGenRef { return fast_uniform_bits(*reinterpret_cast<URBG*>(ptr)); } - mocker_base_t* mocked_gen_ptr_; + // Get a type-erased InvokeMock pointer. + template <typename URBG> + static bool MockCall(uintptr_t gen_ptr, base_internal::FastTypeIdType type, + void* result, void* arg_tuple) { + return reinterpret_cast<URBG*>(gen_ptr)->InvokeMock(type, result, + arg_tuple); + } + static bool NotAMock(uintptr_t, base_internal::FastTypeIdType, void*, void*) { + return false; + } + + inline bool InvokeMock(base_internal::FastTypeIdType type, void* args_tuple, + void* result) { + if (mock_call_ == NotAMock) return false; // avoids an indirect call. + return mock_call_(t_erased_gen_ptr_, type, args_tuple, result); + } + uintptr_t t_erased_gen_ptr_; + mock_call_fn mock_call_; impl_fn generate_impl_fn_; -}; - -namespace random_internal { -template <> -struct DistributionCaller<absl::BitGenRef> { - template <typename DistrT, typename... Args> - static typename DistrT::result_type Call(absl::BitGenRef* gen_ref, - Args&&... args) { - auto* mock_ptr = gen_ref->mocked_gen_ptr_; - if (mock_ptr == nullptr) { - DistrT dist(std::forward<Args>(args)...); - return dist(*gen_ref); - } else { - return mock_ptr->template Call<DistrT>(std::forward<Args>(args)...); - } - } + template <typename> + friend struct ::absl::random_internal::DistributionCaller; // for InvokeMock }; -} // namespace random_internal ABSL_NAMESPACE_END } // namespace absl |