diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-12-05 18:19:32 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-12-05 18:19:32 +0000 |
commit | 8a02883d58e6d7de385ca66e64eeda3c431bf36a (patch) | |
tree | 88c75ddff28fad49db214b2d070f07fabdfd60f9 /unsupported | |
parent | acc3459a49707c92ee96a710e05d7e18e144c780 (diff) | |
parent | 36f8f6d0be1543e12c87c6f33df46fe7bcecab87 (diff) |
Merged in markdryan/eigen/avx512-contraction-2 (pull request PR-554)
Fix tensor contraction on AVX512 builds
Approved-by: Rasmus Munk Larsen <rmlarsen@google.com>
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h | 4 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 12 |
2 files changed, 11 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h index 71fd19774..c51f3f8dd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h @@ -51,6 +51,10 @@ class TensorContractionBlocking { else { computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads); } + + const int rhs_packet_size = internal::packet_traits<RhsScalar>::size; + kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ? + kc_ : (kc_ / rhs_packet_size) * rhs_packet_size; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 24ba3e431..3946e2fc4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -788,9 +788,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index m = this->m_i_size; const Index n = this->m_j_size; const Index k = this->m_k_size; - // The underlying GEMM kernel assumes that k is a multiple of 8 and - // subtle breakage occurs if this is violated. - Index block_size = 8 * divup<Index>(k, 8 * num_threads); + const Index packet_size = internal::packet_traits<RhsScalar>::size; + const Index kmultiple = packet_size <= 8 ? 8 : packet_size; + // The underlying GEMM kernel assumes that k is a multiple of + // the packet size and subtle breakage occurs if this is violated. + Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads); Index num_blocks = divup<Index>(k, block_size); // we use 'result' for the first block's partial result. MaxSizeVector<Scalar*> block_buffers(num_blocks - 1); @@ -805,9 +807,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Index start = 0; for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) { // The underlying GEMM kernel assumes that k is a multiple of packet size - // (currently largest packet size is 8) and subtle breakage occurs if + // (currently largest packet size is 16) and subtle breakage occurs if // this is violated. - block_size = 8 * divup<Index>(k - start, 8 * blocks_left); + block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left); Scalar* buf; if (start == 0) { buf = result; |