diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-23 16:40:07 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-23 16:40:07 -0700 |
commit | a9c1e4d7b7ce7c9dc5310cee1ed13fdef08e506e (patch) | |
tree | 0da9ce7e94a29f6fc1071b4b0286ad089f803dc5 /unsupported | |
parent | d39df320d29ecc678e019962dfb2bdf64b061197 (diff) |
Return -1 from CurrentThreadId when called by thread outside the pool.
Diffstat (limited to 'unsupported')
5 files changed, 14 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index 0af91fe64..34270730b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -172,6 +172,8 @@ struct ThreadPoolDevice { pool_->Schedule(func); } + // Returns a logical thread index between 0 and pool_->NumThreads() - 1 if + // called from one of the threads in pool_. Returns -1 otherwise. EIGEN_STRONG_INLINE int currentThreadId() const { return pool_->CurrentThreadId(); } diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h index 1369ca183..33ae45131 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h @@ -99,13 +99,13 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { return static_cast<int>(threads_.size()); } - int CurrentThreadId() const { + int CurrentThreadId() const final { const PerThread* pt = const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread(); if (pt->pool == this) { return pt->thread_id; } else { - return NumThreads(); + return -1; } } @@ -113,10 +113,10 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface { typedef typename Environment::EnvThread Thread; struct PerThread { - constexpr PerThread() : pool(NULL), index(-1), rand(0) { } + constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { } NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads. - int thread_id; // Worker thread index in pool. - uint64_t rand; // Random generator state. + uint64_t rand; // Random generator state. + int thread_id; // Worker thread index in pool. }; Environment env_; diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h index 36eb6950f..e75d0f467 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/SimpleThreadPool.h @@ -78,7 +78,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { if (pt->pool == this) { return pt->thread_id; } else { - return NumThreads(); + return -1; } } @@ -128,8 +128,9 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { }; struct PerThread { - ThreadPoolTempl* pool; // Parent pool, or null for normal threads. - int thread_id; // Worker thread index in pool. + constexpr PerThread() : pool(NULL), thread_id(-1) { } + SimpleThreadPoolTempl* pool; // Parent pool, or null for normal threads. + int thread_id; // Worker thread index in pool. }; Environment env_; @@ -141,7 +142,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface { bool exiting_ = false; PerThread* GetPerThread() const { - static EIGEN_THREAD_LOCAL PerThread per_thread; + EIGEN_THREAD_LOCAL PerThread per_thread; return &per_thread; } }; diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h index 569cd4bc8..a65ee97c9 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h @@ -22,7 +22,7 @@ class ThreadPoolInterface { virtual int NumThreads() const = 0; // Returns a logical thread index between 0 and NumThreads() - 1 if called - // from one of the threads in the pool. Returns NumThreads() otherwise. + // from one of the threads in the pool. Returns -1 otherwise. virtual int CurrentThreadId() const = 0; virtual ~ThreadPoolInterface() {} diff --git a/unsupported/test/cxx11_non_blocking_thread_pool.cpp b/unsupported/test/cxx11_non_blocking_thread_pool.cpp index 6e4e5cbab..5f9bb938b 100644 --- a/unsupported/test/cxx11_non_blocking_thread_pool.cpp +++ b/unsupported/test/cxx11_non_blocking_thread_pool.cpp @@ -28,7 +28,7 @@ static void test_parallelism() const int kThreads = 16; // code below expects that this is a multiple of 4 NonBlockingThreadPool tp(kThreads); VERIFY_IS_EQUAL(tp.NumThreads(), kThreads); - VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads); + VERIFY_IS_EQUAL(tp.CurrentThreadId(), -1); for (int iter = 0; iter < 100; ++iter) { std::atomic<int> running(0); std::atomic<int> done(0); |