aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Mark D Ryan <mark.d.ryan@intel.com>2018-12-05 12:29:03 +0100
committerGravatar Mark D Ryan <mark.d.ryan@intel.com>2018-12-05 12:29:03 +0100
commit36f8f6d0be1543e12c87c6f33df46fe7bcecab87 (patch)
tree27301b7ae9e48dda5861e27f5a66e97778aef5af /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent6f5b126e6d23f1339d15b26fe87916132397d619 (diff)
Fix evalShardedByInnerDim for AVX512 builds
evalShardedByInnerDim ensures that the values it passes for start_k and end_k to evalGemmPartialWithoutOutputKernel are multiples of 8 as the kernel does not work correctly when the values of k are not multiples of the packet_size. While this precaution works for AVX builds, it is insufficient for AVX512 builds where the maximum packet size is 16. The result is slightly incorrect float32 contractions on AVX512 builds. This commit fixes the problem by ensuring that k is always a multiple of the packet_size if the packet_size is > 8.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h12
1 files changed, 7 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 24ba3e431..3946e2fc4 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -788,9 +788,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
- // The underlying GEMM kernel assumes that k is a multiple of 8 and
- // subtle breakage occurs if this is violated.
- Index block_size = 8 * divup<Index>(k, 8 * num_threads);
+ const Index packet_size = internal::packet_traits<RhsScalar>::size;
+ const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
+ // The underlying GEMM kernel assumes that k is a multiple of
+ // the packet size and subtle breakage occurs if this is violated.
+ Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
Index num_blocks = divup<Index>(k, block_size);
// we use 'result' for the first block's partial result.
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
@@ -805,9 +807,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Index start = 0;
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
// The underlying GEMM kernel assumes that k is a multiple of packet size
- // (currently largest packet size is 8) and subtle breakage occurs if
+ // (currently largest packet size is 16) and subtle breakage occurs if
// this is violated.
- block_size = 8 * divup<Index>(k - start, 8 * blocks_left);
+ block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
Scalar* buf;
if (start == 0) {
buf = result;