diff options
author | Eugene Zhulenev <eugene.zhulenev@gmail.com> | 2019-12-17 19:06:14 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-12-17 19:06:14 +0000 |
commit | 788bef6ab55bc2897e29be308996b8937da4a38d (patch) | |
tree | a0a44da78ca2a8f1156ade1473d1a3489784c803 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | |
parent | 7252163335f56f23fcc7381c1efdea47161005fa (diff) |
Reduce block evaluation overhead for small tensor expressions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index e2f1806cb..b90791d8d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, IndexType lastBlockIdx) { TensorBlockScratch scratch(device); - for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) { + for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; + ++block_idx) { TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx); evaluator.evalBlock(desc, scratch); scratch.reset(); } }; - device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost, - eval_block); + // Evaluate small expressions directly as a single block. + if (tiling.block_mapper.blockCount() == 1) { + TensorBlockScratch scratch(device); + TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions()); + evaluator.evalBlock(desc, scratch); + } else { + device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost, + eval_block); + } } evaluator.cleanup(); } @@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, scratch.reset(); } }; - ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), - ctx->tiling.cost, eval_block, [ctx]() { delete ctx; }); + + // Evaluate small expressions directly as a single block. + if (ctx->tiling.block_mapper.blockCount() == 1) { + TensorBlockScratch scratch(ctx->device); + TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions()); + ctx->evaluator.evalBlock(desc, scratch); + delete ctx; + } else { + ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), + ctx->tiling.cost, eval_block, + [ctx]() { delete ctx; }); + } }; ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); |