diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h | 187 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h | 6 | ||||
-rw-r--r-- | unsupported/test/cxx11_eventcount.cpp | 10 |
3 files changed, 110 insertions, 93 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h index 7a9ebe40a..8b3b210b1 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h @@ -20,7 +20,8 @@ namespace Eigen { // if (predicate) // return act(); // EventCount::Waiter& w = waiters[my_index]; -// ec.Prewait(&w); +// if (!ec.Prewait(&w)) +// return act(); // if (predicate) { // ec.CancelWait(&w); // return act(); @@ -50,78 +51,78 @@ class EventCount { public: class Waiter; - EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) { + EventCount(MaxSizeVector<Waiter>& waiters) + : state_(kStackMask), waiters_(waiters) { eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1); - // Initialize epoch to something close to overflow to test overflow. - state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2); } ~EventCount() { // Ensure there are no waiters. - eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask); + eigen_plain_assert(state_.load() == kStackMask); } // Prewait prepares for waiting. - // After calling this function the thread must re-check the wait predicate - // and call either CancelWait or CommitWait passing the same Waiter object. - void Prewait(Waiter* w) { - w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed); - std::atomic_thread_fence(std::memory_order_seq_cst); + // If Prewait returns true, the thread must re-check the wait predicate + // and then call either CancelWait or CommitWait. + // Otherwise, the thread should assume the predicate may be true + // and don't call CancelWait/CommitWait (there was a concurrent Notify call). + bool Prewait() { + uint64_t state = state_.load(std::memory_order_relaxed); + for (;;) { + CheckState(state); + uint64_t newstate = state + kWaiterInc; + if ((state & kSignalMask) != 0) { + // Consume the signal and cancel waiting. + newstate -= kSignalInc + kWaiterInc; + } + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_seq_cst)) + return (state & kSignalMask) == 0; + } } - // CommitWait commits waiting. + // CommitWait commits waiting after Prewait. void CommitWait(Waiter* w) { + eigen_plain_assert((w->epoch & ~kEpochMask) == 0); w->state = Waiter::kNotSignaled; - // Modification epoch of this waiter. - uint64_t epoch = - (w->epoch & kEpochMask) + - (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift); + const uint64_t me = (w - &waiters_[0]) | w->epoch; uint64_t state = state_.load(std::memory_order_seq_cst); for (;;) { - if (int64_t((state & kEpochMask) - epoch) < 0) { - // The preceding waiter has not decided on its fate. Wait until it - // calls either CancelWait or CommitWait, or is notified. - EIGEN_THREAD_YIELD(); - state = state_.load(std::memory_order_seq_cst); - continue; + CheckState(state, true); + uint64_t newstate; + if ((state & kSignalMask) != 0) { + // Consume the signal and return immidiately. + newstate = state - kWaiterInc - kSignalInc; + } else { + // Remove this thread from pre-wait counter and add to the waiter stack. + newstate = ((state & kWaiterMask) - kWaiterInc) | me; + w->next.store(state & (kStackMask | kEpochMask), + std::memory_order_relaxed); } - // We've already been notified. - if (int64_t((state & kEpochMask) - epoch) > 0) return; - // Remove this thread from prewait counter and add it to the waiter list. - eigen_plain_assert((state & kWaiterMask) != 0); - uint64_t newstate = state - kWaiterInc + kEpochInc; - newstate = (newstate & ~kStackMask) | (w - &waiters_[0]); - if ((state & kStackMask) == kStackMask) - w->next.store(nullptr, std::memory_order_relaxed); - else - w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed); + CheckState(newstate); if (state_.compare_exchange_weak(state, newstate, - std::memory_order_release)) - break; + std::memory_order_acq_rel)) { + if ((state & kSignalMask) == 0) { + w->epoch += kEpochInc; + Park(w); + } + return; + } } - Park(w); } // CancelWait cancels effects of the previous Prewait call. - void CancelWait(Waiter* w) { - uint64_t epoch = - (w->epoch & kEpochMask) + - (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift); + void CancelWait() { uint64_t state = state_.load(std::memory_order_relaxed); for (;;) { - if (int64_t((state & kEpochMask) - epoch) < 0) { - // The preceding waiter has not decided on its fate. Wait until it - // calls either CancelWait or CommitWait, or is notified. - EIGEN_THREAD_YIELD(); - state = state_.load(std::memory_order_relaxed); - continue; - } - // We've already been notified. - if (int64_t((state & kEpochMask) - epoch) > 0) return; - // Remove this thread from prewait counter. - eigen_plain_assert((state & kWaiterMask) != 0); - if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc, - std::memory_order_relaxed)) + CheckState(state, true); + uint64_t newstate = state - kWaiterInc; + // Also take away a signal if any. + if ((state & kSignalMask) != 0) newstate -= kSignalInc; + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_acq_rel)) return; } } @@ -132,35 +133,33 @@ class EventCount { std::atomic_thread_fence(std::memory_order_seq_cst); uint64_t state = state_.load(std::memory_order_acquire); for (;;) { + CheckState(state); + const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; + const uint64_t signals = (state & kSignalMask) >> kSignalShift; // Easy case: no waiters. - if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0) - return; - uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; + if ((state & kStackMask) == kStackMask && waiters == signals) return; uint64_t newstate; if (notifyAll) { - // Reset prewait counter and empty wait list. - newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask; - } else if (waiters) { + // Empty wait stack and set signal to number of pre-wait threads. + newstate = + (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask; + } else if (signals < waiters) { // There is a thread in pre-wait state, unblock it. - newstate = state + kEpochInc - kWaiterInc; + newstate = state + kSignalInc; } else { // Pop a waiter from list and unpark it. Waiter* w = &waiters_[state & kStackMask]; - Waiter* wnext = w->next.load(std::memory_order_relaxed); - uint64_t next = kStackMask; - if (wnext != nullptr) next = wnext - &waiters_[0]; - // Note: we don't add kEpochInc here. ABA problem on the lock-free stack - // can't happen because a waiter is re-pushed onto the stack only after - // it was in the pre-wait state which inevitably leads to epoch - // increment. - newstate = (state & kEpochMask) + next; + uint64_t next = w->next.load(std::memory_order_relaxed); + newstate = (state & (kWaiterMask | kSignalMask)) | next; } + CheckState(newstate); if (state_.compare_exchange_weak(state, newstate, - std::memory_order_acquire)) { - if (!notifyAll && waiters) return; // unblocked pre-wait thread + std::memory_order_acq_rel)) { + if (!notifyAll && (signals < waiters)) + return; // unblocked pre-wait thread if ((state & kStackMask) == kStackMask) return; Waiter* w = &waiters_[state & kStackMask]; - if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed); + if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed); Unpark(w); return; } @@ -171,11 +170,11 @@ class EventCount { friend class EventCount; // Align to 128 byte boundary to prevent false sharing with other Waiter // objects in the same vector. - EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next; + EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next; std::mutex mu; std::condition_variable cv; - uint64_t epoch; - unsigned state; + uint64_t epoch = 0; + unsigned state = kNotSignaled; enum { kNotSignaled, kWaiting, @@ -185,23 +184,41 @@ class EventCount { private: // State_ layout: - // - low kStackBits is a stack of waiters committed wait. + // - low kWaiterBits is a stack of waiters committed wait + // (indexes in waiters_ array are used as stack elements, + // kStackMask means empty stack). // - next kWaiterBits is count of waiters in prewait state. - // - next kEpochBits is modification counter. - static const uint64_t kStackBits = 16; - static const uint64_t kStackMask = (1ull << kStackBits) - 1; - static const uint64_t kWaiterBits = 16; - static const uint64_t kWaiterShift = 16; + // - next kWaiterBits is count of pending signals. + // - remaining bits are ABA counter for the stack. + // (stored in Waiter node and incremented on push). + static const uint64_t kWaiterBits = 14; + static const uint64_t kStackMask = (1ull << kWaiterBits) - 1; + static const uint64_t kWaiterShift = kWaiterBits; static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) << kWaiterShift; - static const uint64_t kWaiterInc = 1ull << kWaiterBits; - static const uint64_t kEpochBits = 32; - static const uint64_t kEpochShift = 32; + static const uint64_t kWaiterInc = 1ull << kWaiterShift; + static const uint64_t kSignalShift = 2 * kWaiterBits; + static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1) + << kSignalShift; + static const uint64_t kSignalInc = 1ull << kSignalShift; + static const uint64_t kEpochShift = 3 * kWaiterBits; + static const uint64_t kEpochBits = 64 - kEpochShift; static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; static const uint64_t kEpochInc = 1ull << kEpochShift; std::atomic<uint64_t> state_; MaxSizeVector<Waiter>& waiters_; + static void CheckState(uint64_t state, bool waiter = false) { + static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem"); + const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; + const uint64_t signals = (state & kSignalMask) >> kSignalShift; + eigen_plain_assert(waiters >= signals); + eigen_plain_assert(waiters < (1 << kWaiterBits) - 1); + eigen_plain_assert(!waiter || waiters > 0); + (void)waiters; + (void)signals; + } + void Park(Waiter* w) { std::unique_lock<std::mutex> lock(w->mu); while (w->state != Waiter::kSignaled) { @@ -210,10 +227,10 @@ class EventCount { } } - void Unpark(Waiter* waiters) { - Waiter* next = nullptr; - for (Waiter* w = waiters; w; w = next) { - next = w->next.load(std::memory_order_relaxed); + void Unpark(Waiter* w) { + for (Waiter* next; w; w = next) { + uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask; + next = wnext == kStackMask ? nullptr : &waiters_[wnext]; unsigned state; { std::unique_lock<std::mutex> lock(w->mu); diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h index 8fafcdab5..49603d6c1 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h @@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { eigen_plain_assert(!t->f); // We already did best-effort emptiness check in Steal, so prepare for // blocking. - ec_.Prewait(waiter); + if (!ec_.Prewait()) return true; // Now do a reliable emptiness check. int victim = NonEmptyQueueIndex(); if (victim != -1) { - ec_.CancelWait(waiter); + ec_.CancelWait(); if (cancelled_) { return false; } else { @@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { blocked_++; // TODO is blocked_ required to be unsigned? if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) { - ec_.CancelWait(waiter); + ec_.CancelWait(); // Almost done, but need to re-check queues. // Consider that all queues are empty and all worker threads are preempted // right after incrementing blocked_ above. Now a free-standing thread diff --git a/unsupported/test/cxx11_eventcount.cpp b/unsupported/test/cxx11_eventcount.cpp index 2f1418684..3ca8598c7 100644 --- a/unsupported/test/cxx11_eventcount.cpp +++ b/unsupported/test/cxx11_eventcount.cpp @@ -30,11 +30,11 @@ static void test_basic_eventcount() EventCount ec(waiters); EventCount::Waiter& w = waiters[0]; ec.Notify(false); - ec.Prewait(&w); + VERIFY(ec.Prewait()); ec.Notify(true); ec.CommitWait(&w); - ec.Prewait(&w); - ec.CancelWait(&w); + VERIFY(ec.Prewait()); + ec.CancelWait(); } // Fake bounded counter-based queue. @@ -112,7 +112,7 @@ static void test_stress_eventcount() unsigned idx = rand_reentrant(&rnd) % kQueues; if (queues[idx].Pop()) continue; j--; - ec.Prewait(&w); + if (!ec.Prewait()) continue; bool empty = true; for (int q = 0; q < kQueues; q++) { if (!queues[q].Empty()) { @@ -121,7 +121,7 @@ static void test_stress_eventcount() } } if (!empty) { - ec.CancelWait(&w); + ec.CancelWait(); continue; } ec.CommitWait(&w); |