From 524c81f3fad1548a92504d92326f3622075ed77b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 28 Sep 2018 11:24:08 -0700 Subject: Add tests for evalShardedByInnerDim contraction + fix bugs --- .../Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 4553c3785..675201d23 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -798,14 +798,15 @@ struct TensorEvaluatortemplate evalGemmPartial, Alignment, + this->template evalGemmPartialWithoutOutputKernel, Alignment, (buf, first, last, this->m_device.numThreads())); barrier.Notify(); }; Index start = 0; for (int blocks_left = num_blocks; blocks_left > 0; --blocks_left) { - // The underlying GEMM kernel assumes that k is a multiple of 8 and - // subtle breakage occurs if this is violated. + // The underlying GEMM kernel assumes that k is a multiple of packet size + // (currently largest packet size is 8) and subtle breakage occurs if + // this is violated. block_size = 8 * divup(k - start, 8 * blocks_left); Scalar* buf; if (start == 0) { @@ -830,6 +831,14 @@ struct TensorEvaluator(m * n, buf, result); this->m_device.deallocate(buf); } + + // Finally call output kernel with finalized output buffer. + typedef internal::blas_data_mapper OutputMapper; + this->m_output_kernel(OutputMapper(result, m), + this->m_tensor_contraction_params, + static_cast(0), + static_cast(0), + m, n); } TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const { -- cgit v1.2.3