diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:36:57 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:36:57 -0800 |
commit | 71676eaddd7fb6b8abdc5713f437750f3c963fcb (patch) | |
tree | cf02d29f7c1ba6850cc6a53176ed4dfee1a9f7f4 /unsupported/Eigen/CXX11/src | |
parent | 0a0ab6dd158e3f4471ba1fe20454de35b18fdce5 (diff) |
Added support for RowMajor inputs to the contraction code.
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
3 files changed, 220 insertions, 94 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index c5ec42cf4..a02a273e7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -320,6 +320,8 @@ class TensorContractionInputMapper }; + + template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t, @@ -362,6 +364,14 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co }; +template <size_t n> struct max_n_1 { + static const size_t size = n; +}; +template <> struct max_n_1<0> { + static const size_t size = 1; +}; + + template<typename Dimensions, typename LhsXprType, typename RhsXprType> struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > { @@ -378,6 +388,10 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + // From NumDims below. + static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size; + static const int Layout = traits<LhsXprType>::Layout; + enum { Flags = 0, }; @@ -401,19 +415,19 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; typedef Device_ Device; + + // From NumDims below. + static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size; }; } // end namespace internal - - template<typename Indices, typename LhsXprType, typename RhsXprType> class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType> > { public: typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet; - typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType, @@ -422,20 +436,21 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} - EIGEN_DEVICE_FUNC - const Indices& indices() const { return m_indices; } + EIGEN_DEVICE_FUNC + const Indices& indices() const { return m_indices; } - /** \returns the nested expressions */ - EIGEN_DEVICE_FUNC - const typename internal::remove_all<typename LhsXprType::Nested>::type& - lhsExpression() const { return m_lhs_xpr; } + /** \returns the nested expressions */ + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename LhsXprType::Nested>::type& + lhsExpression() const { return m_lhs_xpr; } - EIGEN_DEVICE_FUNC - const typename internal::remove_all<typename RhsXprType::Nested>::type& - rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename RhsXprType::Nested>::type& + rhsExpression() const { return m_rhs_xpr; } protected: typename LhsXprType::Nested m_lhs_xpr; @@ -444,12 +459,17 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp }; -template <size_t n> struct max_n_1 { - static const size_t size = n; -}; -template <> struct max_n_1<0> { - static const size_t size = 1; -}; +template<bool cond> struct Cond {}; + +template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T1& choose(Cond<true>, const T1& first, const T2&) { + return first; +} + +template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T2& choose(Cond<false>, const T1&, const T2& second) { + return second; +} template<typename Derived> @@ -467,37 +487,94 @@ struct TensorContractionEvaluatorBase typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t; - typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t; - - typedef array<Index, internal::array_size<Indices>::value> contract_t; - typedef array<Index, max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> left_nocontract_t; - typedef array<Index, max_n_1<TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> right_nocontract_t; - - static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; - - typedef DSizes<Index, NumDims> Dimensions; - enum { IsAligned = true, PacketAccess = (internal::packet_traits<Scalar>::size > 1), + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType& op, const Device& device) - : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_device(device), m_result(NULL) - { + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; + static const int RDims = + internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; + static const int ContractDims = internal::array_size<Indices>::value; + static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + + typedef array<Index, LDims> left_dim_mapper_t; + typedef array<Index, RDims> right_dim_mapper_t; + typedef array<Index, ContractDims> contract_t; + typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + + typedef DSizes<Index, NumDims> Dimensions; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorContractionEvaluatorBase(const XprType& op, const Device& device) + : m_leftImpl(choose(Cond<Layout == ColMajor>(), + op.lhsExpression(), op.rhsExpression()), device), + m_rightImpl(choose(Cond<Layout == ColMajor>(), + op.rhsExpression(), op.lhsExpression()), device), + m_device(device), + m_result(NULL) { + EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == + TensorEvaluator<RightArgType, Device>::Layout), + YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert((internal::array_size<contract_t>::value > 0) && "Must contract on some indices"); - array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> lhs_strides; + + DSizes<Index, LDims> eval_left_dims; + DSizes<Index, RDims> eval_right_dims; + array<IndexPair<Index>, ContractDims> eval_op_indices; + if (Layout == ColMajor) { + // For ColMajor, we keep using the existing dimensions + for (int i = 0; i < LDims; i++) { + eval_left_dims[i] = m_leftImpl.dimensions()[i]; + } + for (int i = 0; i < RDims; i++) { + eval_right_dims[i] = m_rightImpl.dimensions()[i]; + } + // We keep the pairs of contracting indices. + for (int i = 0; i < ContractDims; i++) { + eval_op_indices[i].first = op.indices()[i].first; + eval_op_indices[i].second = op.indices()[i].second; + } + } else { + // For RowMajor, we need to reverse the existing dimensions + for (int i = 0; i < LDims; i++) { + eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; + } + for (int i = 0; i < RDims; i++) { + eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; + } + // We need to flip all the pairs of contracting indices as well as + // reversing the dimensions. + for (int i = 0; i < ContractDims; i++) { + eval_op_indices[i].first = LDims - 1 - op.indices()[i].second; + eval_op_indices[i].second = RDims - 1 - op.indices()[i].first; + } + } + + array<Index, LDims> lhs_strides; lhs_strides[0] = 1; - for (int i = 0; i < TensorEvaluator<LeftArgType, Device>::Dimensions::count-1; ++i) { - lhs_strides[i+1] = lhs_strides[i] * m_leftImpl.dimensions()[i]; + for (int i = 0; i < LDims-1; ++i) { + lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; } - array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> rhs_strides; + array<Index, RDims> rhs_strides; rhs_strides[0] = 1; - for (int i = 0; i < TensorEvaluator<RightArgType, Device>::Dimensions::count-1; ++i) { - rhs_strides[i+1] = rhs_strides[i] * m_rightImpl.dimensions()[i]; + for (int i = 0; i < RDims-1; ++i) { + rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; } m_i_strides[0] = 1; @@ -515,27 +592,28 @@ struct TensorContractionEvaluatorBase m_lhs_inner_dim_contiguous = true; int dim_idx = 0; int nocontract_idx = 0; - const typename TensorEvaluator<LeftArgType, Device>::Dimensions& left_dims = m_leftImpl.dimensions(); - for (int i = 0; i < TensorEvaluator<LeftArgType, Device>::Dimensions::count; i++) { + + for (int i = 0; i < LDims; i++) { // find if we are contracting on index i of left tensor bool contracting = false; - for (int j = 0; j < internal::array_size<Indices>::value; j++) { - if (op.indices()[j].first == i) { + for (int j = 0; j < ContractDims; j++) { + if (eval_op_indices[j].first == i) { contracting = true; break; } } if (!contracting) { // add dimension size to output dimensions - m_dimensions[dim_idx] = left_dims[i]; + m_dimensions[dim_idx] = eval_left_dims[i]; m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; if (dim_idx != i) { m_lhs_inner_dim_contiguous = false; } if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) { - m_i_strides[nocontract_idx+1] = m_i_strides[nocontract_idx] * left_dims[i]; + m_i_strides[nocontract_idx+1] = + m_i_strides[nocontract_idx] * eval_left_dims[i]; } else { - m_i_size = m_i_strides[nocontract_idx] * left_dims[i]; + m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; } dim_idx++; nocontract_idx++; @@ -543,22 +621,22 @@ struct TensorContractionEvaluatorBase } nocontract_idx = 0; - const typename TensorEvaluator<RightArgType, Device>::Dimensions& right_dims = m_rightImpl.dimensions(); - for (int i = 0; i < TensorEvaluator<RightArgType, Device>::Dimensions::count; i++) { + for (int i = 0; i < RDims; i++) { bool contracting = false; // find if we are contracting on index i of right tensor - for (int j = 0; j < internal::array_size<Indices>::value; j++) { - if (op.indices()[j].second == i) { + for (int j = 0; j < ContractDims; j++) { + if (eval_op_indices[j].second == i) { contracting = true; break; } } if (!contracting) { - m_dimensions[dim_idx] = right_dims[i]; + m_dimensions[dim_idx] = eval_right_dims[i]; if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) { - m_j_strides[nocontract_idx+1] = m_j_strides[nocontract_idx] * right_dims[i]; + m_j_strides[nocontract_idx+1] = + m_j_strides[nocontract_idx] * eval_right_dims[i]; } else { - m_j_size = m_j_strides[nocontract_idx] * right_dims[i]; + m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; } m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; dim_idx++; @@ -573,12 +651,13 @@ struct TensorContractionEvaluatorBase // each tensor, we'll only look at the first tensor here. m_rhs_inner_dim_contiguous = true; m_rhs_inner_dim_reordered = false; - for (int i = 0; i < internal::array_size<Indices>::value; i++) { - Index left = op.indices()[i].first; - Index right = op.indices()[i].second; + for (int i = 0; i < ContractDims; i++) { + Index left = eval_op_indices[i].first; + Index right = eval_op_indices[i].second; - Index size = left_dims[left]; - eigen_assert(size == right_dims[right] && "Contraction axes must be same size"); + Index size = eval_left_dims[left]; + eigen_assert(size == eval_right_dims[right] && + "Contraction axes must be same size"); if (i+1 < internal::array_size<contract_t>::value) { m_k_strides[i+1] = m_k_strides[i] * size; @@ -588,7 +667,7 @@ struct TensorContractionEvaluatorBase m_left_contracting_strides[i] = lhs_strides[left]; m_right_contracting_strides[i] = rhs_strides[right]; - if (i > 0 && right < op.indices()[i-1].second) { + if (i > 0 && right < eval_op_indices[i-1].second) { m_rhs_inner_dim_reordered = true; } if (right != i) { @@ -597,9 +676,16 @@ struct TensorContractionEvaluatorBase } // Scalar case. We represent the result as a 1d tensor of size 1. - if (TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count == 2 * internal::array_size<Indices>::value) { + if (LDims + RDims == 2 * ContractDims) { m_dimensions[0] = 1; } + + // If the layout is RowMajor, we need to reverse the m_dimensions + if (Layout == RowMajor) { + for (int i = 0, j = NumDims - 1; i < j; i++, j--) { + std::swap(m_dimensions[i], m_dimensions[j]); + } + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -661,10 +747,10 @@ struct TensorContractionEvaluatorBase const Index rows = m_i_size; const Index cols = m_k_size; - typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; - typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; - typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<RightArgType, Device> RightEvaluator; + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; const int lhs_packet_size = internal::packet_traits<LhsScalar>::size; const int rhs_packet_size = internal::packet_traits<RhsScalar>::size; typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, @@ -719,7 +805,6 @@ struct TensorContractionEvaluatorBase protected: // Prevent assignment TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&); - Dimensions m_dimensions; contract_t m_k_strides; @@ -739,16 +824,18 @@ struct TensorContractionEvaluatorBase Index m_j_size; Index m_k_size; - TensorEvaluator<LeftArgType, Device> m_leftImpl; - TensorEvaluator<RightArgType, Device> m_rightImpl; + TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; + TensorEvaluator<EvalRightArgType, Device> m_rightImpl; const Device& m_device; Scalar* m_result; }; +// evaluator for default device template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : - public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { + public TensorContractionEvaluatorBase< + TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; @@ -759,15 +846,35 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t; - typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t; + enum { + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + }; + + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; + static const int RDims = + internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; + static const int ContractDims = internal::array_size<Indices>::value; - typedef array<Index, internal::array_size<Indices>::value> contract_t; - typedef array<Index, max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> left_nocontract_t; - typedef array<Index, max_n_1<TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> right_nocontract_t; + typedef array<Index, LDims> left_dim_mapper_t; + typedef array<Index, RDims> right_dim_mapper_t; - static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; + typedef array<Index, ContractDims> contract_t; + typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + + // Could we use NumDimensions here? typedef DSizes<Index, NumDims> Dimensions; @@ -799,15 +906,15 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); // define mr, nr, and all of my data mapper types - typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; - typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; const Index nr = Traits::nr; const Index mr = Traits::mr; - typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<RightArgType, Device> RightEvaluator; + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; const int lhs_packet_size = internal::packet_traits<LhsScalar>::size; const int rhs_packet_size = internal::packet_traits<RhsScalar>::size; @@ -826,10 +933,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; - // Declare GEBP packing and kernel structs internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs; internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs; + internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp; // initialize data mappers diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h index f6bd949bd..588770bb4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -1241,10 +1241,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef array<Index, RDims> right_dim_mapper_t; typedef array<Index, ContractDims> contract_t; - typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; - static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef DSizes<Index, NumDims> Dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index f0e9bb616..5851e5adc 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -70,24 +70,43 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t; - typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t; - - typedef array<Index, internal::array_size<Indices>::value> contract_t; - typedef array<Index, max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> left_nocontract_t; - typedef array<Index, max_n_1<TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value>::size> right_nocontract_t; - - static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; + enum { + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + }; + + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; + static const int RDims = + internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; + static const int ContractDims = internal::array_size<Indices>::value; + + typedef array<Index, LDims> left_dim_mapper_t; + typedef array<Index, RDims> right_dim_mapper_t; + + typedef array<Index, ContractDims> contract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; + + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef DSizes<Index, NumDims> Dimensions; // typedefs needed in evalTo - typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; - typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; - typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<RightArgType, Device> RightEvaluator; + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {} |