aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h27
1 files changed, 17 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
index 1612c004b..47025a510 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
@@ -214,18 +214,25 @@ struct ThreadPoolDevice {
Barrier barrier(static_cast<unsigned int>(block_count));
std::function<void(Index, Index)> handleRange;
handleRange = [=, &handleRange, &barrier, &f](Index firstIdx, Index lastIdx) {
- if (lastIdx - firstIdx <= block_size) {
- // Single block or less, execute directly.
- f(firstIdx, lastIdx);
- barrier.Notify();
- return;
+ while (lastIdx - firstIdx > block_size) {
+ // Split into halves and schedule the second half on a different thread.
+ const Index midIdx = firstIdx + divup((lastIdx - firstIdx) / 2, block_size) * block_size;
+ pool_->Schedule([=, &handleRange]() { handleRange(midIdx, lastIdx); });
+ lastIdx = midIdx;
}
- // Split into halves and submit to the pool.
- Index mid = firstIdx + divup((lastIdx - firstIdx) / 2, block_size) * block_size;
- pool_->Schedule([=, &handleRange]() { handleRange(mid, lastIdx); });
- handleRange(firstIdx, mid);
+ // Single block or less, execute directly.
+ f(firstIdx, lastIdx);
+ barrier.Notify();
};
- handleRange(0, n);
+ if (block_count <= numThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handleRange(0, n);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ pool_->Schedule([=, &handleRange]() { handleRange(0, n); });
+ }
barrier.Wait();
}