aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <eugene.zhulenev@gmail.com>2019-12-17 19:06:14 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-12-17 19:06:14 +0000
commit788bef6ab55bc2897e29be308996b8937da4a38d (patch)
treea0a44da78ca2a8f1156ade1473d1a3489784c803 /unsupported/Eigen/CXX11
parent7252163335f56f23fcc7381c1efdea47161005fa (diff)
Reduce block evaluation overhead for small tensor expressions
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h40
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h28
2 files changed, 48 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
index 222333847..dc9af3aa8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
@@ -282,19 +282,8 @@ class TensorBlockMapper {
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
const TensorBlockResourceRequirements& requirements)
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
- // Initialize `m_block_dimensions`.
+ // Compute block dimensions and the total number of blocks.
InitializeBlockDimensions();
-
- // Calculate block counts by dimension and total block count.
- DSizes<IndexType, NumDims> block_count;
- for (int i = 0; i < NumDims; ++i) {
- block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
- }
- m_total_block_count = array_prod(block_count);
-
- // Calculate block strides (used for enumerating blocks).
- m_tensor_strides = strides<Layout>(m_tensor_dimensions);
- m_block_strides = strides<Layout>(block_count);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
@@ -339,23 +328,33 @@ class TensorBlockMapper {
void InitializeBlockDimensions() {
// Requested block shape and size.
const TensorBlockShapeType shape_type = m_requirements.shape_type;
- const IndexType target_block_size =
+ IndexType target_block_size =
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
+ IndexType tensor_size = m_tensor_dimensions.TotalSize();
+
// Corner case: one of the dimensions is zero. Logic below is too complex
// to handle this case on a general basis, just use unit block size.
// Note: we must not yield blocks with zero dimensions (recipe for
// overflows/underflows, divisions by zero and NaNs later).
- if (m_tensor_dimensions.TotalSize() == 0) {
+ if (tensor_size == 0) {
for (int i = 0; i < NumDims; ++i) {
m_block_dimensions[i] = 1;
}
+ m_total_block_count = 0;
return;
}
// If tensor fits into a target block size, evaluate it as a single block.
- if (m_tensor_dimensions.TotalSize() <= target_block_size) {
+ if (tensor_size <= target_block_size) {
m_block_dimensions = m_tensor_dimensions;
+ m_total_block_count = 1;
+ // The only valid block index is `0`, and in this case we do not need
+ // to compute real strides for tensor or blocks (see blockDescriptor).
+ for (int i = 0; i < NumDims; ++i) {
+ m_tensor_strides[i] = 0;
+ m_block_strides[i] = 1;
+ }
return;
}
@@ -418,6 +417,17 @@ class TensorBlockMapper {
eigen_assert(m_block_dimensions.TotalSize() >=
numext::mini<IndexType>(target_block_size,
m_tensor_dimensions.TotalSize()));
+
+ // Calculate block counts by dimension and total block count.
+ DSizes<IndexType, NumDims> block_count;
+ for (int i = 0; i < NumDims; ++i) {
+ block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
+ }
+ m_total_block_count = array_prod(block_count);
+
+ // Calculate block strides (used for enumerating blocks).
+ m_tensor_strides = strides<Layout>(m_tensor_dimensions);
+ m_block_strides = strides<Layout>(block_count);
}
DSizes<IndexType, NumDims> m_tensor_dimensions;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index e2f1806cb..b90791d8d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
IndexType lastBlockIdx) {
TensorBlockScratch scratch(device);
- for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) {
+ for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx;
+ ++block_idx) {
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
evaluator.evalBlock(desc, scratch);
scratch.reset();
}
};
- device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
- eval_block);
+ // Evaluate small expressions directly as a single block.
+ if (tiling.block_mapper.blockCount() == 1) {
+ TensorBlockScratch scratch(device);
+ TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions());
+ evaluator.evalBlock(desc, scratch);
+ } else {
+ device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
+ eval_block);
+ }
}
evaluator.cleanup();
}
@@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
scratch.reset();
}
};
- ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
- ctx->tiling.cost, eval_block, [ctx]() { delete ctx; });
+
+ // Evaluate small expressions directly as a single block.
+ if (ctx->tiling.block_mapper.blockCount() == 1) {
+ TensorBlockScratch scratch(ctx->device);
+ TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions());
+ ctx->evaluator.evalBlock(desc, scratch);
+ delete ctx;
+ } else {
+ ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
+ ctx->tiling.cost, eval_block,
+ [ctx]() { delete ctx; });
+ }
};
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);