diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 4553c3785..675201d23 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -798,14 +798,15 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT auto process_block = [=, &barrier](Scalar* buf, Index first, Index last) { ::memset(buf, 0, m * n * sizeof(Scalar)); TENSOR_CONTRACTION_DISPATCH( - this->template evalGemmPartial, Alignment, + this->template evalGemmPartialWithoutOutputKernel, Alignment, (buf, first, last, this->m_device.numThreads())); barrier.Notify(); }; Index start = 0; for (int blocks_left = num_blocks; blocks_left > 0; --blocks_left) { - // The underlying GEMM kernel assumes that k is a multiple of 8 and - // subtle breakage occurs if this is violated. + // The underlying GEMM kernel assumes that k is a multiple of packet size + // (currently largest packet size is 8) and subtle breakage occurs if + // this is violated. block_size = 8 * divup<Index>(k - start, 8 * blocks_left); Scalar* buf; if (start == 0) { @@ -830,6 +831,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT addToBuffer<Alignment>(m * n, buf, result); this->m_device.deallocate(buf); } + + // Finally call output kernel with finalized output buffer. + typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; + this->m_output_kernel(OutputMapper(result, m), + this->m_tensor_contraction_params, + static_cast<Eigen::Index>(0), + static_cast<Eigen::Index>(0), + m, n); } TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const { |