diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-12-05 14:50:19 -0800 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-12-05 14:51:49 -0800 |
commit | bb7ccac3af90acb15e1bdc3943758ebb2ae22790 (patch) | |
tree | 224f1abc757de515616bdb296dc3216b4c569514 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | |
parent | 25230d1862ecfe3f1bf91c12eefe52dbdc0179b9 (diff) |
Add recursive work splitting to EvalShardedByInnerDimContext
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 106 |
1 files changed, 64 insertions, 42 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 26c9fac17..21be6ea42 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -1159,16 +1159,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT template <int Alignment> void run() { Barrier barrier(internal::convert_index<int>(num_blocks)); - for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { - evaluator->m_device.enqueueNoNotification( - [this, block_idx, &barrier]() { - Index block_start = block_idx * block_size; - Index block_end = block_start + actualBlockSize(block_idx); - - processBlock<Alignment>(block_idx, block_start, block_end); - barrier.Notify(); - }); - } + eval<Alignment>(barrier, 0, num_blocks); barrier.Wait(); // Aggregate partial sums from l0 ranges. @@ -1180,38 +1171,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT template <int Alignment> void runAsync() { - for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { - evaluator->m_device.enqueueNoNotification([this, block_idx]() { - Index block_start = block_idx * block_size; - Index block_end = block_start + actualBlockSize(block_idx); - - processBlock<Alignment>(block_idx, block_start, block_end); - - int v = num_pending_blocks.fetch_sub(1); - eigen_assert(v >= 1); - - if (v == 1) { - // Aggregate partial sums from l0 ranges. - aggregateL0Blocks<Alignment>(); - - // Apply output kernel. - applyOutputKernel(); - - // NOTE: If we call `done` callback before deleting this (context), - // it might deallocate Self* pointer captured by context, and we'll - // fail in destructor trying to deallocate temporary buffers. - - // Move done call back from context before it will be destructed. - DoneCallback done_copy = std::move(done); - - // We are confident that we are the last one who touches context. - delete this; - - // Now safely call the done callback. - done_copy(); - } - }); - } + evalAsync<Alignment>(0, num_blocks); } private: @@ -1405,6 +1365,68 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } } + template <int Alignment> + void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) { + while (end_block_idx - start_block_idx > 1) { + Index mid_block_idx = (start_block_idx + end_block_idx) / 2; + evaluator->m_device.enqueueNoNotification( + [this, &barrier, mid_block_idx, end_block_idx]() { + eval<Alignment>(barrier, mid_block_idx, end_block_idx); + }); + end_block_idx = mid_block_idx; + } + + Index block_idx = start_block_idx; + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock<Alignment>(block_idx, block_start, block_end); + barrier.Notify(); + } + + template <int Alignment> + void evalAsync(Index start_block_idx, Index end_block_idx) { + while (end_block_idx - start_block_idx > 1) { + Index mid_block_idx = (start_block_idx + end_block_idx) / 2; + evaluator->m_device.enqueueNoNotification( + [this, mid_block_idx, end_block_idx]() { + evalAsync<Alignment>(mid_block_idx, end_block_idx); + }); + end_block_idx = mid_block_idx; + } + + Index block_idx = start_block_idx; + + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock<Alignment>(block_idx, block_start, block_end); + + int v = num_pending_blocks.fetch_sub(1); + eigen_assert(v >= 1); + + if (v == 1) { + // Aggregate partial sums from l0 ranges. + aggregateL0Blocks<Alignment>(); + + // Apply output kernel. + applyOutputKernel(); + + // NOTE: If we call `done` callback before deleting this (context), + // it might deallocate Self* pointer captured by context, and we'll + // fail in destructor trying to deallocate temporary buffers. + + // Move done call back from context before it will be destructed. + DoneCallback done_copy = std::move(done); + + // We are confident that we are the last one who touches context. + delete this; + + // Now safely call the done callback. + done_copy(); + } + } + // Cost model doesn't capture well the cost associated with constructing // tensor contraction mappers and computing loop bounds in gemm_pack_lhs // and gemm_pack_rhs, so we specify minimum desired block size. |