diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 6fc1e4a6e..d1c659858 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -671,12 +671,20 @@ struct TensorContractionEvaluatorBase 0, k, 1); } - template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> - EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const { - // columns in left side, rows in right side - const Index k = this->m_k_size; + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, + bool rhs_inner_dim_reordered, int Alignment> + EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel( + Scalar* buffer, Index k_start, Index k_end, int num_threads) const { + evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, + rhs_inner_dim_reordered, Alignment, + /*use_output_kernel*/ false>(buffer, k_start, k_end, + num_threads); + } - eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k); + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment, bool use_output_kernel = true> + EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const { + eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= this->m_k_size); + // columns in slice on left side, rows on right side const Index k_slice = k_end - k_start; // rows in left side @@ -740,7 +748,7 @@ struct TensorContractionEvaluatorBase const Index actual_mc = numext::mini(i2+mc,m)-i2; for (Index k2 = k_start; k2 < k_end; k2 += kc) { // make sure we don't overshoot right edge of left matrix, then pack vertical panel - const Index actual_kc = numext::mini(k2 + kc, k) - k2; + const Index actual_kc = numext::mini(k2 + kc, k_end) - k2; TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc); @@ -759,7 +767,7 @@ struct TensorContractionEvaluatorBase Scalar(1)); // We are done with this [i2, j2] output block. - if (k2 + kc >= k) { + if (use_output_kernel && k2 + kc >= k_end) { m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, actual_mc, actual_nc); } |