From 96e7d9f8969395db702775eaa0907b4aa941b2ba Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 8 Jul 2009 18:24:37 +0200 Subject: ok now all the complex mat-mat and mat-vec products involving conjugate, adjoint, -, and scalar multiple seems to be well handled. It only remains the simpler case: C = alpha*(A*B) ... for the next commit --- Eigen/src/Core/products/GeneralMatrixMatrix.h | 170 ++++++++++++++------------ Eigen/src/Core/products/GeneralMatrixVector.h | 101 +++++++++------ 2 files changed, 155 insertions(+), 116 deletions(-) (limited to 'Eigen/src/Core/products') diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index db63eadf9..afd97b340 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -30,30 +30,46 @@ struct ei_L2_block_traits { enum {width = 8 * ei_meta_sqrt::ret }; }; -template struct ei_conj_pmadd; +template struct ei_conj_helper; -template<> struct ei_conj_pmadd +template<> struct ei_conj_helper { template - EIGEN_STRONG_INLINE T operator()(const T& x, const T& y, T& c) const { return ei_pmadd(x,y,c); } + EIGEN_STRONG_INLINE T pmadd(const T& x, const T& y, const T& c) const { return ei_pmadd(x,y,c); } + template + EIGEN_STRONG_INLINE T pmul(const T& x, const T& y) const { return ei_pmul(x,y); } }; -template<> struct ei_conj_pmadd +template<> struct ei_conj_helper { - template std::complex operator()(const std::complex& x, const std::complex& y, std::complex& c) const - { return c + std::complex(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + template std::complex + pmadd(const std::complex& x, const std::complex& y, const std::complex& c) const + { return c + pmul(x,y); } + + template std::complex pmul(const std::complex& x, const std::complex& y) const + //{ return std::complex(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + { return x * ei_conj(y); } }; -template<> struct ei_conj_pmadd +template<> struct ei_conj_helper { - template std::complex operator()(const std::complex& x, const std::complex& y, std::complex& c) const - { return c + std::complex(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template std::complex + pmadd(const std::complex& x, const std::complex& y, const std::complex& c) const + { return c + pmul(x,y); } + + template std::complex pmul(const std::complex& x, const std::complex& y) const + { return std::complex(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } }; -template<> struct ei_conj_pmadd +template<> struct ei_conj_helper { - template std::complex operator()(const std::complex& x, const std::complex& y, std::complex& c) const - { return c + std::complex(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template std::complex + pmadd(const std::complex& x, const std::complex& y, const std::complex& c) const + { return c + pmul(x,y); } + + template std::complex pmul(const std::complex& x, const std::complex& y) const +// { return std::complex(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + { return ei_conj(x) * ei_conj(y); } }; #ifndef EIGEN_EXTERN_INSTANTIATIONS @@ -74,7 +90,9 @@ static void ei_cache_friendly_product( int lhsStride, rhsStride, rows, cols; bool lhsRowMajor; - ei_conj_pmadd cj_pmadd; + ei_conj_helper cj; + if (ConjugateRhs) + alpha = ei_conj(alpha); bool hasAlpha = alpha != Scalar(1); if (resRowMajor) @@ -261,59 +279,59 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[2*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[4*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[5*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[6*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += 4*nr*PacketSize; blA += 4*mr; @@ -327,16 +345,16 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += nr*PacketSize; blA += mr; @@ -368,12 +386,12 @@ static void ei_cache_friendly_product( A0 = blA[k]; B0 = blB[0*PacketSize]; B1 = blB[1*PacketSize]; - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B3 = blB[3*PacketSize]; - C1 = cj_pmadd(A0, B1, C1); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + C1 = cj.pmadd(A0, B1, C1); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); blB += nr*PacketSize; } @@ -391,11 +409,11 @@ static void ei_cache_friendly_product( Scalar c0 = Scalar(0); if (lhsRowMajor) for(int k=0; k -static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( +template +static EIGEN_DONT_INLINE +void ei_cache_friendly_product_colmajor_times_vector( int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, @@ -47,10 +48,14 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( ei_pstore(&res[j], \ ei_padd(ei_pload(&res[j]), \ ei_padd( \ - ei_padd(ei_pmul(ptmp0,EIGEN_CAT(ei_ploa , A0)(&lhs0[j])), \ - ei_pmul(ptmp1,EIGEN_CAT(ei_ploa , A13)(&lhs1[j]))), \ - ei_padd(ei_pmul(ptmp2,EIGEN_CAT(ei_ploa , A2)(&lhs2[j])), \ - ei_pmul(ptmp3,EIGEN_CAT(ei_ploa , A13)(&lhs3[j]))) ))) + ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A0)(&lhs0[j]), ptmp0), \ + cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs1[j]), ptmp1)), \ + ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A2)(&lhs2[j]), ptmp2), \ + cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs3[j]), ptmp3)) ))) + + ei_conj_helper cj; + if(ConjugateRhs) + alpha = ei_conj(alpha); typedef typename ei_packet_traits::type Packet; const int PacketSize = sizeof(Packet)/sizeof(Scalar); @@ -109,7 +114,7 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( ptmp2 = ei_pset1(alpha*rhs[i+2]), ptmp3 = ei_pset1(alpha*rhs[i+offset3]); // this helps a lot generating better binary code - const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, + const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; if (PacketSize>1) @@ -117,7 +122,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( /* explicit vectorization */ // process initial unaligned coeffs for (int j=0; jalignedStart) { @@ -148,19 +159,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( A00 = ei_pload (&lhs0[j]); A10 = ei_pload (&lhs0[j+PacketSize]); - A00 = ei_pmadd(ptmp0, A00, ei_pload(&res[j])); - A10 = ei_pmadd(ptmp0, A10, ei_pload(&res[j+PacketSize])); + A00 = cj.pmadd(A00, ptmp0, ei_pload(&res[j])); + A10 = cj.pmadd(A10, ptmp0, ei_pload(&res[j+PacketSize])); - A00 = ei_pmadd(ptmp1, A01, A00); + A00 = cj.pmadd(A01, ptmp1, A00); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); - A00 = ei_pmadd(ptmp2, A02, A00); + A00 = cj.pmadd(A02, ptmp2, A00); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); - A00 = ei_pmadd(ptmp3, A03, A00); + A00 = cj.pmadd(A03, ptmp3, A00); ei_pstore(&res[j],A00); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); - A10 = ei_pmadd(ptmp1, A11, A10); - A10 = ei_pmadd(ptmp2, A12, A10); - A10 = ei_pmadd(ptmp3, A13, A10); + A10 = cj.pmadd(A11, ptmp1, A10); + A10 = cj.pmadd(A12, ptmp2, A10); + A10 = cj.pmadd(A13, ptmp3, A10); ei_pstore(&res[j+PacketSize],A10); } } @@ -177,7 +188,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( /* process remaining coeffs (or all if there is no explicit vectorization) */ for (int j=alignedSize; j +template static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, @@ -236,10 +253,12 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\ Packet b = ei_pload(&rhs[j]); \ - ptmp0 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), ptmp0); \ - ptmp1 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), ptmp1); \ - ptmp2 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), ptmp2); \ - ptmp3 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), ptmp3); } + ptmp0 = cj.pmadd(EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), b, ptmp0); \ + ptmp1 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), b, ptmp1); \ + ptmp2 = cj.pmadd(EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), b, ptmp2); \ + ptmp3 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), b, ptmp3); } + + ei_conj_helper cj; typedef typename ei_packet_traits::type Packet; const int PacketSize = sizeof(Packet)/sizeof(Scalar); @@ -311,7 +330,8 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( for (int j=0; jalignedStart) @@ -347,19 +367,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( A12 = ei_pload(&lhs2[j-2+PacketSize]); ei_palign<2>(A02,A12); A13 = ei_pload(&lhs3[j-3+PacketSize]); ei_palign<3>(A03,A13); - ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j]), ptmp0); - ptmp1 = ei_pmadd(b, A01, ptmp1); + ptmp0 = cj.pmadd(ei_pload (&lhs0[j]), b, ptmp0); + ptmp1 = cj.pmadd(A01, b, ptmp1); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); - ptmp2 = ei_pmadd(b, A02, ptmp2); + ptmp2 = cj.pmadd(A02, b, ptmp2); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); - ptmp3 = ei_pmadd(b, A03, ptmp3); + ptmp3 = cj.pmadd(A03, b, ptmp3); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); b = ei_pload(&rhs[j+PacketSize]); - ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j+PacketSize]), ptmp0); - ptmp1 = ei_pmadd(b, A11, ptmp1); - ptmp2 = ei_pmadd(b, A12, ptmp2); - ptmp3 = ei_pmadd(b, A13, ptmp3); + ptmp0 = cj.pmadd(ei_pload (&lhs0[j+PacketSize]), b, ptmp0); + ptmp1 = cj.pmadd(A11, b, ptmp1); + ptmp2 = cj.pmadd(A12, b, ptmp2); + ptmp3 = cj.pmadd(A13, b, ptmp3); } } for (int j = peeledSize; jalignedStart) { // process aligned rhs coeffs if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0) for (int j = alignedStart;j