aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h12
1 files changed, 7 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 24ba3e431..3946e2fc4 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -788,9 +788,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
- // The underlying GEMM kernel assumes that k is a multiple of 8 and
- // subtle breakage occurs if this is violated.
- Index block_size = 8 * divup<Index>(k, 8 * num_threads);
+ const Index packet_size = internal::packet_traits<RhsScalar>::size;
+ const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
+ // The underlying GEMM kernel assumes that k is a multiple of
+ // the packet size and subtle breakage occurs if this is violated.
+ Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
Index num_blocks = divup<Index>(k, block_size);
// we use 'result' for the first block's partial result.
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
@@ -805,9 +807,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Index start = 0;
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
// The underlying GEMM kernel assumes that k is a multiple of packet size
- // (currently largest packet size is 8) and subtle breakage occurs if
+ // (currently largest packet size is 16) and subtle breakage occurs if
// this is violated.
- block_size = 8 * divup<Index>(k - start, 8 * blocks_left);
+ block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
Scalar* buf;
if (start == 0) {
buf = result;