aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h15
1 files changed, 12 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 4553c3785..675201d23 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -798,14 +798,15 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
auto process_block = [=, &barrier](Scalar* buf, Index first, Index last) {
::memset(buf, 0, m * n * sizeof(Scalar));
TENSOR_CONTRACTION_DISPATCH(
- this->template evalGemmPartial, Alignment,
+ this->template evalGemmPartialWithoutOutputKernel, Alignment,
(buf, first, last, this->m_device.numThreads()));
barrier.Notify();
};
Index start = 0;
for (int blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
- // The underlying GEMM kernel assumes that k is a multiple of 8 and
- // subtle breakage occurs if this is violated.
+ // The underlying GEMM kernel assumes that k is a multiple of packet size
+ // (currently largest packet size is 8) and subtle breakage occurs if
+ // this is violated.
block_size = 8 * divup<Index>(k - start, 8 * blocks_left);
Scalar* buf;
if (start == 0) {
@@ -830,6 +831,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
addToBuffer<Alignment>(m * n, buf, result);
this->m_device.deallocate(buf);
}
+
+ // Finally call output kernel with finalized output buffer.
+ typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
+ this->m_output_kernel(OutputMapper(result, m),
+ this->m_tensor_contraction_params,
+ static_cast<Eigen::Index>(0),
+ static_cast<Eigen::Index>(0),
+ m, n);
}
TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {