aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/ThreadPool
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-02-22 13:56:26 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-02-22 13:56:26 -0800
commit6560692c670bcf34fc922474bf37f3c18b8768af (patch)
tree551dd7bd544c2db42f34b92bfaf368744bcc525f /unsupported/Eigen/CXX11/src/ThreadPool
parent0b25a5c431f2764cd46a04f07536d60256ecd256 (diff)
Improve EventCount used by the non-blocking threadpool.
The current algorithm requires threads to commit/cancel waiting in order they called Prewait. Spinning caused by that serialization can consume lots of CPU time on some workloads. Restructure the algorithm to not require that serialization and remove spin waits from Commit/CancelWait. Note: this reduces max number of threads from 2^16 to 2^14 to leave more space for ABA counter (which is now 22 bits). Implementation details are explained in comments.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/ThreadPool')
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h187
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h6
2 files changed, 105 insertions, 88 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
index 7a9ebe40a..8b3b210b1 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
@@ -20,7 +20,8 @@ namespace Eigen {
// if (predicate)
// return act();
// EventCount::Waiter& w = waiters[my_index];
-// ec.Prewait(&w);
+// if (!ec.Prewait(&w))
+// return act();
// if (predicate) {
// ec.CancelWait(&w);
// return act();
@@ -50,78 +51,78 @@ class EventCount {
public:
class Waiter;
- EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
+ EventCount(MaxSizeVector<Waiter>& waiters)
+ : state_(kStackMask), waiters_(waiters) {
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
- // Initialize epoch to something close to overflow to test overflow.
- state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
}
~EventCount() {
// Ensure there are no waiters.
- eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
+ eigen_plain_assert(state_.load() == kStackMask);
}
// Prewait prepares for waiting.
- // After calling this function the thread must re-check the wait predicate
- // and call either CancelWait or CommitWait passing the same Waiter object.
- void Prewait(Waiter* w) {
- w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
- std::atomic_thread_fence(std::memory_order_seq_cst);
+ // If Prewait returns true, the thread must re-check the wait predicate
+ // and then call either CancelWait or CommitWait.
+ // Otherwise, the thread should assume the predicate may be true
+ // and don't call CancelWait/CommitWait (there was a concurrent Notify call).
+ bool Prewait() {
+ uint64_t state = state_.load(std::memory_order_relaxed);
+ for (;;) {
+ CheckState(state);
+ uint64_t newstate = state + kWaiterInc;
+ if ((state & kSignalMask) != 0) {
+ // Consume the signal and cancel waiting.
+ newstate -= kSignalInc + kWaiterInc;
+ }
+ CheckState(newstate);
+ if (state_.compare_exchange_weak(state, newstate,
+ std::memory_order_seq_cst))
+ return (state & kSignalMask) == 0;
+ }
}
- // CommitWait commits waiting.
+ // CommitWait commits waiting after Prewait.
void CommitWait(Waiter* w) {
+ eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
w->state = Waiter::kNotSignaled;
- // Modification epoch of this waiter.
- uint64_t epoch =
- (w->epoch & kEpochMask) +
- (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
+ const uint64_t me = (w - &waiters_[0]) | w->epoch;
uint64_t state = state_.load(std::memory_order_seq_cst);
for (;;) {
- if (int64_t((state & kEpochMask) - epoch) < 0) {
- // The preceding waiter has not decided on its fate. Wait until it
- // calls either CancelWait or CommitWait, or is notified.
- EIGEN_THREAD_YIELD();
- state = state_.load(std::memory_order_seq_cst);
- continue;
+ CheckState(state, true);
+ uint64_t newstate;
+ if ((state & kSignalMask) != 0) {
+ // Consume the signal and return immidiately.
+ newstate = state - kWaiterInc - kSignalInc;
+ } else {
+ // Remove this thread from pre-wait counter and add to the waiter stack.
+ newstate = ((state & kWaiterMask) - kWaiterInc) | me;
+ w->next.store(state & (kStackMask | kEpochMask),
+ std::memory_order_relaxed);
}
- // We've already been notified.
- if (int64_t((state & kEpochMask) - epoch) > 0) return;
- // Remove this thread from prewait counter and add it to the waiter list.
- eigen_plain_assert((state & kWaiterMask) != 0);
- uint64_t newstate = state - kWaiterInc + kEpochInc;
- newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
- if ((state & kStackMask) == kStackMask)
- w->next.store(nullptr, std::memory_order_relaxed);
- else
- w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
+ CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
- std::memory_order_release))
- break;
+ std::memory_order_acq_rel)) {
+ if ((state & kSignalMask) == 0) {
+ w->epoch += kEpochInc;
+ Park(w);
+ }
+ return;
+ }
}
- Park(w);
}
// CancelWait cancels effects of the previous Prewait call.
- void CancelWait(Waiter* w) {
- uint64_t epoch =
- (w->epoch & kEpochMask) +
- (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
+ void CancelWait() {
uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) {
- if (int64_t((state & kEpochMask) - epoch) < 0) {
- // The preceding waiter has not decided on its fate. Wait until it
- // calls either CancelWait or CommitWait, or is notified.
- EIGEN_THREAD_YIELD();
- state = state_.load(std::memory_order_relaxed);
- continue;
- }
- // We've already been notified.
- if (int64_t((state & kEpochMask) - epoch) > 0) return;
- // Remove this thread from prewait counter.
- eigen_plain_assert((state & kWaiterMask) != 0);
- if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
- std::memory_order_relaxed))
+ CheckState(state, true);
+ uint64_t newstate = state - kWaiterInc;
+ // Also take away a signal if any.
+ if ((state & kSignalMask) != 0) newstate -= kSignalInc;
+ CheckState(newstate);
+ if (state_.compare_exchange_weak(state, newstate,
+ std::memory_order_acq_rel))
return;
}
}
@@ -132,35 +133,33 @@ class EventCount {
std::atomic_thread_fence(std::memory_order_seq_cst);
uint64_t state = state_.load(std::memory_order_acquire);
for (;;) {
+ CheckState(state);
+ const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ const uint64_t signals = (state & kSignalMask) >> kSignalShift;
// Easy case: no waiters.
- if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
- return;
- uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ if ((state & kStackMask) == kStackMask && waiters == signals) return;
uint64_t newstate;
if (notifyAll) {
- // Reset prewait counter and empty wait list.
- newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
- } else if (waiters) {
+ // Empty wait stack and set signal to number of pre-wait threads.
+ newstate =
+ (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
+ } else if (signals < waiters) {
// There is a thread in pre-wait state, unblock it.
- newstate = state + kEpochInc - kWaiterInc;
+ newstate = state + kSignalInc;
} else {
// Pop a waiter from list and unpark it.
Waiter* w = &waiters_[state & kStackMask];
- Waiter* wnext = w->next.load(std::memory_order_relaxed);
- uint64_t next = kStackMask;
- if (wnext != nullptr) next = wnext - &waiters_[0];
- // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
- // can't happen because a waiter is re-pushed onto the stack only after
- // it was in the pre-wait state which inevitably leads to epoch
- // increment.
- newstate = (state & kEpochMask) + next;
+ uint64_t next = w->next.load(std::memory_order_relaxed);
+ newstate = (state & (kWaiterMask | kSignalMask)) | next;
}
+ CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
- std::memory_order_acquire)) {
- if (!notifyAll && waiters) return; // unblocked pre-wait thread
+ std::memory_order_acq_rel)) {
+ if (!notifyAll && (signals < waiters))
+ return; // unblocked pre-wait thread
if ((state & kStackMask) == kStackMask) return;
Waiter* w = &waiters_[state & kStackMask];
- if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed);
+ if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
Unpark(w);
return;
}
@@ -171,11 +170,11 @@ class EventCount {
friend class EventCount;
// Align to 128 byte boundary to prevent false sharing with other Waiter
// objects in the same vector.
- EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
+ EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
std::mutex mu;
std::condition_variable cv;
- uint64_t epoch;
- unsigned state;
+ uint64_t epoch = 0;
+ unsigned state = kNotSignaled;
enum {
kNotSignaled,
kWaiting,
@@ -185,23 +184,41 @@ class EventCount {
private:
// State_ layout:
- // - low kStackBits is a stack of waiters committed wait.
+ // - low kWaiterBits is a stack of waiters committed wait
+ // (indexes in waiters_ array are used as stack elements,
+ // kStackMask means empty stack).
// - next kWaiterBits is count of waiters in prewait state.
- // - next kEpochBits is modification counter.
- static const uint64_t kStackBits = 16;
- static const uint64_t kStackMask = (1ull << kStackBits) - 1;
- static const uint64_t kWaiterBits = 16;
- static const uint64_t kWaiterShift = 16;
+ // - next kWaiterBits is count of pending signals.
+ // - remaining bits are ABA counter for the stack.
+ // (stored in Waiter node and incremented on push).
+ static const uint64_t kWaiterBits = 14;
+ static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
+ static const uint64_t kWaiterShift = kWaiterBits;
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
<< kWaiterShift;
- static const uint64_t kWaiterInc = 1ull << kWaiterBits;
- static const uint64_t kEpochBits = 32;
- static const uint64_t kEpochShift = 32;
+ static const uint64_t kWaiterInc = 1ull << kWaiterShift;
+ static const uint64_t kSignalShift = 2 * kWaiterBits;
+ static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
+ << kSignalShift;
+ static const uint64_t kSignalInc = 1ull << kSignalShift;
+ static const uint64_t kEpochShift = 3 * kWaiterBits;
+ static const uint64_t kEpochBits = 64 - kEpochShift;
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
static const uint64_t kEpochInc = 1ull << kEpochShift;
std::atomic<uint64_t> state_;
MaxSizeVector<Waiter>& waiters_;
+ static void CheckState(uint64_t state, bool waiter = false) {
+ static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
+ const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
+ const uint64_t signals = (state & kSignalMask) >> kSignalShift;
+ eigen_plain_assert(waiters >= signals);
+ eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
+ eigen_plain_assert(!waiter || waiters > 0);
+ (void)waiters;
+ (void)signals;
+ }
+
void Park(Waiter* w) {
std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) {
@@ -210,10 +227,10 @@ class EventCount {
}
}
- void Unpark(Waiter* waiters) {
- Waiter* next = nullptr;
- for (Waiter* w = waiters; w; w = next) {
- next = w->next.load(std::memory_order_relaxed);
+ void Unpark(Waiter* w) {
+ for (Waiter* next; w; w = next) {
+ uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
+ next = wnext == kStackMask ? nullptr : &waiters_[wnext];
unsigned state;
{
std::unique_lock<std::mutex> lock(w->mu);
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
index 8fafcdab5..49603d6c1 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
@@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
eigen_plain_assert(!t->f);
// We already did best-effort emptiness check in Steal, so prepare for
// blocking.
- ec_.Prewait(waiter);
+ if (!ec_.Prewait()) return true;
// Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex();
if (victim != -1) {
- ec_.CancelWait(waiter);
+ ec_.CancelWait();
if (cancelled_) {
return false;
} else {
@@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
blocked_++;
// TODO is blocked_ required to be unsigned?
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
- ec_.CancelWait(waiter);
+ ec_.CancelWait();
// Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread