From bb7ccac3af90acb15e1bdc3943758ebb2ae22790 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 5 Dec 2019 14:50:19 -0800 Subject: Add recursive work splitting to EvalShardedByInnerDimContext --- .../CXX11/src/Tensor/TensorContractionThreadPool.h | 106 +++++++++++++-------- 1 file changed, 64 insertions(+), 42 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 26c9fac17..21be6ea42 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -1159,16 +1159,7 @@ struct TensorEvaluator void run() { Barrier barrier(internal::convert_index(num_blocks)); - for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { - evaluator->m_device.enqueueNoNotification( - [this, block_idx, &barrier]() { - Index block_start = block_idx * block_size; - Index block_end = block_start + actualBlockSize(block_idx); - - processBlock(block_idx, block_start, block_end); - barrier.Notify(); - }); - } + eval(barrier, 0, num_blocks); barrier.Wait(); // Aggregate partial sums from l0 ranges. @@ -1180,38 +1171,7 @@ struct TensorEvaluator void runAsync() { - for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { - evaluator->m_device.enqueueNoNotification([this, block_idx]() { - Index block_start = block_idx * block_size; - Index block_end = block_start + actualBlockSize(block_idx); - - processBlock(block_idx, block_start, block_end); - - int v = num_pending_blocks.fetch_sub(1); - eigen_assert(v >= 1); - - if (v == 1) { - // Aggregate partial sums from l0 ranges. - aggregateL0Blocks(); - - // Apply output kernel. - applyOutputKernel(); - - // NOTE: If we call `done` callback before deleting this (context), - // it might deallocate Self* pointer captured by context, and we'll - // fail in destructor trying to deallocate temporary buffers. - - // Move done call back from context before it will be destructed. - DoneCallback done_copy = std::move(done); - - // We are confident that we are the last one who touches context. - delete this; - - // Now safely call the done callback. - done_copy(); - } - }); - } + evalAsync(0, num_blocks); } private: @@ -1405,6 +1365,68 @@ struct TensorEvaluator + void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) { + while (end_block_idx - start_block_idx > 1) { + Index mid_block_idx = (start_block_idx + end_block_idx) / 2; + evaluator->m_device.enqueueNoNotification( + [this, &barrier, mid_block_idx, end_block_idx]() { + eval(barrier, mid_block_idx, end_block_idx); + }); + end_block_idx = mid_block_idx; + } + + Index block_idx = start_block_idx; + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock(block_idx, block_start, block_end); + barrier.Notify(); + } + + template + void evalAsync(Index start_block_idx, Index end_block_idx) { + while (end_block_idx - start_block_idx > 1) { + Index mid_block_idx = (start_block_idx + end_block_idx) / 2; + evaluator->m_device.enqueueNoNotification( + [this, mid_block_idx, end_block_idx]() { + evalAsync(mid_block_idx, end_block_idx); + }); + end_block_idx = mid_block_idx; + } + + Index block_idx = start_block_idx; + + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock(block_idx, block_start, block_end); + + int v = num_pending_blocks.fetch_sub(1); + eigen_assert(v >= 1); + + if (v == 1) { + // Aggregate partial sums from l0 ranges. + aggregateL0Blocks(); + + // Apply output kernel. + applyOutputKernel(); + + // NOTE: If we call `done` callback before deleting this (context), + // it might deallocate Self* pointer captured by context, and we'll + // fail in destructor trying to deallocate temporary buffers. + + // Move done call back from context before it will be destructed. + DoneCallback done_copy = std::move(done); + + // We are confident that we are the last one who touches context. + delete this; + + // Now safely call the done callback. + done_copy(); + } + } + // Cost model doesn't capture well the cost associated with constructing // tensor contraction mappers and computing loop bounds in gemm_pack_lhs // and gemm_pack_rhs, so we specify minimum desired block size. -- cgit v1.2.3