diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-03 16:28:58 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-03 16:28:58 -0700 |
commit | 76308e7fd277ad962a87724040670da827a27db4 (patch) | |
tree | a4b431787f1b826e3660c85eb4b8b1e044a9d750 | |
parent | 8d97ba6b2251aabf325ff74f24959ceaa85cf11e (diff) |
Add CurrentThreadId and NumThreads methods to Eigen threadpools and TensorDeviceThreadPool.
5 files changed, 65 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index d31b0ad38..90fded8ad 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -172,6 +172,10 @@ struct ThreadPoolDevice { pool_->Schedule(func); } + EIGEN_STRONG_INLINE size_t currentThreadId() const { + return pool_->CurrentThreadId(); + } + // parallelFor executes f with [0, n) arguments in parallel and waits for // completion. F accepts a half-open interval [first, last). // Block size is choosen based on the iteration cost and resulting parallel diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h index c094563b7..1465878b7 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h @@ -74,7 +74,7 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { PerThread* pt = GetPerThread(); if (pt->pool == this) { // Worker thread of this pool, push onto the thread's queue. - Queue* q = queues_[pt->index]; + Queue* q = queues_[pt->thread_id]; t = q->PushFront(std::move(t)); } else { // A free-standing thread (or worker of another pool), push onto a random @@ -95,13 +95,27 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { env_.ExecuteTask(t); // Push failed, execute directly. } + size_t NumThreads() const final { + return threads_.size(); + } + + size_t CurrentThreadId() const { + const PerThread* pt = + const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread(); + if (pt->pool == this) { + return static_cast<size_t>(pt->thread_id); + } else { + return threads_.size(); + } + } + private: typedef typename Environment::EnvThread Thread; struct PerThread { bool inited; NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads. - unsigned index; // Worker thread index in pool. + unsigned thread_id; // Worker thread index in pool. unsigned rand; // Random generator state. }; @@ -116,12 +130,12 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { EventCount ec_; // Main worker thread loop. - void WorkerLoop(unsigned index) { + void WorkerLoop(unsigned thread_id) { PerThread* pt = GetPerThread(); pt->pool = this; - pt->index = index; - Queue* q = queues_[index]; - EventCount::Waiter* waiter = &waiters_[index]; + pt->thread_id = thread_id; + Queue* q = queues_[thread_id]; + EventCount::Waiter* waiter = &waiters_[thread_id]; for (;;) { Task t = q->PopFront(); if (!t.f) { diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h index 17fd1658b..fde80afdf 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h @@ -24,7 +24,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { explicit SimpleThreadPoolTempl(int num_threads, Environment env = Environment()) : env_(env), threads_(num_threads), waiters_(num_threads) { for (int i = 0; i < num_threads; i++) { - threads_.push_back(env.CreateThread([this]() { WorkerLoop(); })); + threads_.push_back(env.CreateThread([this, i]() { WorkerLoop(i); })); } } @@ -55,7 +55,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { // Schedule fn() for execution in the pool of threads. The functions are // executed in the order in which they are scheduled. - void Schedule(std::function<void()> fn) { + void Schedule(std::function<void()> fn) final { Task t = env_.CreateTask(std::move(fn)); std::unique_lock<std::mutex> l(mu_); if (waiters_.empty()) { @@ -69,9 +69,25 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { } } + size_t NumThreads() const final { + return threads_.size(); + } + + size_t CurrentThreadId() const final { + const PerThread* pt = this->GetPerThread(); + if (pt->pool == this) { + return pt->thread_id; + } else { + return threads_.size(); + } + } + protected: - void WorkerLoop() { + void WorkerLoop(size_t thread_id) { std::unique_lock<std::mutex> l(mu_); + PerThread* pt = GetPerThread(); + pt->pool = this; + pt->thread_id = thread_id; Waiter w; Task t; while (!exiting_) { @@ -111,6 +127,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { bool ready; }; + struct PerThread { + ThreadPoolTempl* pool; // Parent pool, or null for normal threads. + size_t thread_id; // Worker thread index in pool. + }; + Environment env_; std::mutex mu_; MaxSizeVector<Thread*> threads_; // All threads @@ -118,6 +139,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { std::deque<Task> pending_; // Queue of pending work std::condition_variable empty_; // Signaled on pending_.empty() bool exiting_ = false; + + PerThread* GetPerThread() const { + static EIGEN_THREAD_LOCAL PerThread per_thread; + return &per_thread; + } }; typedef SimpleThreadPoolTempl<StlThreadEnvironment> SimpleThreadPool; diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h index 38b40aceb..b1beccdde 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h @@ -18,6 +18,13 @@ class ThreadPoolInterface { public: virtual void Schedule(std::function<void()> fn) = 0; + // Returns the number of threads in the pool. + virtual size_t NumThreads() const = 0; + + // Returns a logical thread index between 0 and NumThreads() - 1 if called + // from one of the threads in the pool. Returns NumThreads() otherwise. + virtual size_t CurrentThreadId() const = 0; + virtual ~ThreadPoolInterface() {} }; diff --git a/unsupported/test/cxx11_non_blocking_thread_pool.cpp b/unsupported/test/cxx11_non_blocking_thread_pool.cpp index 6569218c4..844a1fbf4 100644 --- a/unsupported/test/cxx11_non_blocking_thread_pool.cpp +++ b/unsupported/test/cxx11_non_blocking_thread_pool.cpp @@ -27,6 +27,8 @@ static void test_parallelism() // Test we never-ever fail to match available tasks with idle threads. const int kThreads = 16; // code below expects that this is a multiple of 4 NonBlockingThreadPool tp(kThreads); + VERIFY_IS_EQUAL(tp.NumThreads(), kThreads); + VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads); for (int iter = 0; iter < 100; ++iter) { std::atomic<int> running(0); std::atomic<int> done(0); @@ -34,6 +36,9 @@ static void test_parallelism() // Schedule kThreads tasks and ensure that they all are running. for (int i = 0; i < kThreads; ++i) { tp.Schedule([&]() { + const size_t thread_id = tp.CurrentThreadId(); + VERIFY_GE(thread_id, 0); + VERIFY_LE(thread_id, kThreads - 1); running++; while (phase < 1) { } |