aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:36:57 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:36:57 -0800
commit71676eaddd7fb6b8abdc5713f437750f3c963fcb (patch)
treecf02d29f7c1ba6850cc6a53176ed4dfee1a9f7f4 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent0a0ab6dd158e3f4471ba1fe20454de35b18fdce5 (diff)
Added support for RowMajor inputs to the contraction code.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h43
1 files changed, 31 insertions, 12 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index f0e9bb616..5851e5adc 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -70,24 +70,43 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t;
- typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t;
-
- typedef array<Index, internal::array_size<Indices>::value> contract_t;
- typedef array<Index, max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> left_nocontract_t;
- typedef array<Index, max_n_1<TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> right_nocontract_t;
-
- static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size;
+ enum {
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ };
+
+ // Most of the code is assuming that both input tensors are ColMajor. If the
+ // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
+ // If we want to compute A * B = C, where A is LHS and B is RHS, the code
+ // will pretend B is LHS and A is RHS.
+ typedef typename internal::conditional<
+ Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType;
+ typedef typename internal::conditional<
+ Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType;
+
+ static const int LDims =
+ internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
+ static const int RDims =
+ internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
+ static const int ContractDims = internal::array_size<Indices>::value;
+
+ typedef array<Index, LDims> left_dim_mapper_t;
+ typedef array<Index, RDims> right_dim_mapper_t;
+
+ typedef array<Index, ContractDims> contract_t;
+ typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
+ typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
+
+ static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
typedef DSizes<Index, NumDims> Dimensions;
// typedefs needed in evalTo
- typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
- typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
+ typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
- typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
- typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
+ typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) {}