aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h65
1 files changed, 50 insertions, 15 deletions
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
index ecd49f382..ede70da8d 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
@@ -10,7 +10,6 @@
#ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
#define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
-
namespace Eigen {
template <typename Environment>
@@ -23,7 +22,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
: ThreadPoolTempl(num_threads, true, env) {}
ThreadPoolTempl(int num_threads, bool allow_spinning,
- Environment env = Environment())
+ Environment env = Environment())
: env_(env),
num_threads_(num_threads),
allow_spinning_(allow_spinning),
@@ -61,9 +60,17 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
for (int i = 0; i < num_threads_; i++) {
queues_.push_back(new Queue());
}
+#ifndef EIGEN_THREAD_LOCAL
+ init_barrier_.reset(new Barrier(num_threads_));
+#endif
for (int i = 0; i < num_threads_; i++) {
threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); }));
}
+#ifndef EIGEN_THREAD_LOCAL
+ // Wait for workers to initialize per_thread_map_. Otherwise we might race
+ // with them in Schedule or CurrentThreadId.
+ init_barrier_->Wait();
+#endif
}
~ThreadPoolTempl() {
@@ -85,6 +92,9 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
// Join threads explicitly to avoid destruction order issues.
for (size_t i = 0; i < num_threads_; i++) delete threads_[i];
for (size_t i = 0; i < num_threads_; i++) delete queues_[i];
+#ifndef EIGEN_THREAD_LOCAL
+ for (auto it : per_thread_map_) delete it.second;
+#endif
}
void Schedule(std::function<void()> fn) {
@@ -109,8 +119,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
// this is kept alive while any threads can potentially be in Schedule.
if (!t.f) {
ec_.Notify(false);
- }
- else {
+ } else {
env_.ExecuteTask(t); // Push failed, execute directly.
}
}
@@ -130,13 +139,10 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
ec_.Notify(true);
}
- int NumThreads() const final {
- return num_threads_;
- }
+ int NumThreads() const final { return num_threads_; }
int CurrentThreadId() const final {
- const PerThread* pt =
- const_cast<ThreadPoolTempl*>(this)->GetPerThread();
+ const PerThread* pt = const_cast<ThreadPoolTempl*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
@@ -148,10 +154,10 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
typedef typename Environment::EnvThread Thread;
struct PerThread {
- constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { }
+ constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
- uint64_t rand; // Random generator state.
- int thread_id; // Worker thread index in pool.
+ uint64_t rand; // Random generator state.
+ int thread_id; // Worker thread index in pool.
};
Environment env_;
@@ -166,12 +172,26 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
std::atomic<bool> done_;
std::atomic<bool> cancelled_;
EventCount ec_;
+#ifndef EIGEN_THREAD_LOCAL
+ std::unique_ptr<Barrier> init_barrier_;
+ std::mutex mu; // Protects per_thread_map_.
+ std::unordered_map<uint64_t, PerThread*> per_thread_map_;
+#endif
// Main worker thread loop.
void WorkerLoop(int thread_id) {
+#ifndef EIGEN_THREAD_LOCAL
+ PerThread* pt = new PerThread();
+ mu.lock();
+ per_thread_map_[GlobalThreadIdHash()] = pt;
+ mu.unlock();
+ init_barrier_->Notify();
+ init_barrier_->Wait();
+#else
PerThread* pt = GetPerThread();
+#endif
pt->pool = this;
- pt->rand = std::hash<std::thread::id>()(std::this_thread::get_id());
+ pt->rand = GlobalThreadIdHash();
pt->thread_id = thread_id;
Queue* q = queues_[thread_id];
EventCount::Waiter* waiter = &waiters_[thread_id];
@@ -322,10 +342,24 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
return -1;
}
- static EIGEN_STRONG_INLINE PerThread* GetPerThread() {
+ static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
+ return std::hash<std::thread::id>()(std::this_thread::get_id());
+ }
+
+ EIGEN_STRONG_INLINE PerThread* GetPerThread() {
+#ifndef EIGEN_THREAD_LOCAL
+ static PerThread dummy;
+ auto it = per_thread_map_.find(GlobalThreadIdHash());
+ if (it == per_thread_map_.end()) {
+ return &dummy;
+ } else {
+ return it->second;
+ }
+#else
EIGEN_THREAD_LOCAL PerThread per_thread_;
PerThread* pt = &per_thread_;
return pt;
+#endif
}
static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
@@ -333,7 +367,8 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
// Update the internal state
*state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
// Generate the random output (using the PCG-XSH-RS scheme)
- return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
+ return static_cast<unsigned>((current ^ (current >> 22)) >>
+ (22 + (current >> 61)));
}
};