diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index 1612c004b..47025a510 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -214,18 +214,25 @@ struct ThreadPoolDevice { Barrier barrier(static_cast<unsigned int>(block_count)); std::function<void(Index, Index)> handleRange; handleRange = [=, &handleRange, &barrier, &f](Index firstIdx, Index lastIdx) { - if (lastIdx - firstIdx <= block_size) { - // Single block or less, execute directly. - f(firstIdx, lastIdx); - barrier.Notify(); - return; + while (lastIdx - firstIdx > block_size) { + // Split into halves and schedule the second half on a different thread. + const Index midIdx = firstIdx + divup((lastIdx - firstIdx) / 2, block_size) * block_size; + pool_->Schedule([=, &handleRange]() { handleRange(midIdx, lastIdx); }); + lastIdx = midIdx; } - // Split into halves and submit to the pool. - Index mid = firstIdx + divup((lastIdx - firstIdx) / 2, block_size) * block_size; - pool_->Schedule([=, &handleRange]() { handleRange(mid, lastIdx); }); - handleRange(firstIdx, mid); + // Single block or less, execute directly. + f(firstIdx, lastIdx); + barrier.Notify(); }; - handleRange(0, n); + if (block_count <= numThreads()) { + // Avoid a thread hop by running the root of the tree and one block on the + // main thread. + handleRange(0, n); + } else { + // Execute the root in the thread pool to avoid running work on more than + // numThreads() threads. + pool_->Schedule([=, &handleRange]() { handleRange(0, n); }); + } barrier.Wait(); } |