aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <eugene.zhulenev@gmail.com>2019-12-17 19:06:14 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-12-17 19:06:14 +0000
commit788bef6ab55bc2897e29be308996b8937da4a38d (patch)
treea0a44da78ca2a8f1156ade1473d1a3489784c803 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
parent7252163335f56f23fcc7381c1efdea47161005fa (diff)
Reduce block evaluation overhead for small tensor expressions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h28
1 files changed, 23 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index e2f1806cb..b90791d8d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
IndexType lastBlockIdx) {
TensorBlockScratch scratch(device);
- for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) {
+ for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx;
+ ++block_idx) {
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
evaluator.evalBlock(desc, scratch);
scratch.reset();
}
};
- device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
- eval_block);
+ // Evaluate small expressions directly as a single block.
+ if (tiling.block_mapper.blockCount() == 1) {
+ TensorBlockScratch scratch(device);
+ TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions());
+ evaluator.evalBlock(desc, scratch);
+ } else {
+ device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
+ eval_block);
+ }
}
evaluator.cleanup();
}
@@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
scratch.reset();
}
};
- ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
- ctx->tiling.cost, eval_block, [ctx]() { delete ctx; });
+
+ // Evaluate small expressions directly as a single block.
+ if (ctx->tiling.block_mapper.blockCount() == 1) {
+ TensorBlockScratch scratch(ctx->device);
+ TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions());
+ ctx->evaluator.evalBlock(desc, scratch);
+ delete ctx;
+ } else {
+ ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
+ ctx->tiling.cost, eval_block,
+ [ctx]() { delete ctx; });
+ }
};
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);