diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 3b22e43e7..ea17a897d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -590,6 +590,25 @@ struct TensorContractionEvaluatorBase // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + this->template evalGemmPartial<lhs_inner_dim_contiguous, + rhs_inner_dim_contiguous, + rhs_inner_dim_reordered, Alignment>(buffer, + 0, k, 1); + } + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> + EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const { + // columns in left side, rows in right side + const Index k = this->m_k_size; + + eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k); + const Index k_slice = k_end - k_start; + + // rows in left side + const Index m = this->m_i_size; + + // columns in right side + const Index n = this->m_j_size; // define mr, nr, and all of my data mapper types typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; @@ -620,7 +639,7 @@ struct TensorContractionEvaluatorBase typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; // Declare GEBP packing and kernel structs - internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> pack_lhs; + internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> pack_lhs; internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs; internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp; @@ -635,7 +654,7 @@ struct TensorContractionEvaluatorBase OutputMapper output(buffer, m); // Sizes of the blocks to load in cache. See the Goto paper for details. - internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1); + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k_slice, m, n, num_threads); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); @@ -648,7 +667,7 @@ struct TensorContractionEvaluatorBase for(Index i2=0; i2<m; i2+=mc) { const Index actual_mc = numext::mini(i2+mc,m)-i2; - for (Index k2 = 0; k2 < k; k2 += kc) { + for (Index k2 = k_start; k2 < k_end; k2 += kc) { // make sure we don't overshoot right edge of left matrix, then pack vertical panel const Index actual_kc = numext::mini(k2 + kc, k) - k2; pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0); |