summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--absl/synchronization/mutex.cc38
-rw-r--r--absl/synchronization/mutex.h107
-rw-r--r--absl/synchronization/mutex_test.cc111
3 files changed, 198 insertions, 58 deletions
diff --git a/absl/synchronization/mutex.cc b/absl/synchronization/mutex.cc
index b0f412bf..ff18df5d 100644
--- a/absl/synchronization/mutex.cc
+++ b/absl/synchronization/mutex.cc
@@ -37,6 +37,8 @@
#include <atomic>
#include <cinttypes>
#include <cstddef>
+#include <cstring>
+#include <iterator>
#include <thread> // NOLINT(build/c++11)
#include "absl/base/attributes.h"
@@ -2780,25 +2782,32 @@ static bool Dereference(void *arg) {
return *(static_cast<bool *>(arg));
}
-Condition::Condition() {} // null constructor, used for kTrue only
+Condition::Condition() = default; // null constructor, used for kTrue only
const Condition Condition::kTrue;
Condition::Condition(bool (*func)(void *), void *arg)
: eval_(&CallVoidPtrFunction),
- function_(func),
- method_(nullptr),
- arg_(arg) {}
+ arg_(arg) {
+ static_assert(sizeof(&func) <= sizeof(callback_),
+ "An overlarge function pointer passed to Condition.");
+ StoreCallback(func);
+}
bool Condition::CallVoidPtrFunction(const Condition *c) {
- return (*c->function_)(c->arg_);
+ using FunctionPointer = bool (*)(void *);
+ FunctionPointer function_pointer;
+ std::memcpy(&function_pointer, c->callback_, sizeof(function_pointer));
+ return (*function_pointer)(c->arg_);
}
Condition::Condition(const bool *cond)
: eval_(CallVoidPtrFunction),
- function_(Dereference),
- method_(nullptr),
// const_cast is safe since Dereference does not modify arg
- arg_(const_cast<bool *>(cond)) {}
+ arg_(const_cast<bool *>(cond)) {
+ using FunctionPointer = bool (*)(void *);
+ const FunctionPointer dereference = Dereference;
+ StoreCallback(dereference);
+}
bool Condition::Eval() const {
// eval_ == null for kTrue
@@ -2806,14 +2815,15 @@ bool Condition::Eval() const {
}
bool Condition::GuaranteedEqual(const Condition *a, const Condition *b) {
- if (a == nullptr) {
+ // kTrue logic.
+ if (a == nullptr || a->eval_ == nullptr) {
return b == nullptr || b->eval_ == nullptr;
+ }else if (b == nullptr || b->eval_ == nullptr) {
+ return false;
}
- if (b == nullptr || b->eval_ == nullptr) {
- return a->eval_ == nullptr;
- }
- return a->eval_ == b->eval_ && a->function_ == b->function_ &&
- a->arg_ == b->arg_ && a->method_ == b->method_;
+ // Check equality of the representative fields.
+ return a->eval_ == b->eval_ && a->arg_ == b->arg_ &&
+ !memcmp(a->callback_, b->callback_, sizeof(ConservativeMethodPointer));
}
ABSL_NAMESPACE_END
diff --git a/absl/synchronization/mutex.h b/absl/synchronization/mutex.h
index 8694bb75..54ee703a 100644
--- a/absl/synchronization/mutex.h
+++ b/absl/synchronization/mutex.h
@@ -60,6 +60,8 @@
#include <atomic>
#include <cstdint>
+#include <cstring>
+#include <iterator>
#include <string>
#include "absl/base/const_init.h"
@@ -612,12 +614,12 @@ class ABSL_SCOPED_LOCKABLE WriterMutexLock {
// Condition
// -----------------------------------------------------------------------------
//
-// As noted above, `Mutex` contains a number of member functions which take a
-// `Condition` as an argument; clients can wait for conditions to become `true`
-// before attempting to acquire the mutex. These sections are known as
-// "condition critical" sections. To use a `Condition`, you simply need to
-// construct it, and use within an appropriate `Mutex` member function;
-// everything else in the `Condition` class is an implementation detail.
+// `Mutex` contains a number of member functions which take a `Condition` as an
+// argument; clients can wait for conditions to become `true` before attempting
+// to acquire the mutex. These sections are known as "condition critical"
+// sections. To use a `Condition`, you simply need to construct it, and use
+// within an appropriate `Mutex` member function; everything else in the
+// `Condition` class is an implementation detail.
//
// A `Condition` is specified as a function pointer which returns a boolean.
// `Condition` functions should be pure functions -- their results should depend
@@ -742,22 +744,55 @@ class Condition {
static bool GuaranteedEqual(const Condition *a, const Condition *b);
private:
- typedef bool (*InternalFunctionType)(void * arg);
- typedef bool (Condition::*InternalMethodType)();
- typedef bool (*InternalMethodCallerType)(void * arg,
- InternalMethodType internal_method);
-
- bool (*eval_)(const Condition*); // Actual evaluator
- InternalFunctionType function_; // function taking pointer returning bool
- InternalMethodType method_; // method returning bool
- void *arg_; // arg of function_ or object of method_
-
- Condition(); // null constructor used only to create kTrue
+ // Sizing an allocation for a method pointer can be subtle. In the Itanium
+ // specifications, a method pointer has a predictable, uniform size. On the
+ // other hand, MSVC ABI, method pointer sizes vary based on the
+ // inheritance of the class. Specifically, method pointers from classes with
+ // multiple inheritance are bigger than those of classes with single
+ // inheritance. Other variations also exist.
+
+ // A good way to allocate enough space for *any* pointer in these ABIs is to
+ // employ a class declaration with no definition. Because the inheritance
+ // structure is not available for this declaration, the compiler must
+ // assume, conservatively, that its method pointers have the largest possible
+ // size.
+ class OpaqueClass;
+ using ConservativeMethodPointer = bool (OpaqueClass::*)();
+ static_assert(sizeof(bool(OpaqueClass::*)()) >= sizeof(bool (*)(void *)),
+ "Unsupported platform.");
+
+ // Allocation for a function pointer or method pointer.
+ // The {0} initializer ensures that all unused bytes of this buffer are
+ // always zeroed out. This is necessary, because GuaranteedEqual() compares
+ // all of the bytes, unaware of which bytes are relevant to a given `eval_`.
+ char callback_[sizeof(ConservativeMethodPointer)] = {0};
+
+ // Function with which to evaluate callbacks and/or arguments.
+ bool (*eval_)(const Condition*);
+
+ // Either an argument for a function call or an object for a method call.
+ void *arg_;
// Various functions eval_ can point to:
static bool CallVoidPtrFunction(const Condition*);
template <typename T> static bool CastAndCallFunction(const Condition* c);
template <typename T> static bool CastAndCallMethod(const Condition* c);
+
+ // Helper methods for storing, validating, and reading callback arguments.
+ template <typename T>
+ inline void StoreCallback(T callback) {
+ static_assert(
+ sizeof(callback) <= sizeof(callback_),
+ "An overlarge pointer was passed as a callback to Condition.");
+ std::memcpy(callback_, &callback, sizeof(callback));
+ }
+
+ template <typename T>
+ inline void ReadCallback(T *callback) const {
+ std::memcpy(callback, callback_, sizeof(*callback));
+ }
+
+ Condition(); // null constructor used only to create kTrue
};
// -----------------------------------------------------------------------------
@@ -949,44 +984,48 @@ inline CondVar::CondVar() : cv_(0) {}
// static
template <typename T>
bool Condition::CastAndCallMethod(const Condition *c) {
- typedef bool (T::*MemberType)();
- MemberType rm = reinterpret_cast<MemberType>(c->method_);
- T *x = static_cast<T *>(c->arg_);
- return (x->*rm)();
+ T *object = static_cast<T *>(c->arg_);
+ bool (T::*method_pointer)();
+ c->ReadCallback(&method_pointer);
+ return (object->*method_pointer)();
}
// static
template <typename T>
bool Condition::CastAndCallFunction(const Condition *c) {
- typedef bool (*FuncType)(T *);
- FuncType fn = reinterpret_cast<FuncType>(c->function_);
- T *x = static_cast<T *>(c->arg_);
- return (*fn)(x);
+ bool (*function)(T *);
+ c->ReadCallback(&function);
+ T *argument = static_cast<T *>(c->arg_);
+ return (*function)(argument);
}
template <typename T>
inline Condition::Condition(bool (*func)(T *), T *arg)
: eval_(&CastAndCallFunction<T>),
- function_(reinterpret_cast<InternalFunctionType>(func)),
- method_(nullptr),
- arg_(const_cast<void *>(static_cast<const void *>(arg))) {}
+ arg_(const_cast<void *>(static_cast<const void *>(arg))) {
+ static_assert(sizeof(&func) <= sizeof(callback_),
+ "An overlarge function pointer was passed to Condition.");
+ StoreCallback(func);
+}
template <typename T>
inline Condition::Condition(T *object,
bool (absl::internal::identity<T>::type::*method)())
: eval_(&CastAndCallMethod<T>),
- function_(nullptr),
- method_(reinterpret_cast<InternalMethodType>(method)),
- arg_(object) {}
+ arg_(object) {
+ static_assert(sizeof(&method) <= sizeof(callback_),
+ "An overlarge method pointer was passed to Condition.");
+ StoreCallback(method);
+}
template <typename T>
inline Condition::Condition(const T *object,
bool (absl::internal::identity<T>::type::*method)()
const)
: eval_(&CastAndCallMethod<T>),
- function_(nullptr),
- method_(reinterpret_cast<InternalMethodType>(method)),
- arg_(reinterpret_cast<void *>(const_cast<T *>(object))) {}
+ arg_(reinterpret_cast<void *>(const_cast<T *>(object))) {
+ StoreCallback(method);
+}
// Register hooks for profiling support.
//
diff --git a/absl/synchronization/mutex_test.cc b/absl/synchronization/mutex_test.cc
index 99bb0175..f3d60852 100644
--- a/absl/synchronization/mutex_test.cc
+++ b/absl/synchronization/mutex_test.cc
@@ -295,8 +295,9 @@ static void TestTime(TestContext *cxt, int c, bool use_cv) {
"TestTime failed");
}
elapsed = absl::Now() - start;
- ABSL_RAW_CHECK(absl::Seconds(0.9) <= elapsed &&
- elapsed <= absl::Seconds(2.0), "TestTime failed");
+ ABSL_RAW_CHECK(
+ absl::Seconds(0.9) <= elapsed && elapsed <= absl::Seconds(2.0),
+ "TestTime failed");
ABSL_RAW_CHECK(cxt->g0 == cxt->threads, "TestTime failed");
} else if (c == 1) {
@@ -343,7 +344,7 @@ static void TestMuTime(TestContext *cxt, int c) { TestTime(cxt, c, false); }
static void TestCVTime(TestContext *cxt, int c) { TestTime(cxt, c, true); }
static void EndTest(int *c0, int *c1, absl::Mutex *mu, absl::CondVar *cv,
- const std::function<void(int)>& cb) {
+ const std::function<void(int)> &cb) {
mu->Lock();
int c = (*c0)++;
mu->Unlock();
@@ -366,9 +367,9 @@ static int RunTestCommon(TestContext *cxt, void (*test)(TestContext *cxt, int),
cxt->threads = threads;
absl::synchronization_internal::ThreadPool tp(threads);
for (int i = 0; i != threads; i++) {
- tp.Schedule(std::bind(&EndTest, &c0, &c1, &mu2, &cv2,
- std::function<void(int)>(
- std::bind(test, cxt, std::placeholders::_1))));
+ tp.Schedule(std::bind(
+ &EndTest, &c0, &c1, &mu2, &cv2,
+ std::function<void(int)>(std::bind(test, cxt, std::placeholders::_1))));
}
mu2.Lock();
while (c1 != threads) {
@@ -682,14 +683,14 @@ struct LockWhenTestStruct {
bool waiting = false;
};
-static bool LockWhenTestIsCond(LockWhenTestStruct* s) {
+static bool LockWhenTestIsCond(LockWhenTestStruct *s) {
s->mu2.Lock();
s->waiting = true;
s->mu2.Unlock();
return s->cond;
}
-static void LockWhenTestWaitForIsCond(LockWhenTestStruct* s) {
+static void LockWhenTestWaitForIsCond(LockWhenTestStruct *s) {
s->mu1.LockWhen(absl::Condition(&LockWhenTestIsCond, s));
s->mu1.Unlock();
}
@@ -1694,8 +1695,7 @@ TEST(Mutex, Timed) {
TEST(Mutex, CVTime) {
int threads = 10; // Use a fixed thread count of 10
int iterations = 1;
- EXPECT_EQ(RunTest(&TestCVTime, threads, iterations, 1),
- threads * iterations);
+ EXPECT_EQ(RunTest(&TestCVTime, threads, iterations, 1), threads * iterations);
}
TEST(Mutex, MuTime) {
@@ -1730,4 +1730,95 @@ TEST(Mutex, SignalExitedThread) {
for (auto &th : top) th.join();
}
+#ifdef _MSC_VER
+
+// Declare classes of the various MSVC inheritance types.
+class __single_inheritance SingleInheritance{};
+class __multiple_inheritance MultipleInheritance;
+class __virtual_inheritance VirtualInheritance;
+class UnknownInheritance;
+
+TEST(ConditionTest, MicrosoftMethodPointerSize) {
+ // This test verifies expectations about sizes of MSVC pointers to methods.
+ // Pointers to methods are distinguished by whether their class hierachies
+ // contain single inheritance, multiple inheritance, or virtual inheritence.
+ void (SingleInheritance::*single_inheritance)();
+ void (MultipleInheritance::*multiple_inheritance)();
+ void (VirtualInheritance::*virtual_inheritance)();
+ void (UnknownInheritance::*unknown_inheritance)();
+
+#if defined(_M_IX86) || defined(_M_ARM)
+ static_assert(sizeof(single_inheritance) == 4,
+ "Unexpected sizeof(single_inheritance).");
+ static_assert(sizeof(multiple_inheritance) == 8,
+ "Unexpected sizeof(multiple_inheritance).");
+ static_assert(sizeof(virtual_inheritance) == 12,
+ "Unexpected sizeof(virtual_inheritance).");
+#elif defined(_M_X64) || defined(__aarch64__)
+ static_assert(sizeof(single_inheritance) == 8,
+ "Unexpected sizeof(single_inheritance).");
+ static_assert(sizeof(multiple_inheritance) == 16,
+ "Unexpected sizeof(multiple_inheritance).");
+ static_assert(sizeof(virtual_inheritance) == 16,
+ "Unexpected sizeof(virtual_inheritance).");
+#endif
+ static_assert(sizeof(unknown_inheritance) >= sizeof(virtual_inheritance),
+ "Failed invariant: sizeof(unknown_inheritance) >= "
+ "sizeof(virtual_inheritance)!");
+}
+
+class Callback {
+ bool x = true;
+
+ public:
+ Callback() {}
+ bool method() {
+ x = !x;
+ return x;
+ }
+};
+
+class M2 {
+ bool x = true;
+
+ public:
+ M2() {}
+ bool method2() {
+ x = !x;
+ return x;
+ }
+};
+
+class MultipleInheritance : public Callback, public M2 {};
+
+TEST(ConditionTest, ConditionWithMultipleInheritanceMethod) {
+ // This test ensures that Condition can deal with method pointers from classes
+ // with multiple inheritance.
+ MultipleInheritance object = MultipleInheritance();
+ absl::Condition condition(&object, &MultipleInheritance::method);
+ EXPECT_FALSE(condition.Eval());
+ EXPECT_TRUE(condition.Eval());
+}
+
+class __virtual_inheritance VirtualInheritance : virtual public Callback {
+ bool x = false;
+
+ public:
+ VirtualInheritance() {}
+ bool method() {
+ x = !x;
+ return x;
+ }
+};
+
+TEST(ConditionTest, ConditionWithVirtualInheritanceMethod) {
+ // This test ensures that Condition can deal with method pointers from classes
+ // with virtual inheritance.
+ VirtualInheritance object = VirtualInheritance();
+ absl::Condition condition(&object, &VirtualInheritance::method);
+ EXPECT_TRUE(condition.Eval());
+ EXPECT_FALSE(condition.Eval());
+}
+#endif
+
} // namespace