diff options
author | Eugene Zhulenev <eugene.zhulenev@gmail.com> | 2019-12-17 19:06:14 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-12-17 19:06:14 +0000 |
commit | 788bef6ab55bc2897e29be308996b8937da4a38d (patch) | |
tree | a0a44da78ca2a8f1156ade1473d1a3489784c803 /unsupported/Eigen/CXX11/src | |
parent | 7252163335f56f23fcc7381c1efdea47161005fa (diff) |
Reduce block evaluation overhead for small tensor expressions
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h | 40 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 28 |
2 files changed, 48 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h index 222333847..dc9af3aa8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h @@ -282,19 +282,8 @@ class TensorBlockMapper { TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions, const TensorBlockResourceRequirements& requirements) : m_tensor_dimensions(dimensions), m_requirements(requirements) { - // Initialize `m_block_dimensions`. + // Compute block dimensions and the total number of blocks. InitializeBlockDimensions(); - - // Calculate block counts by dimension and total block count. - DSizes<IndexType, NumDims> block_count; - for (int i = 0; i < NumDims; ++i) { - block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]); - } - m_total_block_count = array_prod(block_count); - - // Calculate block strides (used for enumerating blocks). - m_tensor_strides = strides<Layout>(m_tensor_dimensions); - m_block_strides = strides<Layout>(block_count); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const { @@ -339,23 +328,33 @@ class TensorBlockMapper { void InitializeBlockDimensions() { // Requested block shape and size. const TensorBlockShapeType shape_type = m_requirements.shape_type; - const IndexType target_block_size = + IndexType target_block_size = numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size)); + IndexType tensor_size = m_tensor_dimensions.TotalSize(); + // Corner case: one of the dimensions is zero. Logic below is too complex // to handle this case on a general basis, just use unit block size. // Note: we must not yield blocks with zero dimensions (recipe for // overflows/underflows, divisions by zero and NaNs later). - if (m_tensor_dimensions.TotalSize() == 0) { + if (tensor_size == 0) { for (int i = 0; i < NumDims; ++i) { m_block_dimensions[i] = 1; } + m_total_block_count = 0; return; } // If tensor fits into a target block size, evaluate it as a single block. - if (m_tensor_dimensions.TotalSize() <= target_block_size) { + if (tensor_size <= target_block_size) { m_block_dimensions = m_tensor_dimensions; + m_total_block_count = 1; + // The only valid block index is `0`, and in this case we do not need + // to compute real strides for tensor or blocks (see blockDescriptor). + for (int i = 0; i < NumDims; ++i) { + m_tensor_strides[i] = 0; + m_block_strides[i] = 1; + } return; } @@ -418,6 +417,17 @@ class TensorBlockMapper { eigen_assert(m_block_dimensions.TotalSize() >= numext::mini<IndexType>(target_block_size, m_tensor_dimensions.TotalSize())); + + // Calculate block counts by dimension and total block count. + DSizes<IndexType, NumDims> block_count; + for (int i = 0; i < NumDims; ++i) { + block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]); + } + m_total_block_count = array_prod(block_count); + + // Calculate block strides (used for enumerating blocks). + m_tensor_strides = strides<Layout>(m_tensor_dimensions); + m_block_strides = strides<Layout>(block_count); } DSizes<IndexType, NumDims> m_tensor_dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index e2f1806cb..b90791d8d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, IndexType lastBlockIdx) { TensorBlockScratch scratch(device); - for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) { + for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; + ++block_idx) { TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx); evaluator.evalBlock(desc, scratch); scratch.reset(); } }; - device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost, - eval_block); + // Evaluate small expressions directly as a single block. + if (tiling.block_mapper.blockCount() == 1) { + TensorBlockScratch scratch(device); + TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions()); + evaluator.evalBlock(desc, scratch); + } else { + device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost, + eval_block); + } } evaluator.cleanup(); } @@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, scratch.reset(); } }; - ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), - ctx->tiling.cost, eval_block, [ctx]() { delete ctx; }); + + // Evaluate small expressions directly as a single block. + if (ctx->tiling.block_mapper.blockCount() == 1) { + TensorBlockScratch scratch(ctx->device); + TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions()); + ctx->evaluator.evalBlock(desc, scratch); + delete ctx; + } else { + ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), + ctx->tiling.cost, eval_block, + [ctx]() { delete ctx; }); + } }; ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); |