aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-03 08:51:33 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-03 08:51:33 -0800
commitb1789c112b5cf8d478a03786c6c1243320aefd47 (patch)
tree16d91868395df93f058bda1acb469b93070e7e1a /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent2dde63499c4ef836a0d9dfd443494d863ad62b16 (diff)
Improved handling of 1d tensors
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h12
1 files changed, 11 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index cf1352a31..f0e9bb616 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) {}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- void evalTyped(Scalar* buffer) const {
+ void evalProduct(Scalar* buffer) const {
+ if (this->m_j_size == 1) {
+ this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ return;
+ }
+
+ evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;