aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-02-04 12:59:33 -0800
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-02-04 12:59:33 -0800
commit8491127082e5f6568983255a459ca737271aaf3f (patch)
tree2006ecddbc8a833085e412529552ff1b65baa022 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parenteb21bab769b11546d08f7db0b5bb78bfde6cdbae (diff)
Do not reduce parallelism too much in contractions with small number of threads
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h15
1 files changed, 12 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 4932514c7..4af8d3b18 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -339,10 +339,19 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// If there is enough available parallelism in sharding dimension we can
// call kernels in sync mode and use thread local memory for packed data.
const Index sharding_dim_tasks = shard_by_col ? nn : nm;
- if (!parallel_pack_ && sharding_dim_tasks >= device_.numThreadsInPool()) {
- parallelize_by_sharding_dim_only_ = true;
- int num_worker_threads = device_.numThreadsInPool();
+ const int num_worker_threads = device_.numThreadsInPool();
+
+ // With small number of threads we want to make sure that we do not reduce
+ // parallelism too much.
+ const int oversharding_factor =
+ num_worker_threads <= 4 ? 8 :
+ num_worker_threads <= 8 ? 4 :
+ num_worker_threads <= 16 ? 2 : 1;
+
+ if (!parallel_pack_ &&
+ sharding_dim_tasks >= oversharding_factor * num_worker_threads) {
+ parallelize_by_sharding_dim_only_ = true;
if (shard_by_col) {
can_use_thread_local_packed_ = new std::atomic<bool>[nn_];