diff options
-rw-r--r-- | absl/synchronization/mutex.cc | 38 | ||||
-rw-r--r-- | absl/synchronization/mutex.h | 107 | ||||
-rw-r--r-- | absl/synchronization/mutex_test.cc | 111 |
3 files changed, 198 insertions, 58 deletions
diff --git a/absl/synchronization/mutex.cc b/absl/synchronization/mutex.cc index b0f412bf..ff18df5d 100644 --- a/absl/synchronization/mutex.cc +++ b/absl/synchronization/mutex.cc @@ -37,6 +37,8 @@ #include <atomic> #include <cinttypes> #include <cstddef> +#include <cstring> +#include <iterator> #include <thread> // NOLINT(build/c++11) #include "absl/base/attributes.h" @@ -2780,25 +2782,32 @@ static bool Dereference(void *arg) { return *(static_cast<bool *>(arg)); } -Condition::Condition() {} // null constructor, used for kTrue only +Condition::Condition() = default; // null constructor, used for kTrue only const Condition Condition::kTrue; Condition::Condition(bool (*func)(void *), void *arg) : eval_(&CallVoidPtrFunction), - function_(func), - method_(nullptr), - arg_(arg) {} + arg_(arg) { + static_assert(sizeof(&func) <= sizeof(callback_), + "An overlarge function pointer passed to Condition."); + StoreCallback(func); +} bool Condition::CallVoidPtrFunction(const Condition *c) { - return (*c->function_)(c->arg_); + using FunctionPointer = bool (*)(void *); + FunctionPointer function_pointer; + std::memcpy(&function_pointer, c->callback_, sizeof(function_pointer)); + return (*function_pointer)(c->arg_); } Condition::Condition(const bool *cond) : eval_(CallVoidPtrFunction), - function_(Dereference), - method_(nullptr), // const_cast is safe since Dereference does not modify arg - arg_(const_cast<bool *>(cond)) {} + arg_(const_cast<bool *>(cond)) { + using FunctionPointer = bool (*)(void *); + const FunctionPointer dereference = Dereference; + StoreCallback(dereference); +} bool Condition::Eval() const { // eval_ == null for kTrue @@ -2806,14 +2815,15 @@ bool Condition::Eval() const { } bool Condition::GuaranteedEqual(const Condition *a, const Condition *b) { - if (a == nullptr) { + // kTrue logic. + if (a == nullptr || a->eval_ == nullptr) { return b == nullptr || b->eval_ == nullptr; + }else if (b == nullptr || b->eval_ == nullptr) { + return false; } - if (b == nullptr || b->eval_ == nullptr) { - return a->eval_ == nullptr; - } - return a->eval_ == b->eval_ && a->function_ == b->function_ && - a->arg_ == b->arg_ && a->method_ == b->method_; + // Check equality of the representative fields. + return a->eval_ == b->eval_ && a->arg_ == b->arg_ && + !memcmp(a->callback_, b->callback_, sizeof(ConservativeMethodPointer)); } ABSL_NAMESPACE_END diff --git a/absl/synchronization/mutex.h b/absl/synchronization/mutex.h index 8694bb75..54ee703a 100644 --- a/absl/synchronization/mutex.h +++ b/absl/synchronization/mutex.h @@ -60,6 +60,8 @@ #include <atomic> #include <cstdint> +#include <cstring> +#include <iterator> #include <string> #include "absl/base/const_init.h" @@ -612,12 +614,12 @@ class ABSL_SCOPED_LOCKABLE WriterMutexLock { // Condition // ----------------------------------------------------------------------------- // -// As noted above, `Mutex` contains a number of member functions which take a -// `Condition` as an argument; clients can wait for conditions to become `true` -// before attempting to acquire the mutex. These sections are known as -// "condition critical" sections. To use a `Condition`, you simply need to -// construct it, and use within an appropriate `Mutex` member function; -// everything else in the `Condition` class is an implementation detail. +// `Mutex` contains a number of member functions which take a `Condition` as an +// argument; clients can wait for conditions to become `true` before attempting +// to acquire the mutex. These sections are known as "condition critical" +// sections. To use a `Condition`, you simply need to construct it, and use +// within an appropriate `Mutex` member function; everything else in the +// `Condition` class is an implementation detail. // // A `Condition` is specified as a function pointer which returns a boolean. // `Condition` functions should be pure functions -- their results should depend @@ -742,22 +744,55 @@ class Condition { static bool GuaranteedEqual(const Condition *a, const Condition *b); private: - typedef bool (*InternalFunctionType)(void * arg); - typedef bool (Condition::*InternalMethodType)(); - typedef bool (*InternalMethodCallerType)(void * arg, - InternalMethodType internal_method); - - bool (*eval_)(const Condition*); // Actual evaluator - InternalFunctionType function_; // function taking pointer returning bool - InternalMethodType method_; // method returning bool - void *arg_; // arg of function_ or object of method_ - - Condition(); // null constructor used only to create kTrue + // Sizing an allocation for a method pointer can be subtle. In the Itanium + // specifications, a method pointer has a predictable, uniform size. On the + // other hand, MSVC ABI, method pointer sizes vary based on the + // inheritance of the class. Specifically, method pointers from classes with + // multiple inheritance are bigger than those of classes with single + // inheritance. Other variations also exist. + + // A good way to allocate enough space for *any* pointer in these ABIs is to + // employ a class declaration with no definition. Because the inheritance + // structure is not available for this declaration, the compiler must + // assume, conservatively, that its method pointers have the largest possible + // size. + class OpaqueClass; + using ConservativeMethodPointer = bool (OpaqueClass::*)(); + static_assert(sizeof(bool(OpaqueClass::*)()) >= sizeof(bool (*)(void *)), + "Unsupported platform."); + + // Allocation for a function pointer or method pointer. + // The {0} initializer ensures that all unused bytes of this buffer are + // always zeroed out. This is necessary, because GuaranteedEqual() compares + // all of the bytes, unaware of which bytes are relevant to a given `eval_`. + char callback_[sizeof(ConservativeMethodPointer)] = {0}; + + // Function with which to evaluate callbacks and/or arguments. + bool (*eval_)(const Condition*); + + // Either an argument for a function call or an object for a method call. + void *arg_; // Various functions eval_ can point to: static bool CallVoidPtrFunction(const Condition*); template <typename T> static bool CastAndCallFunction(const Condition* c); template <typename T> static bool CastAndCallMethod(const Condition* c); + + // Helper methods for storing, validating, and reading callback arguments. + template <typename T> + inline void StoreCallback(T callback) { + static_assert( + sizeof(callback) <= sizeof(callback_), + "An overlarge pointer was passed as a callback to Condition."); + std::memcpy(callback_, &callback, sizeof(callback)); + } + + template <typename T> + inline void ReadCallback(T *callback) const { + std::memcpy(callback, callback_, sizeof(*callback)); + } + + Condition(); // null constructor used only to create kTrue }; // ----------------------------------------------------------------------------- @@ -949,44 +984,48 @@ inline CondVar::CondVar() : cv_(0) {} // static template <typename T> bool Condition::CastAndCallMethod(const Condition *c) { - typedef bool (T::*MemberType)(); - MemberType rm = reinterpret_cast<MemberType>(c->method_); - T *x = static_cast<T *>(c->arg_); - return (x->*rm)(); + T *object = static_cast<T *>(c->arg_); + bool (T::*method_pointer)(); + c->ReadCallback(&method_pointer); + return (object->*method_pointer)(); } // static template <typename T> bool Condition::CastAndCallFunction(const Condition *c) { - typedef bool (*FuncType)(T *); - FuncType fn = reinterpret_cast<FuncType>(c->function_); - T *x = static_cast<T *>(c->arg_); - return (*fn)(x); + bool (*function)(T *); + c->ReadCallback(&function); + T *argument = static_cast<T *>(c->arg_); + return (*function)(argument); } template <typename T> inline Condition::Condition(bool (*func)(T *), T *arg) : eval_(&CastAndCallFunction<T>), - function_(reinterpret_cast<InternalFunctionType>(func)), - method_(nullptr), - arg_(const_cast<void *>(static_cast<const void *>(arg))) {} + arg_(const_cast<void *>(static_cast<const void *>(arg))) { + static_assert(sizeof(&func) <= sizeof(callback_), + "An overlarge function pointer was passed to Condition."); + StoreCallback(func); +} template <typename T> inline Condition::Condition(T *object, bool (absl::internal::identity<T>::type::*method)()) : eval_(&CastAndCallMethod<T>), - function_(nullptr), - method_(reinterpret_cast<InternalMethodType>(method)), - arg_(object) {} + arg_(object) { + static_assert(sizeof(&method) <= sizeof(callback_), + "An overlarge method pointer was passed to Condition."); + StoreCallback(method); +} template <typename T> inline Condition::Condition(const T *object, bool (absl::internal::identity<T>::type::*method)() const) : eval_(&CastAndCallMethod<T>), - function_(nullptr), - method_(reinterpret_cast<InternalMethodType>(method)), - arg_(reinterpret_cast<void *>(const_cast<T *>(object))) {} + arg_(reinterpret_cast<void *>(const_cast<T *>(object))) { + StoreCallback(method); +} // Register hooks for profiling support. // diff --git a/absl/synchronization/mutex_test.cc b/absl/synchronization/mutex_test.cc index 99bb0175..f3d60852 100644 --- a/absl/synchronization/mutex_test.cc +++ b/absl/synchronization/mutex_test.cc @@ -295,8 +295,9 @@ static void TestTime(TestContext *cxt, int c, bool use_cv) { "TestTime failed"); } elapsed = absl::Now() - start; - ABSL_RAW_CHECK(absl::Seconds(0.9) <= elapsed && - elapsed <= absl::Seconds(2.0), "TestTime failed"); + ABSL_RAW_CHECK( + absl::Seconds(0.9) <= elapsed && elapsed <= absl::Seconds(2.0), + "TestTime failed"); ABSL_RAW_CHECK(cxt->g0 == cxt->threads, "TestTime failed"); } else if (c == 1) { @@ -343,7 +344,7 @@ static void TestMuTime(TestContext *cxt, int c) { TestTime(cxt, c, false); } static void TestCVTime(TestContext *cxt, int c) { TestTime(cxt, c, true); } static void EndTest(int *c0, int *c1, absl::Mutex *mu, absl::CondVar *cv, - const std::function<void(int)>& cb) { + const std::function<void(int)> &cb) { mu->Lock(); int c = (*c0)++; mu->Unlock(); @@ -366,9 +367,9 @@ static int RunTestCommon(TestContext *cxt, void (*test)(TestContext *cxt, int), cxt->threads = threads; absl::synchronization_internal::ThreadPool tp(threads); for (int i = 0; i != threads; i++) { - tp.Schedule(std::bind(&EndTest, &c0, &c1, &mu2, &cv2, - std::function<void(int)>( - std::bind(test, cxt, std::placeholders::_1)))); + tp.Schedule(std::bind( + &EndTest, &c0, &c1, &mu2, &cv2, + std::function<void(int)>(std::bind(test, cxt, std::placeholders::_1)))); } mu2.Lock(); while (c1 != threads) { @@ -682,14 +683,14 @@ struct LockWhenTestStruct { bool waiting = false; }; -static bool LockWhenTestIsCond(LockWhenTestStruct* s) { +static bool LockWhenTestIsCond(LockWhenTestStruct *s) { s->mu2.Lock(); s->waiting = true; s->mu2.Unlock(); return s->cond; } -static void LockWhenTestWaitForIsCond(LockWhenTestStruct* s) { +static void LockWhenTestWaitForIsCond(LockWhenTestStruct *s) { s->mu1.LockWhen(absl::Condition(&LockWhenTestIsCond, s)); s->mu1.Unlock(); } @@ -1694,8 +1695,7 @@ TEST(Mutex, Timed) { TEST(Mutex, CVTime) { int threads = 10; // Use a fixed thread count of 10 int iterations = 1; - EXPECT_EQ(RunTest(&TestCVTime, threads, iterations, 1), - threads * iterations); + EXPECT_EQ(RunTest(&TestCVTime, threads, iterations, 1), threads * iterations); } TEST(Mutex, MuTime) { @@ -1730,4 +1730,95 @@ TEST(Mutex, SignalExitedThread) { for (auto &th : top) th.join(); } +#ifdef _MSC_VER + +// Declare classes of the various MSVC inheritance types. +class __single_inheritance SingleInheritance{}; +class __multiple_inheritance MultipleInheritance; +class __virtual_inheritance VirtualInheritance; +class UnknownInheritance; + +TEST(ConditionTest, MicrosoftMethodPointerSize) { + // This test verifies expectations about sizes of MSVC pointers to methods. + // Pointers to methods are distinguished by whether their class hierachies + // contain single inheritance, multiple inheritance, or virtual inheritence. + void (SingleInheritance::*single_inheritance)(); + void (MultipleInheritance::*multiple_inheritance)(); + void (VirtualInheritance::*virtual_inheritance)(); + void (UnknownInheritance::*unknown_inheritance)(); + +#if defined(_M_IX86) || defined(_M_ARM) + static_assert(sizeof(single_inheritance) == 4, + "Unexpected sizeof(single_inheritance)."); + static_assert(sizeof(multiple_inheritance) == 8, + "Unexpected sizeof(multiple_inheritance)."); + static_assert(sizeof(virtual_inheritance) == 12, + "Unexpected sizeof(virtual_inheritance)."); +#elif defined(_M_X64) || defined(__aarch64__) + static_assert(sizeof(single_inheritance) == 8, + "Unexpected sizeof(single_inheritance)."); + static_assert(sizeof(multiple_inheritance) == 16, + "Unexpected sizeof(multiple_inheritance)."); + static_assert(sizeof(virtual_inheritance) == 16, + "Unexpected sizeof(virtual_inheritance)."); +#endif + static_assert(sizeof(unknown_inheritance) >= sizeof(virtual_inheritance), + "Failed invariant: sizeof(unknown_inheritance) >= " + "sizeof(virtual_inheritance)!"); +} + +class Callback { + bool x = true; + + public: + Callback() {} + bool method() { + x = !x; + return x; + } +}; + +class M2 { + bool x = true; + + public: + M2() {} + bool method2() { + x = !x; + return x; + } +}; + +class MultipleInheritance : public Callback, public M2 {}; + +TEST(ConditionTest, ConditionWithMultipleInheritanceMethod) { + // This test ensures that Condition can deal with method pointers from classes + // with multiple inheritance. + MultipleInheritance object = MultipleInheritance(); + absl::Condition condition(&object, &MultipleInheritance::method); + EXPECT_FALSE(condition.Eval()); + EXPECT_TRUE(condition.Eval()); +} + +class __virtual_inheritance VirtualInheritance : virtual public Callback { + bool x = false; + + public: + VirtualInheritance() {} + bool method() { + x = !x; + return x; + } +}; + +TEST(ConditionTest, ConditionWithVirtualInheritanceMethod) { + // This test ensures that Condition can deal with method pointers from classes + // with virtual inheritance. + VirtualInheritance object = VirtualInheritance(); + absl::Condition condition(&object, &VirtualInheritance::method); + EXPECT_TRUE(condition.Eval()); + EXPECT_FALSE(condition.Eval()); +} +#endif + } // namespace |