aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-12-05 14:50:19 -0800
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-12-05 14:51:49 -0800
commitbb7ccac3af90acb15e1bdc3943758ebb2ae22790 (patch)
tree224f1abc757de515616bdb296dc3216b4c569514 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent25230d1862ecfe3f1bf91c12eefe52dbdc0179b9 (diff)
Add recursive work splitting to EvalShardedByInnerDimContext
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h106
1 files changed, 64 insertions, 42 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 26c9fac17..21be6ea42 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -1159,16 +1159,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void run() {
Barrier barrier(internal::convert_index<int>(num_blocks));
- for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
- evaluator->m_device.enqueueNoNotification(
- [this, block_idx, &barrier]() {
- Index block_start = block_idx * block_size;
- Index block_end = block_start + actualBlockSize(block_idx);
-
- processBlock<Alignment>(block_idx, block_start, block_end);
- barrier.Notify();
- });
- }
+ eval<Alignment>(barrier, 0, num_blocks);
barrier.Wait();
// Aggregate partial sums from l0 ranges.
@@ -1180,38 +1171,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void runAsync() {
- for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
- evaluator->m_device.enqueueNoNotification([this, block_idx]() {
- Index block_start = block_idx * block_size;
- Index block_end = block_start + actualBlockSize(block_idx);
-
- processBlock<Alignment>(block_idx, block_start, block_end);
-
- int v = num_pending_blocks.fetch_sub(1);
- eigen_assert(v >= 1);
-
- if (v == 1) {
- // Aggregate partial sums from l0 ranges.
- aggregateL0Blocks<Alignment>();
-
- // Apply output kernel.
- applyOutputKernel();
-
- // NOTE: If we call `done` callback before deleting this (context),
- // it might deallocate Self* pointer captured by context, and we'll
- // fail in destructor trying to deallocate temporary buffers.
-
- // Move done call back from context before it will be destructed.
- DoneCallback done_copy = std::move(done);
-
- // We are confident that we are the last one who touches context.
- delete this;
-
- // Now safely call the done callback.
- done_copy();
- }
- });
- }
+ evalAsync<Alignment>(0, num_blocks);
}
private:
@@ -1405,6 +1365,68 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
+ template <int Alignment>
+ void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
+ while (end_block_idx - start_block_idx > 1) {
+ Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
+ evaluator->m_device.enqueueNoNotification(
+ [this, &barrier, mid_block_idx, end_block_idx]() {
+ eval<Alignment>(barrier, mid_block_idx, end_block_idx);
+ });
+ end_block_idx = mid_block_idx;
+ }
+
+ Index block_idx = start_block_idx;
+ Index block_start = block_idx * block_size;
+ Index block_end = block_start + actualBlockSize(block_idx);
+
+ processBlock<Alignment>(block_idx, block_start, block_end);
+ barrier.Notify();
+ }
+
+ template <int Alignment>
+ void evalAsync(Index start_block_idx, Index end_block_idx) {
+ while (end_block_idx - start_block_idx > 1) {
+ Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
+ evaluator->m_device.enqueueNoNotification(
+ [this, mid_block_idx, end_block_idx]() {
+ evalAsync<Alignment>(mid_block_idx, end_block_idx);
+ });
+ end_block_idx = mid_block_idx;
+ }
+
+ Index block_idx = start_block_idx;
+
+ Index block_start = block_idx * block_size;
+ Index block_end = block_start + actualBlockSize(block_idx);
+
+ processBlock<Alignment>(block_idx, block_start, block_end);
+
+ int v = num_pending_blocks.fetch_sub(1);
+ eigen_assert(v >= 1);
+
+ if (v == 1) {
+ // Aggregate partial sums from l0 ranges.
+ aggregateL0Blocks<Alignment>();
+
+ // Apply output kernel.
+ applyOutputKernel();
+
+ // NOTE: If we call `done` callback before deleting this (context),
+ // it might deallocate Self* pointer captured by context, and we'll
+ // fail in destructor trying to deallocate temporary buffers.
+
+ // Move done call back from context before it will be destructed.
+ DoneCallback done_copy = std::move(done);
+
+ // We are confident that we are the last one who touches context.
+ delete this;
+
+ // Now safely call the done callback.
+ done_copy();
+ }
+ }
+
// Cost model doesn't capture well the cost associated with constructing
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs
// and gemm_pack_rhs, so we specify minimum desired block size.