aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 12:05:06 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 12:05:06 -0700
commita7a3e9f2b6dfa97887fd44b6d8f658c4928c799d (patch)
treedf58dc8f6c1414abd73fc1d7887020ef47d38492 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent9f4988959f1b0394ee027f474f49916543ad2f3c (diff)
parent1e5750a5b896089b4455cf4940b4fe88d99b3293 (diff)
Merge with eigen/eigen default
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h20
1 files changed, 19 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index d220f82be..a4df45098 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -665,6 +665,24 @@ struct TensorContractionEvaluatorBase
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
+ this->template evalGemmPartial<lhs_inner_dim_contiguous,
+ rhs_inner_dim_contiguous,
+ rhs_inner_dim_reordered, Alignment>(buffer,
+ 0, k, 1);
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int /*num_threads*/) const {
+ // columns in left side, rows in right side
+ const Index k = this->m_k_size;
+
+ eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k);
+
+ // rows in left side
+ const Index m = this->m_i_size;
+
+ // columns in right side
+ const Index n = this->m_j_size;
// define data mappers for Lhs and Rhs
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
@@ -717,7 +735,7 @@ struct TensorContractionEvaluatorBase
for(Index i2=0; i2<m; i2+=mc)
{
const Index actual_mc = numext::mini(i2+mc,m)-i2;
- for (Index k2 = 0; k2 < k; k2 += kc) {
+ for (Index k2 = k_start; k2 < k_end; k2 += kc) {
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
const Index actual_kc = numext::mini(k2 + kc, k) - k2;
TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2),