aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h22
1 files changed, 15 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 6fc1e4a6e..d1c659858 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -671,12 +671,20 @@ struct TensorContractionEvaluatorBase
0, k, 1);
}
- template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const {
- // columns in left side, rows in right side
- const Index k = this->m_k_size;
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
+ bool rhs_inner_dim_reordered, int Alignment>
+ EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel(
+ Scalar* buffer, Index k_start, Index k_end, int num_threads) const {
+ evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
+ rhs_inner_dim_reordered, Alignment,
+ /*use_output_kernel*/ false>(buffer, k_start, k_end,
+ num_threads);
+ }
- eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k);
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment, bool use_output_kernel = true>
+ EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const {
+ eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= this->m_k_size);
+ // columns in slice on left side, rows on right side
const Index k_slice = k_end - k_start;
// rows in left side
@@ -740,7 +748,7 @@ struct TensorContractionEvaluatorBase
const Index actual_mc = numext::mini(i2+mc,m)-i2;
for (Index k2 = k_start; k2 < k_end; k2 += kc) {
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
- const Index actual_kc = numext::mini(k2 + kc, k) - k2;
+ const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2),
actual_kc, actual_mc);
@@ -759,7 +767,7 @@ struct TensorContractionEvaluatorBase
Scalar(1));
// We are done with this [i2, j2] output block.
- if (k2 + kc >= k) {
+ if (use_output_kernel && k2 + kc >= k_end) {
m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2,
actual_mc, actual_nc);
}