aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <eugene.zhulenev@gmail.com>2019-12-17 19:06:14 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-12-17 19:06:14 +0000
commit788bef6ab55bc2897e29be308996b8937da4a38d (patch)
treea0a44da78ca2a8f1156ade1473d1a3489784c803 /unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
parent7252163335f56f23fcc7381c1efdea47161005fa (diff)
Reduce block evaluation overhead for small tensor expressions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h40
1 files changed, 25 insertions, 15 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
index 222333847..dc9af3aa8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
@@ -282,19 +282,8 @@ class TensorBlockMapper {
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
const TensorBlockResourceRequirements& requirements)
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
- // Initialize `m_block_dimensions`.
+ // Compute block dimensions and the total number of blocks.
InitializeBlockDimensions();
-
- // Calculate block counts by dimension and total block count.
- DSizes<IndexType, NumDims> block_count;
- for (int i = 0; i < NumDims; ++i) {
- block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
- }
- m_total_block_count = array_prod(block_count);
-
- // Calculate block strides (used for enumerating blocks).
- m_tensor_strides = strides<Layout>(m_tensor_dimensions);
- m_block_strides = strides<Layout>(block_count);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
@@ -339,23 +328,33 @@ class TensorBlockMapper {
void InitializeBlockDimensions() {
// Requested block shape and size.
const TensorBlockShapeType shape_type = m_requirements.shape_type;
- const IndexType target_block_size =
+ IndexType target_block_size =
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
+ IndexType tensor_size = m_tensor_dimensions.TotalSize();
+
// Corner case: one of the dimensions is zero. Logic below is too complex
// to handle this case on a general basis, just use unit block size.
// Note: we must not yield blocks with zero dimensions (recipe for
// overflows/underflows, divisions by zero and NaNs later).
- if (m_tensor_dimensions.TotalSize() == 0) {
+ if (tensor_size == 0) {
for (int i = 0; i < NumDims; ++i) {
m_block_dimensions[i] = 1;
}
+ m_total_block_count = 0;
return;
}
// If tensor fits into a target block size, evaluate it as a single block.
- if (m_tensor_dimensions.TotalSize() <= target_block_size) {
+ if (tensor_size <= target_block_size) {
m_block_dimensions = m_tensor_dimensions;
+ m_total_block_count = 1;
+ // The only valid block index is `0`, and in this case we do not need
+ // to compute real strides for tensor or blocks (see blockDescriptor).
+ for (int i = 0; i < NumDims; ++i) {
+ m_tensor_strides[i] = 0;
+ m_block_strides[i] = 1;
+ }
return;
}
@@ -418,6 +417,17 @@ class TensorBlockMapper {
eigen_assert(m_block_dimensions.TotalSize() >=
numext::mini<IndexType>(target_block_size,
m_tensor_dimensions.TotalSize()));
+
+ // Calculate block counts by dimension and total block count.
+ DSizes<IndexType, NumDims> block_count;
+ for (int i = 0; i < NumDims; ++i) {
+ block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
+ }
+ m_total_block_count = array_prod(block_count);
+
+ // Calculate block strides (used for enumerating blocks).
+ m_tensor_strides = strides<Layout>(m_tensor_dimensions);
+ m_block_strides = strides<Layout>(block_count);
}
DSizes<IndexType, NumDims> m_tensor_dimensions;