diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-22 18:04:16 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-22 18:04:16 +0200 |
commit | e7f8e939e282a64025203a7a22e511165e1e3647 (patch) | |
tree | 6e44ea084b0df5b9201731434bd3af58e976f9f6 /Eigen | |
parent | d6475ea390b5a4beeef64e71a247b3f72573d768 (diff) |
* GEMM enhencement: no need to pre-transpose the rhs
=> faster a * b.transpose() product
=> this also fix a bug in a so far untested situation
* SYMM is now ready for use => still have to write the high level
stuff to convert natural expressions into a call to SYMM
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/Product.h | 114 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 73 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 316 | ||||
-rw-r--r-- | Eigen/src/Core/util/BlasUtil.h | 77 |
4 files changed, 405 insertions, 175 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 78cb88c33..754ce4c24 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -73,79 +73,6 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; }; -/* Helper class to analyze the factors of a Product expression. - * In particular it allows to pop out operator-, scalar multiples, - * and conjugate */ -template<typename XprType> struct ei_blas_traits -{ - typedef typename ei_traits<XprType>::Scalar Scalar; - typedef XprType ActualXprType; - enum { - IsComplex = NumTraits<Scalar>::IsComplex, - NeedToConjugate = false, - ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess - }; - typedef typename ei_meta_if<int(ActualAccess)==HasDirectAccess, - const ActualXprType&, - typename ActualXprType::PlainMatrixType - >::ret DirectLinearAccessType; - static inline const ActualXprType& extract(const XprType& x) { return x; } - static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } -}; - -// pop conjugate -template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> > - : ei_blas_traits<NestedXpr> -{ - typedef ei_blas_traits<NestedXpr> Base; - typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - - enum { - IsComplex = NumTraits<Scalar>::IsComplex, - NeedToConjugate = IsComplex - }; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } -}; - -// pop scalar multiple -template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> > - : ei_blas_traits<NestedXpr> -{ - typedef ei_blas_traits<NestedXpr> Base; - typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) - { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } -}; - -// pop opposite -template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> > - : ei_blas_traits<NestedXpr> -{ - typedef ei_blas_traits<NestedXpr> Base; - typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) - { return - Base::extractScalarFactor(x._expression()); } -}; - -// pop opposite -template<typename NestedXpr> struct ei_blas_traits<NestByValue<NestedXpr> > - : ei_blas_traits<NestedXpr> -{ - typedef typename NestedXpr::Scalar Scalar; - typedef ei_blas_traits<NestedXpr> Base; - typedef NestByValue<NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast<const NestedXpr&>(x)); } - static inline Scalar extractScalarFactor(const XprType& x) - { return Base::extractScalarFactor(static_cast<const NestedXpr&>(x)); } -}; - /* Helper class to determine the type of the product, can be either: * - NormalProduct * - CacheFriendlyProduct @@ -869,25 +796,6 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien return derived(); } -template<typename T> struct ei_product_copy_rhs -{ - typedef typename ei_meta_if< - (ei_traits<T>::Flags & RowMajorBit) - || (!(ei_traits<T>::Flags & DirectAccessBit)), - typename ei_plain_matrix_type_column_major<T>::type, - const T& - >::ret type; -}; - -template<typename T> struct ei_product_copy_lhs -{ - typedef typename ei_meta_if< - (!(int(ei_traits<T>::Flags) & DirectAccessBit)), - typename ei_plain_matrix_type<T>::type, - const T& - >::ret type; -}; - template<typename Lhs, typename Rhs, int ProductMode> template<typename DestDerived> inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const @@ -895,26 +803,22 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& typedef ei_blas_traits<_LhsNested> LhsProductTraits; typedef ei_blas_traits<_RhsNested> RhsProductTraits; - typedef typename LhsProductTraits::ActualXprType ActualLhsType; - typedef typename RhsProductTraits::ActualXprType ActualRhsType; + typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; + typedef typename RhsProductTraits::DirectLinearAccessType ActualRhsType; + + typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType; + typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType; - const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); - const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); + const ActualLhsType lhs = LhsProductTraits::extract(m_lhs); + const ActualRhsType rhs = RhsProductTraits::extract(m_rhs); Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs) * RhsProductTraits::extractScalarFactor(m_rhs); - typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; - typedef typename ei_unref<LhsCopy>::type _LhsCopy; - typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy; - typedef typename ei_unref<RhsCopy>::type _RhsCopy; - LhsCopy lhs(actualLhs); - RhsCopy rhs(actualRhs); - ei_general_matrix_matrix_product< Scalar, - (_LhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate), - (_RhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate), + (_ActualLhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate), + (_ActualRhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate), (DestDerived::Flags&RowMajorBit)?RowMajor:ColMajor> ::run( rows(), cols(), lhs.cols(), diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 68949499a..1c48a5ed4 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -89,16 +89,16 @@ static void run(int rows, int cols, int depth, // we have selected one row panel of rhs and one column panel of lhs // pack rhs's panel into a sequential chunk of memory // and expand each coeff to a constant packet for further reuse - ei_gemm_pack_rhs<Scalar, Blocking::PacketSize, Blocking::nr>()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); + ei_gemm_pack_rhs<Scalar, Blocking::nr, RhsStorageOrder>()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); // => GEPP_VAR1 for(int i2=0; i2<rows; i2+=mc) { const int actual_mc = std::min(i2+mc,rows)-i2; - + ei_gemm_pack_lhs<Scalar, Blocking::mr, LhsStorageOrder>()(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); - ei_gebp_kernel<Scalar, PacketType, Blocking::PacketSize, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> >() + ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> >() (res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); } } @@ -110,11 +110,13 @@ static void run(int rows, int cols, int depth, }; // optimized GEneral packed Block * packed Panel product kernel -template<typename Scalar, typename PacketType, int PacketSize, int mr, int nr, typename Conj> +template<typename Scalar, int mr, int nr, typename Conj> struct ei_gebp_kernel { void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int actual_mc, int actual_kc, int packet_cols, int i2, int cols) { + typedef typename ei_packet_traits<Scalar>::type PacketType; + enum { PacketSize = ei_packet_traits<Scalar>::size }; Conj cj; const int peeled_mc = (actual_mc/mr)*mr; // loops on each cache friendly block of the result/rhs @@ -276,7 +278,7 @@ struct ei_gebp_kernel if(nr==4) res[(j2+3)*resStride + i2 + i] += C3; } } - + // process remaining rhs/res columns one at a time // => do the same but with nr==1 for(int j2=packet_cols; j2<cols; j2++) @@ -353,9 +355,11 @@ struct ei_gemm_pack_lhs }; // copy a complete panel of the rhs while expending each coefficient into a packet form -template<typename Scalar, int PacketSize, int nr> -struct ei_gemm_pack_rhs +// this version is optimized for column major matrices +template<typename Scalar, int nr> +struct ei_gemm_pack_rhs<Scalar, nr, ColMajor> { + enum { PacketSize = ei_packet_traits<Scalar>::size }; void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols) { bool hasAlpha = alpha != Scalar(1); @@ -419,6 +423,61 @@ struct ei_gemm_pack_rhs } }; +// this version is optimized for row major matrices +template<typename Scalar, int nr> +struct ei_gemm_pack_rhs<Scalar, nr, RowMajor> +{ + enum { PacketSize = ei_packet_traits<Scalar>::size }; + void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols) + { + bool hasAlpha = alpha != Scalar(1); + int count = 0; + for(int j2=0; j2<packet_cols; j2+=nr) + { + if (hasAlpha) + { + for(int k=0; k<actual_kc; k++) + { + const Scalar* b0 = &rhs[k*rhsStride + j2]; + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[0])); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b0[1])); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b0[2])); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b0[3])); + } + count += nr*PacketSize; + } + } + else + { + for(int k=0; k<actual_kc; k++) + { + const Scalar* b0 = &rhs[k*rhsStride + j2]; + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(b0[0])); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(b0[1])); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(b0[2])); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(b0[3])); + } + count += nr*PacketSize; + } + } + } + // copy the remaining columns one at a time (nr==1) + for(int j2=packet_cols; j2<cols; ++j2) + { + const Scalar* b0 = &rhs[j2]; + for(int k=0; k<actual_kc; k++) + { + ei_pstore(&blockB[count], ei_pset1(alpha*b0[k*rhsStride])); + count += PacketSize; + } + } + } +}; + #endif // EIGEN_EXTERN_INSTANTIATIONS #endif // EIGEN_GENERAL_MATRIX_MATRIX_H diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 5008d227e..af3767e18 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -31,11 +31,12 @@ struct ei_symm_pack_lhs { void operator()(Scalar* blockA, const Scalar* _lhs, int lhsStride, int actual_kc, int actual_mc) { - ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride); + ei_const_blas_data_mapper<Scalar,StorageOrder> lhs(_lhs,lhsStride); int count = 0; const int peeled_mc = (actual_mc/mr)*mr; for(int i=0; i<peeled_mc; i+=mr) { + // normal copy for(int k=0; k<i; k++) for(int w=0; w<mr; w++) blockA[count++] = lhs(i+w,k); @@ -55,6 +56,7 @@ struct ei_symm_pack_lhs for(int w=0; w<mr; w++) blockA[count++] = lhs(k, i+w); } + // do the same with mr==1 for(int i=peeled_mc; i<actual_mc; i++) { @@ -67,86 +69,278 @@ struct ei_symm_pack_lhs } }; +template<typename Scalar, int nr, int StorageOrder> +struct ei_symm_pack_rhs +{ + enum { PacketSize = ei_packet_traits<Scalar>::size }; + void operator()(Scalar* blockB, const Scalar* _rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols, int k2) + { + int end_k = k2 + actual_kc; + int count = 0; + ei_const_blas_data_mapper<Scalar,StorageOrder> rhs(_rhs,rhsStride); + + // first part: standard case + for(int j2=0; j2<k2; j2+=nr) + { + for(int k=k2; k<end_k; k++) + { + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(k,j2+0))); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(k,j2+1))); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(k,j2+2))); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(k,j2+3))); + } + count += nr*PacketSize; + } + } + + // second part: diagonal block + for(int j2=k2; j2<std::min(k2+actual_kc,packet_cols); j2+=nr) + { + // again we can split vertically in three different parts (transpose, symmetric, normal) + // transpose + for(int k=k2; k<j2; k++) + { + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k))); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k))); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k))); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k))); + } + count += nr*PacketSize; + } + // symmetric + int h = 0; + for(int k=j2; k<j2+nr; k++) + { + // normal + for (int w=0 ; w<h; ++w) + ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(k,j2+w))); + // transpose + for (int w=h ; w<nr; ++w) + ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(j2+w,k))); + count += nr*PacketSize; + ++h; + } + // normal + for(int k=j2+nr; k<end_k; k++) + { + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(k,j2+0))); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(k,j2+1))); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(k,j2+2))); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(k,j2+3))); + } + count += nr*PacketSize; + } + } + + // third part: transpose + for(int j2=k2+actual_kc; j2<packet_cols; j2+=nr) + { + for(int k=k2; k<end_k; k++) + { + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k))); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k))); + if (nr==4) + { + ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k))); + ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k))); + } + count += nr*PacketSize; + } + } + + // copy the remaining columns one at a time (=> the same with nr==1) + for(int j2=packet_cols; j2<cols; ++j2) + { + // transpose + int half = std::min(end_k,j2); + for(int k=k2; k<half; k++) + { + ei_pstore(&blockB[count], ei_pset1(alpha*rhs(j2,k))); + count += PacketSize; + } + // normal + for(int k=half; k<k2+actual_kc; k++) + { + ei_pstore(&blockB[count], ei_pset1(alpha*rhs(k,j2))); + count += PacketSize; + } + } + } +}; + /* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of * the general matrix matrix product. */ -template<typename Scalar, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> -static EIGEN_DONT_INLINE void ei_product_selfadjoint_matrix( - int size, - const Scalar* _lhs, int lhsStride, - const Scalar* _rhs, int rhsStride, bool rhsRowMajor, int cols, - Scalar* res, int resStride, - Scalar alpha) +template <typename Scalar, + int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs, + int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs, + int ResStorageOrder> +struct ei_product_selfadjoint_matrix; + +template <typename Scalar, + int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs, + int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs> +struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor> { - typedef typename ei_packet_traits<Scalar>::type Packet; - ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride); - ei_const_blas_data_mapper<Scalar, ColMajor> rhs(_rhs,rhsStride); + static EIGEN_STRONG_INLINE void run( + int rows, int cols, + const Scalar* lhs, int lhsStride, + const Scalar* rhs, int rhsStride, + Scalar* res, int resStride, + Scalar alpha) + { + ei_product_selfadjoint_matrix<Scalar, + RhsStorageOrder==RowMajor ? ColMajor : RowMajor, RhsSelfAdjoint, ConjugateRhs, + LhsStorageOrder==RowMajor ? ColMajor : RowMajor, LhsSelfAdjoint, ConjugateLhs, ColMajor> + ::run(rows, cols, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); + } +}; + +template <typename Scalar, + int LhsStorageOrder, bool ConjugateLhs, + int RhsStorageOrder, bool ConjugateRhs> +struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor> +{ - if (ConjugateRhs) - alpha = ei_conj(alpha); + static EIGEN_DONT_INLINE void run( + int rows, int cols, + const Scalar* _lhs, int lhsStride, + const Scalar* _rhs, int rhsStride, + Scalar* res, int resStride, + Scalar alpha) + { + int size = rows; - typedef typename ei_packet_traits<Scalar>::type PacketType; + ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride); + ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride); - const bool lhsRowMajor = (StorageOrder==RowMajor); + if (ConjugateRhs) + alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar> Blocking; - int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction - int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction + int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction + int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction - Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); - Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); + Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); + Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); - // number of columns which can be processed by packet of nr columns - int packet_cols = (cols/Blocking::nr)*Blocking::nr; + // number of columns which can be processed by packet of nr columns + int packet_cols = (cols/Blocking::nr)*Blocking::nr; - ei_gebp_kernel<Scalar, PacketType, Blocking::PacketSize, - Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel; + ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel; - for(int k2=0; k2<size; k2+=kc) - { - const int actual_kc = std::min(k2+kc,size)-k2; - - // we have selected one row panel of rhs and one column panel of lhs - // pack rhs's panel into a sequential chunk of memory - // and expand each coeff to a constant packet for further reuse - ei_gemm_pack_rhs<Scalar,Blocking::PacketSize,Blocking::nr>() - (blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); - - // the select lhs's panel has to be split in three different parts: - // 1 - the transposed panel above the diagonal block => transposed packed copy - // 2 - the diagonal block => special packed copy - // 3 - the panel below the diagonal block => generic packed copy - for(int i2=0; i2<k2; i2+=mc) + for(int k2=0; k2<size; k2+=kc) { - const int actual_mc = std::min(i2+mc,k2)-i2; - // transposed packed copy - ei_gemm_pack_lhs<Scalar,Blocking::mr,StorageOrder==RowMajor?ColMajor:RowMajor>() - (blockA, &lhs(k2,i2), lhsStride, actual_kc, actual_mc); + const int actual_kc = std::min(k2+kc,size)-k2; - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); - } - // the block diagonal - { - const int actual_mc = std::min(k2+kc,size)-k2; - // symmetric packed copy - ei_symm_pack_lhs<Scalar,Blocking::mr,StorageOrder>() - (blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols); + // we have selected one row panel of rhs and one column panel of lhs + // pack rhs's panel into a sequential chunk of memory + // and expand each coeff to a constant packet for further reuse + ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder>() + (blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); + + // the select lhs's panel has to be split in three different parts: + // 1 - the transposed panel above the diagonal block => transposed packed copy + // 2 - the diagonal block => special packed copy + // 3 - the panel below the diagonal block => generic packed copy + for(int i2=0; i2<k2; i2+=mc) + { + const int actual_mc = std::min(i2+mc,k2)-i2; + // transposed packed copy if Lower part + ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor>() + (blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } + // the block diagonal + { + const int actual_mc = std::min(k2+kc,size)-k2; + // symmetric packed copy + ei_symm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>() + (blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols); + } + + for(int i2=k2+kc; i2<size; i2+=mc) + { + const int actual_mc = std::min(i2+mc,size)-i2; + ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>() + (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } } - for(int i2=k2+kc; i2<size; i2+=mc) + ei_aligned_stack_delete(Scalar, blockA, kc*mc); + ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); + } +}; + +// matrix * selfadjoint product +template <typename Scalar, + int LhsStorageOrder, bool ConjugateLhs, + int RhsStorageOrder, bool ConjugateRhs> +struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor> +{ + + static EIGEN_DONT_INLINE void run( + int rows, int cols, + const Scalar* _lhs, int lhsStride, + const Scalar* _rhs, int rhsStride, + Scalar* res, int resStride, + Scalar alpha) + { + int size = cols; + + ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride); + ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride); + + if (ConjugateRhs) + alpha = ei_conj(alpha); + + typedef ei_product_blocking_traits<Scalar> Blocking; + + int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction + int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction + + Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); + Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); + + // number of columns which can be processed by packet of nr columns + int packet_cols = (cols/Blocking::nr)*Blocking::nr; + + ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel; + + for(int k2=0; k2<size; k2+=kc) { - const int actual_mc = std::min(i2+mc,size)-i2; - ei_gemm_pack_lhs<Scalar,Blocking::mr,StorageOrder>() - (blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + const int actual_kc = std::min(k2+kc,size)-k2; + + ei_symm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder>() + (blockB, _rhs, rhsStride, alpha, actual_kc, packet_cols, cols, k2); + + // => GEPP + for(int i2=0; i2<rows; i2+=mc) + { + const int actual_mc = std::min(i2+mc,rows)-i2; + ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>() + (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } } - } - ei_aligned_stack_delete(Scalar, blockA, kc*mc); - ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); -} + ei_aligned_stack_delete(Scalar, blockA, kc*mc); + ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); + } +}; #endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 6e4b21e6a..25829652f 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -29,10 +29,10 @@ // implement and control fast level 2 and level 3 BLAS-like routines. // forward declarations -template<typename Scalar, typename Packet, int PacketSize, int mr, int nr, typename Conj> +template<typename Scalar, int mr, int nr, typename Conj> struct ei_gebp_kernel; -template<typename Scalar, int PacketSize, int nr> +template<typename Scalar, int nr, int StorageOrder> struct ei_gemm_pack_rhs; template<typename Scalar, int mr, int StorageOrder> @@ -154,4 +154,77 @@ struct ei_product_blocking_traits }; }; +/* Helper class to analyze the factors of a Product expression. + * In particular it allows to pop out operator-, scalar multiples, + * and conjugate */ +template<typename XprType> struct ei_blas_traits +{ + typedef typename ei_traits<XprType>::Scalar Scalar; + typedef XprType ActualXprType; + enum { + IsComplex = NumTraits<Scalar>::IsComplex, + NeedToConjugate = false, + ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess + }; + typedef typename ei_meta_if<int(ActualAccess)==HasDirectAccess, + const ActualXprType&, + typename ActualXprType::PlainMatrixType + >::ret DirectLinearAccessType; + static inline const ActualXprType& extract(const XprType& x) { return x; } + static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } +}; + +// pop conjugate +template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> > + : ei_blas_traits<NestedXpr> +{ + typedef ei_blas_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + + enum { + IsComplex = NumTraits<Scalar>::IsComplex, + NeedToConjugate = IsComplex + }; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } +}; + +// pop scalar multiple +template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> > + : ei_blas_traits<NestedXpr> +{ + typedef ei_blas_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> > + : ei_blas_traits<NestedXpr> +{ + typedef ei_blas_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return - Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template<typename NestedXpr> struct ei_blas_traits<NestByValue<NestedXpr> > + : ei_blas_traits<NestedXpr> +{ + typedef typename NestedXpr::Scalar Scalar; + typedef ei_blas_traits<NestedXpr> Base; + typedef NestByValue<NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast<const NestedXpr&>(x)); } + static inline Scalar extractScalarFactor(const XprType& x) + { return Base::extractScalarFactor(static_cast<const NestedXpr&>(x)); } +}; + #endif // EIGEN_BLASUTIL_H |