aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 12:08:17 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 12:08:17 -0700
commit9f33e71e9d33b51735841e40dfa49bda9d7fe5ff (patch)
tree6b518faf3facbd815e98651bbfc7bddef516b5fa /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parenta7a3e9f2b6dfa97887fd44b6d8f658c4928c799d (diff)
Revert code lost in merge
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index a4df45098..b92753c44 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -672,11 +672,12 @@ struct TensorContractionEvaluatorBase
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int /*num_threads*/) const {
+ EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;
eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k);
+ const Index k_slice = k_end - k_start;
// rows in left side
const Index m = this->m_i_size;
@@ -722,7 +723,9 @@ struct TensorContractionEvaluatorBase
OutputMapper output(buffer, m);
// Sizes of the blocks to load in cache. See the Goto paper for details.
- internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
+ internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar,
+ Index, internal::ShardByCol>
+ blocking(k_slice, m, n, num_threads);
const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc());