From 6772f653c33bd78c25623619581836bac1d1d20a Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 28 Mar 2016 10:01:04 -0700 Subject: Made it possible to customize the threadpool --- .../CXX11/src/Tensor/TensorDeviceThreadPool.h | 79 ++++++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index 23b1765ba..cd3dd214b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -24,36 +24,40 @@ class ThreadPoolInterface { // The implementation of the ThreadPool type ensures that the Schedule method // runs the functions it is provided in FIFO order when the scheduling is done // by a single thread. -class ThreadPool : public ThreadPoolInterface { +// Environment provides a way to create threads and also allows to intercept +// task submission and execution. +template +class ThreadPoolTempl : public ThreadPoolInterface { public: // Construct a pool that contains "num_threads" threads. - explicit ThreadPool(int num_threads) : threads_(num_threads), waiters_(num_threads) { + explicit ThreadPoolTempl(int num_threads, Environment env = Environment()) + : env_(env), threads_(num_threads), waiters_(num_threads) { for (int i = 0; i < num_threads; i++) { - threads_.push_back(new std::thread([this]() { WorkerLoop(); })); + threads_.push_back(env.CreateThread([this]() { WorkerLoop(); })); } } // Wait until all scheduled work has finished and then destroy the // set of threads. - ~ThreadPool() - { + ~ThreadPoolTempl() { { // Wait for all work to get done. std::unique_lock l(mu_); - empty_.wait(l, [this]() { return pending_.empty(); }); + while (!pending_.empty()) { + empty_.wait(l); + } exiting_ = true; // Wakeup all waiters. for (auto w : waiters_) { w->ready = true; - w->work = nullptr; + w->task.f = nullptr; w->cv.notify_one(); } } // Wait for threads to finish. for (auto t : threads_) { - t->join(); delete t; } } @@ -61,14 +65,15 @@ class ThreadPool : public ThreadPoolInterface { // Schedule fn() for execution in the pool of threads. The functions are // executed in the order in which they are scheduled. void Schedule(std::function fn) { + Task t = env_.CreateTask(std::move(fn)); std::unique_lock l(mu_); if (waiters_.empty()) { - pending_.push_back(fn); + pending_.push_back(std::move(t)); } else { Waiter* w = waiters_.back(); waiters_.pop_back(); w->ready = true; - w->work = fn; + w->task = std::move(t); w->cv.notify_one(); } } @@ -77,46 +82,76 @@ class ThreadPool : public ThreadPoolInterface { void WorkerLoop() { std::unique_lock l(mu_); Waiter w; + Task t; while (!exiting_) { - std::function fn; if (pending_.empty()) { // Wait for work to be assigned to me w.ready = false; waiters_.push_back(&w); - w.cv.wait(l, [&w]() { return w.ready; }); - fn = w.work; - w.work = nullptr; + while (!w.ready) { + w.cv.wait(l); + } + t = w.task; + w.task.f = nullptr; } else { // Pick up pending work - fn = pending_.front(); + t = std::move(pending_.front()); pending_.pop_front(); if (pending_.empty()) { empty_.notify_all(); } } - if (fn) { + if (t.f) { mu_.unlock(); - fn(); + env_.ExecuteTask(t); + t.f = nullptr; mu_.lock(); } } } private: + typedef typename Environment::Task Task; + typedef typename Environment::EnvThread Thread; + struct Waiter { std::condition_variable cv; - std::function work; + Task task; bool ready; }; + Environment env_; std::mutex mu_; - MaxSizeVector threads_; // All threads - MaxSizeVector waiters_; // Stack of waiting threads. - std::deque> pending_; // Queue of pending work - std::condition_variable empty_; // Signaled on pending_.empty() + MaxSizeVector threads_; // All threads + MaxSizeVector waiters_; // Stack of waiting threads. + std::deque pending_; // Queue of pending work + std::condition_variable empty_; // Signaled on pending_.empty() bool exiting_ = false; }; +struct StlThreadEnvironment { + struct Task { + std::function f; + }; + + // EnvThread constructor must start the thread, + // destructor must join the thread. + class EnvThread { + public: + EnvThread(std::function f) : thr_(f) {} + ~EnvThread() { thr_.join(); } + + private: + std::thread thr_; + }; + + EnvThread* CreateThread(std::function f) { return new EnvThread(f); } + Task CreateTask(std::function f) { return Task{std::move(f)}; } + void ExecuteTask(const Task& t) { t.f(); } +}; + +typedef ThreadPoolTempl ThreadPool; + // Barrier is an object that allows one or more threads to wait until // Notify has been called a specified number of times. -- cgit v1.2.3