diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h | 65 |
1 files changed, 50 insertions, 15 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h index ecd49f382..ede70da8d 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h @@ -10,7 +10,6 @@ #ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H #define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H - namespace Eigen { template <typename Environment> @@ -23,7 +22,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { : ThreadPoolTempl(num_threads, true, env) {} ThreadPoolTempl(int num_threads, bool allow_spinning, - Environment env = Environment()) + Environment env = Environment()) : env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), @@ -61,9 +60,17 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { for (int i = 0; i < num_threads_; i++) { queues_.push_back(new Queue()); } +#ifndef EIGEN_THREAD_LOCAL + init_barrier_.reset(new Barrier(num_threads_)); +#endif for (int i = 0; i < num_threads_; i++) { threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); })); } +#ifndef EIGEN_THREAD_LOCAL + // Wait for workers to initialize per_thread_map_. Otherwise we might race + // with them in Schedule or CurrentThreadId. + init_barrier_->Wait(); +#endif } ~ThreadPoolTempl() { @@ -85,6 +92,9 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { // Join threads explicitly to avoid destruction order issues. for (size_t i = 0; i < num_threads_; i++) delete threads_[i]; for (size_t i = 0; i < num_threads_; i++) delete queues_[i]; +#ifndef EIGEN_THREAD_LOCAL + for (auto it : per_thread_map_) delete it.second; +#endif } void Schedule(std::function<void()> fn) { @@ -109,8 +119,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { // this is kept alive while any threads can potentially be in Schedule. if (!t.f) { ec_.Notify(false); - } - else { + } else { env_.ExecuteTask(t); // Push failed, execute directly. } } @@ -130,13 +139,10 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { ec_.Notify(true); } - int NumThreads() const final { - return num_threads_; - } + int NumThreads() const final { return num_threads_; } int CurrentThreadId() const final { - const PerThread* pt = - const_cast<ThreadPoolTempl*>(this)->GetPerThread(); + const PerThread* pt = const_cast<ThreadPoolTempl*>(this)->GetPerThread(); if (pt->pool == this) { return pt->thread_id; } else { @@ -148,10 +154,10 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { typedef typename Environment::EnvThread Thread; struct PerThread { - constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { } + constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {} ThreadPoolTempl* pool; // Parent pool, or null for normal threads. - uint64_t rand; // Random generator state. - int thread_id; // Worker thread index in pool. + uint64_t rand; // Random generator state. + int thread_id; // Worker thread index in pool. }; Environment env_; @@ -166,12 +172,26 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { std::atomic<bool> done_; std::atomic<bool> cancelled_; EventCount ec_; +#ifndef EIGEN_THREAD_LOCAL + std::unique_ptr<Barrier> init_barrier_; + std::mutex mu; // Protects per_thread_map_. + std::unordered_map<uint64_t, PerThread*> per_thread_map_; +#endif // Main worker thread loop. void WorkerLoop(int thread_id) { +#ifndef EIGEN_THREAD_LOCAL + PerThread* pt = new PerThread(); + mu.lock(); + per_thread_map_[GlobalThreadIdHash()] = pt; + mu.unlock(); + init_barrier_->Notify(); + init_barrier_->Wait(); +#else PerThread* pt = GetPerThread(); +#endif pt->pool = this; - pt->rand = std::hash<std::thread::id>()(std::this_thread::get_id()); + pt->rand = GlobalThreadIdHash(); pt->thread_id = thread_id; Queue* q = queues_[thread_id]; EventCount::Waiter* waiter = &waiters_[thread_id]; @@ -322,10 +342,24 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { return -1; } - static EIGEN_STRONG_INLINE PerThread* GetPerThread() { + static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() { + return std::hash<std::thread::id>()(std::this_thread::get_id()); + } + + EIGEN_STRONG_INLINE PerThread* GetPerThread() { +#ifndef EIGEN_THREAD_LOCAL + static PerThread dummy; + auto it = per_thread_map_.find(GlobalThreadIdHash()); + if (it == per_thread_map_.end()) { + return &dummy; + } else { + return it->second; + } +#else EIGEN_THREAD_LOCAL PerThread per_thread_; PerThread* pt = &per_thread_; return pt; +#endif } static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) { @@ -333,7 +367,8 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { // Update the internal state *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL; // Generate the random output (using the PCG-XSH-RS scheme) - return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61))); + return static_cast<unsigned>((current ^ (current >> 22)) >> + (22 + (current >> 61))); } }; |