aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-06-20 17:51:48 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-06-20 17:51:48 -0700
commit5418154a45db637211e94f11ee04c6ae4dc8cf85 (patch)
tree5262c4d27a7f35739fddb48031e36e90b8ef2556
parentb8271bb368d4d2be11f9f493495840481d2e5f2a (diff)
Fix oversharding bug in parallelFor.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h11
1 files changed, 7 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
index ca9ba402e..90fd99027 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h
@@ -189,9 +189,11 @@ struct ThreadPoolDevice {
// of blocks to be evenly dividable across threads.
double block_size_f = 1.0 / CostModel::taskSize(1, cost);
- Index block_size = numext::mini(n, numext::maxi<Index>(1, block_size_f));
- const Index max_block_size =
- numext::mini(n, numext::maxi<Index>(1, 2 * block_size_f));
+ const Index max_oversharding_factor = 4;
+ Index block_size = numext::mini(
+ n, numext::maxi<Index>(divup<Index>(n, max_oversharding_factor * numThreads()),
+ block_size_f));
+ const Index max_block_size = numext::mini(n, 2 * block_size);
if (block_align) {
Index new_block_size = block_align(block_size);
eigen_assert(new_block_size >= block_size);
@@ -205,7 +207,8 @@ struct ThreadPoolDevice {
(divup<int>(block_count, numThreads()) * numThreads());
// Now try to increase block size up to max_block_size as long as it
// doesn't decrease parallel efficiency.
- for (Index prev_block_count = block_count; prev_block_count > 1;) {
+ for (Index prev_block_count = block_count;
+ max_efficiency < 1.0 && prev_block_count > 1;) {
// This is the next block size that divides size into a smaller number
// of blocks than the current block_size.
Index coarser_block_size = divup(n, prev_block_count - 1);