aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h187
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h6
-rw-r--r--unsupported/test/cxx11_eventcount.cpp10
3 files changed, 110 insertions, 93 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
index 7a9ebe40a..8b3b210b1 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
@@ -20,7 +20,8 @@ namespace Eigen {
// if (predicate)
// return act();
// EventCount::Waiter& w = waiters[my_index];
-// ec.Prewait(&w);
+// if (!ec.Prewait(&w))
+// return act();
// if (predicate) {
// ec.CancelWait(&w);
// return act();
@@ -50,78 +51,78 @@ class EventCount {
public:
class Waiter;
- EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
+ EventCount(MaxSizeVector<Waiter>& waiters)
+ : state_(kStackMask), waiters_(waiters) {
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
- // Initialize epoch to something close to overflow to test overflow.
- state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
}
~EventCount() {
// Ensure there are no waiters.
- eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
+ eigen_plain_assert(state_.load() == kStackMask);
}
// Prewait prepares for waiting.
- // After calling this function the thread must re-check the wait predicate
- // and call either CancelWait or CommitWait passing the same Waiter object.
- void Prewait(Waiter* w) {
- w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
- std::atomic_thread_fence(std::memory_order_seq_cst);
+ // If Prewait returns true, the thread must re-check the wait predicate
+ // and then call either CancelWait or CommitWait.
+ // Otherwise, the thread should assume the predicate may be true
+ // and don't call CancelWait/CommitWait (there was a concurrent Notify call).
+ bool Prewait() {
+ uint64_t state = state_.load(std::memory_order_relaxed);
+ for (;;) {
+ CheckState(state);
+ uint64_t newstate = state + kWaiterInc;
+ if ((state & kSignalMask) != 0) {
+ // Consume the signal and cancel waiting.
+ newstate -= kSignalInc + kWaiterInc;
+ }
+ CheckState(newstate);
+ if (state_.compare_exchange_weak(state, newstate,
+ std::memory_order_seq_cst))
+ return (state & kSignalMask) == 0;
+ }
}
- // CommitWait commits waiting.
+ // CommitWait commits waiting after Prewait.
void CommitWait(Waiter* w) {
+ eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
w->state = Waiter::kNotSignaled;
- // Modification epoch of this waiter.
- uint64_t epoch =
- (w->epoch & kEpochMask) +
- (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
+ const uint64_t me = (w - &waiters_[0]) | w->epoch;
uint64_t state = state_.load(std::memory_order_seq_cst);
for (;;) {
- if (int64_t((state & kEpochMask) - epoch) < 0) {
- // The preceding waiter has not decided on its fate. Wait until it
- // calls either CancelWait or CommitWait, or is notified.
- EIGEN_THREAD_YIELD();
- state = state_.load(std::memory_order_seq_cst);
- continue;
+ CheckState(state, true);
+ uint64_t newstate;
+ if ((state & kSignalMask) != 0) {
+ // Consume the signal and return immidiately.
+ newstate = state - kWaiterInc - kSignalInc;
+ } else {
+ // Remove this thread from pre-wait counter and add to the waiter stack.
+ newstate = ((state & kWaiterMask) - kWaiterInc) | me;
+ w->next.store(state & (kStackMask | kEpochMask),
+ std::memory_order_relaxed);
}
- // We've already been notified.
- if (int64_t((state & kEpochMask) - epoch) > 0) return;
- // Remove this thread from prewait counter and add it to the waiter list.
- eigen_plain_assert((state & kWaiterMask) != 0);
- uint64_t newstate = state - kWaiterInc + kEpochInc;
- newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
- if ((state & kStackMask) == kStackMask)
- w->next.store(nullptr, std::memory_order_relaxed);
- else
- w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
+ CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
- std::memory_order_release))
- break;
+ std::memory_order_acq_rel)) {
+ if ((state & kSignalMask) == 0) {
+ w->epoch += kEpochInc;
+ Park(w);
+ }
+ return;
+ }
}
- Park(w);
}
// CancelWait cancels effects of the previous Prewait call.
- void CancelWait(Waiter* w) {
- uint64_t epoch =
- (w->epoch & kEpochMask) +
- (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
+ void CancelWait() {
uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) {
- if (int64_t((state & kEpochMask) - epoch) < 0) {
- // The preceding waiter has not decided on its fate. Wait until it
- // calls either CancelWait or CommitWait, or is notified.
- EIGEN_THREAD_YIELD();
- state = state_.load(std::memory_order_relaxed);
- continue;
- }
- // We've already been notified.
- if (int64_t((state & kEpochMask) - epoch) > 0) return;
- // Remove this thread from prewait counter.
- eigen_plain_assert((state & kWaiterMask) != 0);
- if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
- std::memory_order_relaxed))
+ CheckState(state, true);
+ uint64_t newstate = state - kWaiterInc;
+ // Also take away a signal if any.
+ if ((state & kSignalMask) != 0) newstate -= kSignalInc;
+ CheckState(newstate);
+ if (state_.compare_exchange_weak(state, newstate,
+ std::memory_order_acq_rel))
return;
}
}
@@ -132,35 +133,33 @@ class EventCount {
std::atomic_thread_fence(std::memory_order_seq_cst);
uint64_t state = state_.load(std::memory_order_acquire);
for (;;) {
+ CheckState(state);
+ const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ const uint64_t signals = (state & kSignalMask) >> kSignalShift;
// Easy case: no waiters.
- if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
- return;
- uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ if ((state & kStackMask) == kStackMask && waiters == signals) return;
uint64_t newstate;
if (notifyAll) {
- // Reset prewait counter and empty wait list.
- newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
- } else if (waiters) {
+ // Empty wait stack and set signal to number of pre-wait threads.
+ newstate =
+ (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
+ } else if (signals < waiters) {
// There is a thread in pre-wait state, unblock it.
- newstate = state + kEpochInc - kWaiterInc;
+ newstate = state + kSignalInc;
} else {
// Pop a waiter from list and unpark it.
Waiter* w = &waiters_[state & kStackMask];
- Waiter* wnext = w->next.load(std::memory_order_relaxed);
- uint64_t next = kStackMask;
- if (wnext != nullptr) next = wnext - &waiters_[0];
- // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
- // can't happen because a waiter is re-pushed onto the stack only after
- // it was in the pre-wait state which inevitably leads to epoch
- // increment.
- newstate = (state & kEpochMask) + next;
+ uint64_t next = w->next.load(std::memory_order_relaxed);
+ newstate = (state & (kWaiterMask | kSignalMask)) | next;
}
+ CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
- std::memory_order_acquire)) {
- if (!notifyAll && waiters) return; // unblocked pre-wait thread
+ std::memory_order_acq_rel)) {
+ if (!notifyAll && (signals < waiters))
+ return; // unblocked pre-wait thread
if ((state & kStackMask) == kStackMask) return;
Waiter* w = &waiters_[state & kStackMask];
- if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed);
+ if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
Unpark(w);
return;
}
@@ -171,11 +170,11 @@ class EventCount {
friend class EventCount;
// Align to 128 byte boundary to prevent false sharing with other Waiter
// objects in the same vector.
- EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
+ EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
std::mutex mu;
std::condition_variable cv;
- uint64_t epoch;
- unsigned state;
+ uint64_t epoch = 0;
+ unsigned state = kNotSignaled;
enum {
kNotSignaled,
kWaiting,
@@ -185,23 +184,41 @@ class EventCount {
private:
// State_ layout:
- // - low kStackBits is a stack of waiters committed wait.
+ // - low kWaiterBits is a stack of waiters committed wait
+ // (indexes in waiters_ array are used as stack elements,
+ // kStackMask means empty stack).
// - next kWaiterBits is count of waiters in prewait state.
- // - next kEpochBits is modification counter.
- static const uint64_t kStackBits = 16;
- static const uint64_t kStackMask = (1ull << kStackBits) - 1;
- static const uint64_t kWaiterBits = 16;
- static const uint64_t kWaiterShift = 16;
+ // - next kWaiterBits is count of pending signals.
+ // - remaining bits are ABA counter for the stack.
+ // (stored in Waiter node and incremented on push).
+ static const uint64_t kWaiterBits = 14;
+ static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
+ static const uint64_t kWaiterShift = kWaiterBits;
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
<< kWaiterShift;
- static const uint64_t kWaiterInc = 1ull << kWaiterBits;
- static const uint64_t kEpochBits = 32;
- static const uint64_t kEpochShift = 32;
+ static const uint64_t kWaiterInc = 1ull << kWaiterShift;
+ static const uint64_t kSignalShift = 2 * kWaiterBits;
+ static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
+ << kSignalShift;
+ static const uint64_t kSignalInc = 1ull << kSignalShift;
+ static const uint64_t kEpochShift = 3 * kWaiterBits;
+ static const uint64_t kEpochBits = 64 - kEpochShift;
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
static const uint64_t kEpochInc = 1ull << kEpochShift;
std::atomic<uint64_t> state_;
MaxSizeVector<Waiter>& waiters_;
+ static void CheckState(uint64_t state, bool waiter = false) {
+ static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
+ const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ const uint64_t signals = (state & kSignalMask) >> kSignalShift;
+ eigen_plain_assert(waiters >= signals);
+ eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
+ eigen_plain_assert(!waiter || waiters > 0);
+ (void)waiters;
+ (void)signals;
+ }
+
void Park(Waiter* w) {
std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) {
@@ -210,10 +227,10 @@ class EventCount {
}
}
- void Unpark(Waiter* waiters) {
- Waiter* next = nullptr;
- for (Waiter* w = waiters; w; w = next) {
- next = w->next.load(std::memory_order_relaxed);
+ void Unpark(Waiter* w) {
+ for (Waiter* next; w; w = next) {
+ uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
+ next = wnext == kStackMask ? nullptr : &waiters_[wnext];
unsigned state;
{
std::unique_lock<std::mutex> lock(w->mu);
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
index 8fafcdab5..49603d6c1 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
@@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
eigen_plain_assert(!t->f);
// We already did best-effort emptiness check in Steal, so prepare for
// blocking.
- ec_.Prewait(waiter);
+ if (!ec_.Prewait()) return true;
// Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex();
if (victim != -1) {
- ec_.CancelWait(waiter);
+ ec_.CancelWait();
if (cancelled_) {
return false;
} else {
@@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
blocked_++;
// TODO is blocked_ required to be unsigned?
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
- ec_.CancelWait(waiter);
+ ec_.CancelWait();
// Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread
diff --git a/unsupported/test/cxx11_eventcount.cpp b/unsupported/test/cxx11_eventcount.cpp
index 2f1418684..3ca8598c7 100644
--- a/unsupported/test/cxx11_eventcount.cpp
+++ b/unsupported/test/cxx11_eventcount.cpp
@@ -30,11 +30,11 @@ static void test_basic_eventcount()
EventCount ec(waiters);
EventCount::Waiter& w = waiters[0];
ec.Notify(false);
- ec.Prewait(&w);
+ VERIFY(ec.Prewait());
ec.Notify(true);
ec.CommitWait(&w);
- ec.Prewait(&w);
- ec.CancelWait(&w);
+ VERIFY(ec.Prewait());
+ ec.CancelWait();
}
// Fake bounded counter-based queue.
@@ -112,7 +112,7 @@ static void test_stress_eventcount()
unsigned idx = rand_reentrant(&rnd) % kQueues;
if (queues[idx].Pop()) continue;
j--;
- ec.Prewait(&w);
+ if (!ec.Prewait()) continue;
bool empty = true;
for (int q = 0; q < kQueues; q++) {
if (!queues[q].Empty()) {
@@ -121,7 +121,7 @@ static void test_stress_eventcount()
}
}
if (!empty) {
- ec.CancelWait(&w);
+ ec.CancelWait();
continue;
}
ec.CommitWait(&w);