aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-02-07 09:21:25 -0800
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-02-07 09:21:25 -0800
commit59998117bb0e4e0dc4b37b062f02ea5e6aab711e (patch)
treea264408f3b00aa5c7cfb04daf935a50eddd380ee /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent013cc3a6b39c5962a3261a063d2a4ab4810cb757 (diff)
Don't do parallel_pack if we can use thread_local memory in tensor contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h55
1 files changed, 30 insertions, 25 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 4af8d3b18..d7cd995fb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -208,6 +208,23 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Index nm = divup(nm0, gm);
Index nn = divup(nn0, gn);
+ // If there is enough concurrency in the sharding dimension, we choose not
+ // to paralellize by the other dimension, and execute all kernels in sync
+ // mode. This reduces parallelism from the nm x nn down to nn
+ // (shard_by_col==true) or nm (shard_by_col==false).
+ const Index sharding_dim_tasks = shard_by_col ? nn : nm;
+ const int num_worker_threads = this->m_device.numThreadsInPool();
+
+ // With small number of threads we want to make sure that we do not reduce
+ // parallelism too much.
+ const int oversharding_factor =
+ num_worker_threads <= 4 ? 8 :
+ num_worker_threads <= 8 ? 4 :
+ num_worker_threads <= 16 ? 2 : 1;
+
+ const bool parallelize_by_sharding_dim_only =
+ sharding_dim_tasks >= oversharding_factor * num_worker_threads;
+
// Last by not least, decide whether we want to issue both lhs and rhs
// packing in parallel; or issue lhs packing first, and then issue rhs
// packing when lhs packing completes (for !shard_by_col lhs and rhs are
@@ -223,10 +240,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// But don't do it if we will use each rhs only once. Locality seems to be
// more important in this case.
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
+ // Also don't get in the way of parallelize_by_sharding_dim_only
+ // optimization.
+ if (parallelize_by_sharding_dim_only) parallel_pack = false;
- #define CONTEXT_ARGS \
+#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
- nn0, shard_by_col, parallel_pack) \
+ nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only) \
.run()
TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
@@ -260,7 +280,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
- bool parallel_pack)
+ bool parallel_pack, bool parallelize_by_sharding_dim_only)
: device_(self->m_device),
lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
self->m_i_strides, self->m_left_contracting_strides,
@@ -275,6 +295,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
num_threads_(num_threads),
shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack),
+ parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
m_(tm),
n_(tn),
k_(tk),
@@ -289,6 +310,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
nm0_(nm0),
nn0_(nn0)
{
+ // These two options are mutually exclusive.
+ eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
+
for (Index x = 0; x < P; x++) {
// Normal number of notifications for k slice switch is
// nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
@@ -336,22 +360,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
- // If there is enough available parallelism in sharding dimension we can
- // call kernels in sync mode and use thread local memory for packed data.
- const Index sharding_dim_tasks = shard_by_col ? nn : nm;
-
- const int num_worker_threads = device_.numThreadsInPool();
-
- // With small number of threads we want to make sure that we do not reduce
- // parallelism too much.
- const int oversharding_factor =
- num_worker_threads <= 4 ? 8 :
- num_worker_threads <= 8 ? 4 :
- num_worker_threads <= 16 ? 2 : 1;
-
- if (!parallel_pack_ &&
- sharding_dim_tasks >= oversharding_factor * num_worker_threads) {
- parallelize_by_sharding_dim_only_ = true;
+ if (parallelize_by_sharding_dim_only_) {
+ const int num_worker_threads = device_.numThreadsInPool();
if (shard_by_col) {
can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
@@ -422,6 +432,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const int num_threads_;
const bool shard_by_col_;
const bool parallel_pack_;
+ const bool parallelize_by_sharding_dim_only_;
// Matrix sizes.
const Index m_;
const Index n_;
@@ -481,12 +492,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
std::vector<LhsScalar*> packed_lhs_[P - 1];
std::vector<RhsScalar*> packed_rhs_[P - 1];
- // If there is enough concurrency in the sharding dimension, we choose not
- // to paralellize by the other dimension, and execute all kernels in sync
- // mode. This reduces parallelism from the nm_ x nn_ down to nn_
- // (shard_by_col==true) or nm_ (shard_by_col==false).
- bool parallelize_by_sharding_dim_only_ = false;
-
// If we choose to parallelize only by the sharding dimension, each thread
// will have it's own "thead local" (not a c++ thread local storage) memory
// for packed_lhs or packed_rhs (shard_by_col = false of true). This memory