diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:36:57 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:36:57 -0800 |
commit | 71676eaddd7fb6b8abdc5713f437750f3c963fcb (patch) | |
tree | cf02d29f7c1ba6850cc6a53176ed4dfee1a9f7f4 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | |
parent | 0a0ab6dd158e3f4471ba1fe20454de35b18fdce5 (diff) |
Added support for RowMajor inputs to the contraction code.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index f0e9bb616..5851e5adc 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -70,24 +70,43 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t; - typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t; - - typedef array<Index, internal::array_size<Indices>::value> contract_t; - typedef array<Index, max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> left_nocontract_t; - typedef array<Index, max_n_1<TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> right_nocontract_t; - - static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; + enum { + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + }; + + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; + static const int RDims = + internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; + static const int ContractDims = internal::array_size<Indices>::value; + + typedef array<Index, LDims> left_dim_mapper_t; + typedef array<Index, RDims> right_dim_mapper_t; + + typedef array<Index, ContractDims> contract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; + + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef DSizes<Index, NumDims> Dimensions; // typedefs needed in evalTo - typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; - typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; - typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<RightArgType, Device> RightEvaluator; + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {} |