diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-03-28 10:01:04 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-03-28 10:01:04 -0700 |
commit | 6772f653c33bd78c25623619581836bac1d1d20a (patch) | |
tree | 2a548a5df095ccaf6cd246d05af43bf5a7163dea /unsupported/Eigen/CXX11/src | |
parent | 1bc81f78895effe972ef8df5a138d267a74295fb (diff) |
Made it possible to customize the threadpool
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h | 79 |
1 files changed, 57 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index 23b1765ba..cd3dd214b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -24,36 +24,40 @@ class ThreadPoolInterface { // The implementation of the ThreadPool type ensures that the Schedule method // runs the functions it is provided in FIFO order when the scheduling is done // by a single thread. -class ThreadPool : public ThreadPoolInterface { +// Environment provides a way to create threads and also allows to intercept +// task submission and execution. +template <typename Environment> +class ThreadPoolTempl : public ThreadPoolInterface { public: // Construct a pool that contains "num_threads" threads. - explicit ThreadPool(int num_threads) : threads_(num_threads), waiters_(num_threads) { + explicit ThreadPoolTempl(int num_threads, Environment env = Environment()) + : env_(env), threads_(num_threads), waiters_(num_threads) { for (int i = 0; i < num_threads; i++) { - threads_.push_back(new std::thread([this]() { WorkerLoop(); })); + threads_.push_back(env.CreateThread([this]() { WorkerLoop(); })); } } // Wait until all scheduled work has finished and then destroy the // set of threads. - ~ThreadPool() - { + ~ThreadPoolTempl() { { // Wait for all work to get done. std::unique_lock<std::mutex> l(mu_); - empty_.wait(l, [this]() { return pending_.empty(); }); + while (!pending_.empty()) { + empty_.wait(l); + } exiting_ = true; // Wakeup all waiters. for (auto w : waiters_) { w->ready = true; - w->work = nullptr; + w->task.f = nullptr; w->cv.notify_one(); } } // Wait for threads to finish. for (auto t : threads_) { - t->join(); delete t; } } @@ -61,14 +65,15 @@ class ThreadPool : public ThreadPoolInterface { // Schedule fn() for execution in the pool of threads. The functions are // executed in the order in which they are scheduled. void Schedule(std::function<void()> fn) { + Task t = env_.CreateTask(std::move(fn)); std::unique_lock<std::mutex> l(mu_); if (waiters_.empty()) { - pending_.push_back(fn); + pending_.push_back(std::move(t)); } else { Waiter* w = waiters_.back(); waiters_.pop_back(); w->ready = true; - w->work = fn; + w->task = std::move(t); w->cv.notify_one(); } } @@ -77,46 +82,76 @@ class ThreadPool : public ThreadPoolInterface { void WorkerLoop() { std::unique_lock<std::mutex> l(mu_); Waiter w; + Task t; while (!exiting_) { - std::function<void()> fn; if (pending_.empty()) { // Wait for work to be assigned to me w.ready = false; waiters_.push_back(&w); - w.cv.wait(l, [&w]() { return w.ready; }); - fn = w.work; - w.work = nullptr; + while (!w.ready) { + w.cv.wait(l); + } + t = w.task; + w.task.f = nullptr; } else { // Pick up pending work - fn = pending_.front(); + t = std::move(pending_.front()); pending_.pop_front(); if (pending_.empty()) { empty_.notify_all(); } } - if (fn) { + if (t.f) { mu_.unlock(); - fn(); + env_.ExecuteTask(t); + t.f = nullptr; mu_.lock(); } } } private: + typedef typename Environment::Task Task; + typedef typename Environment::EnvThread Thread; + struct Waiter { std::condition_variable cv; - std::function<void()> work; + Task task; bool ready; }; + Environment env_; std::mutex mu_; - MaxSizeVector<std::thread*> threads_; // All threads - MaxSizeVector<Waiter*> waiters_; // Stack of waiting threads. - std::deque<std::function<void()>> pending_; // Queue of pending work - std::condition_variable empty_; // Signaled on pending_.empty() + MaxSizeVector<Thread*> threads_; // All threads + MaxSizeVector<Waiter*> waiters_; // Stack of waiting threads. + std::deque<Task> pending_; // Queue of pending work + std::condition_variable empty_; // Signaled on pending_.empty() bool exiting_ = false; }; +struct StlThreadEnvironment { + struct Task { + std::function<void()> f; + }; + + // EnvThread constructor must start the thread, + // destructor must join the thread. + class EnvThread { + public: + EnvThread(std::function<void()> f) : thr_(f) {} + ~EnvThread() { thr_.join(); } + + private: + std::thread thr_; + }; + + EnvThread* CreateThread(std::function<void()> f) { return new EnvThread(f); } + Task CreateTask(std::function<void()> f) { return Task{std::move(f)}; } + void ExecuteTask(const Task& t) { t.f(); } +}; + +typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool; + // Barrier is an object that allows one or more threads to wait until // Notify has been called a specified number of times. |