aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h4
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h26
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h32
-rw-r--r--unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h7
-rw-r--r--unsupported/test/cxx11_non_blocking_thread_pool.cpp5
5 files changed, 65 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
index d31b0ad38..90fded8ad 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
@@ -172,6 +172,10 @@ struct ThreadPoolDevice {
pool_->Schedule(func);
}
+ EIGEN_STRONG_INLINE size_t currentThreadId() const {
+ return pool_->CurrentThreadId();
+ }
+
// parallelFor executes f with [0, n) arguments in parallel and waits for
// completion. F accepts a half-open interval [first, last).
// Block size is choosen based on the iteration cost and resulting parallel
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
index c094563b7..1465878b7 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h
@@ -74,7 +74,7 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
PerThread* pt = GetPerThread();
if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue.
- Queue* q = queues_[pt->index];
+ Queue* q = queues_[pt->thread_id];
t = q->PushFront(std::move(t));
} else {
// A free-standing thread (or worker of another pool), push onto a random
@@ -95,13 +95,27 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
env_.ExecuteTask(t); // Push failed, execute directly.
}
+ size_t NumThreads() const final {
+ return threads_.size();
+ }
+
+ size_t CurrentThreadId() const {
+ const PerThread* pt =
+ const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
+ if (pt->pool == this) {
+ return static_cast<size_t>(pt->thread_id);
+ } else {
+ return threads_.size();
+ }
+ }
+
private:
typedef typename Environment::EnvThread Thread;
struct PerThread {
bool inited;
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
- unsigned index; // Worker thread index in pool.
+ unsigned thread_id; // Worker thread index in pool.
unsigned rand; // Random generator state.
};
@@ -116,12 +130,12 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
EventCount ec_;
// Main worker thread loop.
- void WorkerLoop(unsigned index) {
+ void WorkerLoop(unsigned thread_id) {
PerThread* pt = GetPerThread();
pt->pool = this;
- pt->index = index;
- Queue* q = queues_[index];
- EventCount::Waiter* waiter = &waiters_[index];
+ pt->thread_id = thread_id;
+ Queue* q = queues_[thread_id];
+ EventCount::Waiter* waiter = &waiters_[thread_id];
for (;;) {
Task t = q->PopFront();
if (!t.f) {
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h
index 17fd1658b..fde80afdf 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h
@@ -24,7 +24,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
explicit SimpleThreadPoolTempl(int num_threads, Environment env = Environment())
: env_(env), threads_(num_threads), waiters_(num_threads) {
for (int i = 0; i < num_threads; i++) {
- threads_.push_back(env.CreateThread([this]() { WorkerLoop(); }));
+ threads_.push_back(env.CreateThread([this, i]() { WorkerLoop(i); }));
}
}
@@ -55,7 +55,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
// Schedule fn() for execution in the pool of threads. The functions are
// executed in the order in which they are scheduled.
- void Schedule(std::function<void()> fn) {
+ void Schedule(std::function<void()> fn) final {
Task t = env_.CreateTask(std::move(fn));
std::unique_lock<std::mutex> l(mu_);
if (waiters_.empty()) {
@@ -69,9 +69,25 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
}
}
+ size_t NumThreads() const final {
+ return threads_.size();
+ }
+
+ size_t CurrentThreadId() const final {
+ const PerThread* pt = this->GetPerThread();
+ if (pt->pool == this) {
+ return pt->thread_id;
+ } else {
+ return threads_.size();
+ }
+ }
+
protected:
- void WorkerLoop() {
+ void WorkerLoop(size_t thread_id) {
std::unique_lock<std::mutex> l(mu_);
+ PerThread* pt = GetPerThread();
+ pt->pool = this;
+ pt->thread_id = thread_id;
Waiter w;
Task t;
while (!exiting_) {
@@ -111,6 +127,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
bool ready;
};
+ struct PerThread {
+ ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
+ size_t thread_id; // Worker thread index in pool.
+ };
+
Environment env_;
std::mutex mu_;
MaxSizeVector<Thread*> threads_; // All threads
@@ -118,6 +139,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
std::deque<Task> pending_; // Queue of pending work
std::condition_variable empty_; // Signaled on pending_.empty()
bool exiting_ = false;
+
+ PerThread* GetPerThread() const {
+ static EIGEN_THREAD_LOCAL PerThread per_thread;
+ return &per_thread;
+ }
};
typedef SimpleThreadPoolTempl<StlThreadEnvironment> SimpleThreadPool;
diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h
index 38b40aceb..b1beccdde 100644
--- a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h
+++ b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h
@@ -18,6 +18,13 @@ class ThreadPoolInterface {
public:
virtual void Schedule(std::function<void()> fn) = 0;
+ // Returns the number of threads in the pool.
+ virtual size_t NumThreads() const = 0;
+
+ // Returns a logical thread index between 0 and NumThreads() - 1 if called
+ // from one of the threads in the pool. Returns NumThreads() otherwise.
+ virtual size_t CurrentThreadId() const = 0;
+
virtual ~ThreadPoolInterface() {}
};
diff --git a/unsupported/test/cxx11_non_blocking_thread_pool.cpp b/unsupported/test/cxx11_non_blocking_thread_pool.cpp
index 6569218c4..844a1fbf4 100644
--- a/unsupported/test/cxx11_non_blocking_thread_pool.cpp
+++ b/unsupported/test/cxx11_non_blocking_thread_pool.cpp
@@ -27,6 +27,8 @@ static void test_parallelism()
// Test we never-ever fail to match available tasks with idle threads.
const int kThreads = 16; // code below expects that this is a multiple of 4
NonBlockingThreadPool tp(kThreads);
+ VERIFY_IS_EQUAL(tp.NumThreads(), kThreads);
+ VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads);
for (int iter = 0; iter < 100; ++iter) {
std::atomic<int> running(0);
std::atomic<int> done(0);
@@ -34,6 +36,9 @@ static void test_parallelism()
// Schedule kThreads tasks and ensure that they all are running.
for (int i = 0; i < kThreads; ++i) {
tp.Schedule([&]() {
+ const size_t thread_id = tp.CurrentThreadId();
+ VERIFY_GE(thread_id, 0);
+ VERIFY_LE(thread_id, kThreads - 1);
running++;
while (phase < 1) {
}