diff options
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h | 163 | ||||
-rw-r--r-- | unsupported/test/cxx11_non_blocking_thread_pool.cpp | 7 |
2 files changed, 111 insertions, 59 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h index ed1a761b6..e28afedb4 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h @@ -20,7 +20,9 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { typedef RunQueue<Task, 1024> Queue; NonBlockingThreadPoolTempl(int num_threads, Environment env = Environment()) - : env_(env), + : num_threads_(num_threads), + allow_spinning_(true), + env_(env), threads_(num_threads), queues_(num_threads), coprimes_(num_threads), @@ -30,34 +32,24 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { done_(false), cancelled_(false), ec_(waiters_) { - waiters_.resize(num_threads); + Init(); + } - // Calculate coprimes of num_threads. - // Coprimes are used for a random walk over all threads in Steal - // and NonEmptyQueueIndex. Iteration is based on the fact that if we take - // a walk starting thread index t and calculate num_threads - 1 subsequent - // indices as (t + coprime) % num_threads, we will cover all threads without - // repetitions (effectively getting a presudo-random permutation of thread - // indices). - for (int i = 1; i <= num_threads; i++) { - unsigned a = i; - unsigned b = num_threads; - // If GCD(a, b) == 1, then a and b are coprimes. - while (b != 0) { - unsigned tmp = a; - a = b; - b = tmp % b; - } - if (a == 1) { - coprimes_.push_back(i); - } - } - for (int i = 0; i < num_threads; i++) { - queues_.push_back(new Queue()); - } - for (int i = 0; i < num_threads; i++) { - threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); })); - } + NonBlockingThreadPoolTempl(int num_threads, bool allow_spinning, + Environment env = Environment()) + : num_threads_(num_threads), + allow_spinning_(allow_spinning), + env_(env), + threads_(num_threads), + queues_(num_threads), + coprimes_(num_threads), + waiters_(num_threads), + blocked_(0), + spinning_(0), + done_(false), + cancelled_(false), + ec_(waiters_) { + Init(); } ~NonBlockingThreadPoolTempl() { @@ -77,8 +69,8 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { } // Join threads explicitly to avoid destruction order issues. - for (size_t i = 0; i < threads_.size(); i++) delete threads_[i]; - for (size_t i = 0; i < threads_.size(); i++) delete queues_[i]; + for (size_t i = 0; i < num_threads_; i++) delete threads_[i]; + for (size_t i = 0; i < num_threads_; i++) delete queues_[i]; } void Schedule(std::function<void()> fn) { @@ -125,7 +117,7 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { } int NumThreads() const final { - return static_cast<int>(threads_.size()); + return num_threads_; } int CurrentThreadId() const final { @@ -149,6 +141,8 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { }; Environment env_; + const int num_threads_; + const bool allow_spinning_; MaxSizeVector<Thread*> threads_; MaxSizeVector<Queue*> queues_; MaxSizeVector<unsigned> coprimes_; @@ -159,6 +153,37 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { std::atomic<bool> cancelled_; EventCount ec_; + void Init() { + waiters_.resize(num_threads_); + + // Calculate coprimes of num_threads_. + // Coprimes are used for a random walk over all threads in Steal + // and NonEmptyQueueIndex. Iteration is based on the fact that if we take + // a walk starting thread index t and calculate num_threads - 1 subsequent + // indices as (t + coprime) % num_threads, we will cover all threads without + // repetitions (effectively getting a presudo-random permutation of thread + // indices). + for (int i = 1; i <= num_threads_; i++) { + unsigned a = i; + unsigned b = num_threads_; + // If GCD(a, b) == 1, then a and b are coprimes. + while (b != 0) { + unsigned tmp = a; + a = b; + b = tmp % b; + } + if (a == 1) { + coprimes_.push_back(i); + } + } + for (int i = 0; i < num_threads_; i++) { + queues_.push_back(new Queue()); + } + for (int i = 0; i < num_threads_; i++) { + threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); })); + } + } + // Main worker thread loop. void WorkerLoop(int thread_id) { PerThread* pt = GetPerThread(); @@ -167,36 +192,62 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { pt->thread_id = thread_id; Queue* q = queues_[thread_id]; EventCount::Waiter* waiter = &waiters_[thread_id]; - while (!cancelled_) { - Task t = q->PopFront(); - if (!t.f) { - t = Steal(); + // TODO(dvyukov,rmlarsen): The time spent in Steal() is proportional + // to num_threads_ and we assume that new work is scheduled at a + // constant rate, so we set spin_count to 5000 / num_threads_. The + // constant was picked based on a fair dice roll, tune it. + const int spin_count = + allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0; + if (num_threads_ == 1) { + // For num_threads_ == 1 there is no point in going through the expensive + // steal loop. Moreover, since Steal() calls PopBack() on the victim + // queues it might reverse the order in which ops are executed compared to + // the order in which they are scheduled, which tends to be + // counter-productive for the types of I/O workloads the single thread + // pools tend to be used for. + while (!cancelled_) { + Task t = q->PopFront(); + for (int i = 0; i < spin_count && !t.f; i++) { + if (!cancelled_.load(std::memory_order_relaxed)) { + t = q->PopFront(); + } + } if (!t.f) { - // Leave one thread spinning. This reduces latency. - // TODO(dvyukov): 1000 iterations is based on fair dice roll, tune it. - // Also, the time it takes to attempt to steal work 1000 times depends - // on the size of the thread pool. However the speed at which the user - // of the thread pool submit tasks is independent of the size of the - // pool. Consider a time based limit instead. - if (!spinning_ && !spinning_.exchange(true)) { - for (int i = 0; i < 1000 && !t.f; i++) { - if (!cancelled_.load(std::memory_order_relaxed)) { - t = Steal(); - } else { - return; - } - } - spinning_ = false; + if (!WaitForWork(waiter, &t)) { + return; } + } + if (t.f) { + env_.ExecuteTask(t); + } + } + } else { + while (!cancelled_) { + Task t = q->PopFront(); + if (!t.f) { + t = Steal(); if (!t.f) { - if (!WaitForWork(waiter, &t)) { - return; + // Leave one thread spinning. This reduces latency. + if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) { + for (int i = 0; i < spin_count && !t.f; i++) { + if (!cancelled_.load(std::memory_order_relaxed)) { + t = Steal(); + } else { + return; + } + } + spinning_ = false; + } + if (!t.f) { + if (!WaitForWork(waiter, &t)) { + return; + } } } } - } - if (t.f) { - env_.ExecuteTask(t); + if (t.f) { + env_.ExecuteTask(t); + } } } } @@ -244,7 +295,7 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { // If we are shutting down and all worker threads blocked without work, // that's we are done. blocked_++; - if (done_ && blocked_ == threads_.size()) { + if (done_ && blocked_ == num_threads_) { ec_.CancelWait(waiter); // Almost done, but need to re-check queues. // Consider that all queues are empty and all worker threads are preempted diff --git a/unsupported/test/cxx11_non_blocking_thread_pool.cpp b/unsupported/test/cxx11_non_blocking_thread_pool.cpp index 2c5765ce4..48cd2d4e4 100644 --- a/unsupported/test/cxx11_non_blocking_thread_pool.cpp +++ b/unsupported/test/cxx11_non_blocking_thread_pool.cpp @@ -23,11 +23,11 @@ static void test_create_destroy_empty_pool() } -static void test_parallelism() +static void test_parallelism(bool allow_spinning) { // Test we never-ever fail to match available tasks with idle threads. const int kThreads = 16; // code below expects that this is a multiple of 4 - NonBlockingThreadPool tp(kThreads); + NonBlockingThreadPool tp(kThreads, allow_spinning); VERIFY_IS_EQUAL(tp.NumThreads(), kThreads); VERIFY_IS_EQUAL(tp.CurrentThreadId(), -1); for (int iter = 0; iter < 100; ++iter) { @@ -119,6 +119,7 @@ static void test_cancel() void test_cxx11_non_blocking_thread_pool() { CALL_SUBTEST(test_create_destroy_empty_pool()); - CALL_SUBTEST(test_parallelism()); + CALL_SUBTEST(test_parallelism(true)); + CALL_SUBTEST(test_parallelism(false)); CALL_SUBTEST(test_cancel()); } |