diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 170 |
1 files changed, 94 insertions, 76 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index db63eadf9..afd97b340 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -30,30 +30,46 @@ struct ei_L2_block_traits { enum {width = 8 * ei_meta_sqrt<L2MemorySize/(64*sizeof(Scalar))>::ret }; }; -template<bool ConjLhs, bool ConjRhs> struct ei_conj_pmadd; +template<bool ConjLhs, bool ConjRhs> struct ei_conj_helper; -template<> struct ei_conj_pmadd<false,false> +template<> struct ei_conj_helper<false,false> { template<typename T> - EIGEN_STRONG_INLINE T operator()(const T& x, const T& y, T& c) const { return ei_pmadd(x,y,c); } + EIGEN_STRONG_INLINE T pmadd(const T& x, const T& y, const T& c) const { return ei_pmadd(x,y,c); } + template<typename T> + EIGEN_STRONG_INLINE T pmul(const T& x, const T& y) const { return ei_pmul(x,y); } }; -template<> struct ei_conj_pmadd<false,true> +template<> struct ei_conj_helper<false,true> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const + //{ return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + { return x * ei_conj(y); } }; -template<> struct ei_conj_pmadd<true,false> +template<> struct ei_conj_helper<true,false> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const + { return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } }; -template<> struct ei_conj_pmadd<true,true> +template<> struct ei_conj_helper<true,true> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const +// { return std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + { return ei_conj(x) * ei_conj(y); } }; #ifndef EIGEN_EXTERN_INSTANTIATIONS @@ -74,7 +90,9 @@ static void ei_cache_friendly_product( int lhsStride, rhsStride, rows, cols; bool lhsRowMajor; - ei_conj_pmadd<ConjugateLhs,ConjugateRhs> cj_pmadd; + ei_conj_helper<ConjugateLhs,ConjugateRhs> cj; + if (ConjugateRhs) + alpha = ei_conj(alpha); bool hasAlpha = alpha != Scalar(1); if (resRowMajor) @@ -261,59 +279,59 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[2*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[4*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[5*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[6*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += 4*nr*PacketSize; blA += 4*mr; @@ -327,16 +345,16 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += nr*PacketSize; blA += mr; @@ -368,12 +386,12 @@ static void ei_cache_friendly_product( A0 = blA[k]; B0 = blB[0*PacketSize]; B1 = blB[1*PacketSize]; - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B3 = blB[3*PacketSize]; - C1 = cj_pmadd(A0, B1, C1); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + C1 = cj.pmadd(A0, B1, C1); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); blB += nr*PacketSize; } @@ -391,11 +409,11 @@ static void ei_cache_friendly_product( Scalar c0 = Scalar(0); if (lhsRowMajor) for(int k=0; k<actual_kc; k++) - c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0); + c0 += cj.pmul(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k]); else for(int k=0; k<actual_kc; k++) - c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0); - res[(j2)*resStride + i2+i] += alpha * c0; + c0 += cj.pmul(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k]); + res[(j2)*resStride + i2+i] += (ConjugateRhs ? ei_conj(alpha) : alpha) * c0; } } } @@ -493,39 +511,39 @@ static void ei_cache_friendly_product( L0 = ei_pload(&lb[1*PacketSize]); R1 = ei_pload(&lb[2*PacketSize]); L1 = ei_pload(&lb[3*PacketSize]); - T0 = cj_pmadd(A0, R0, T0); - T1 = cj_pmadd(A0, L0, T1); + T0 = cj.pmadd(A0, R0, T0); + T1 = cj.pmadd(A0, L0, T1); R0 = ei_pload(&lb[4*PacketSize]); L0 = ei_pload(&lb[5*PacketSize]); - T0 = cj_pmadd(A1, R1, T0); - T1 = cj_pmadd(A1, L1, T1); + T0 = cj.pmadd(A1, R1, T0); + T1 = cj.pmadd(A1, L1, T1); R1 = ei_pload(&lb[6*PacketSize]); L1 = ei_pload(&lb[7*PacketSize]); - T0 = cj_pmadd(A2, R0, T0); - T1 = cj_pmadd(A2, L0, T1); + T0 = cj.pmadd(A2, R0, T0); + T1 = cj.pmadd(A2, L0, T1); if(MaxBlockRows==8) { R0 = ei_pload(&lb[8*PacketSize]); L0 = ei_pload(&lb[9*PacketSize]); } - T0 = cj_pmadd(A3, R1, T0); - T1 = cj_pmadd(A3, L1, T1); + T0 = cj.pmadd(A3, R1, T0); + T1 = cj.pmadd(A3, L1, T1); if(MaxBlockRows==8) { R1 = ei_pload(&lb[10*PacketSize]); L1 = ei_pload(&lb[11*PacketSize]); - T0 = cj_pmadd(A4, R0, T0); - T1 = cj_pmadd(A4, L0, T1); + T0 = cj.pmadd(A4, R0, T0); + T1 = cj.pmadd(A4, L0, T1); R0 = ei_pload(&lb[12*PacketSize]); L0 = ei_pload(&lb[13*PacketSize]); - T0 = cj_pmadd(A5, R1, T0); - T1 = cj_pmadd(A5, L1, T1); + T0 = cj.pmadd(A5, R1, T0); + T1 = cj.pmadd(A5, L1, T1); R1 = ei_pload(&lb[14*PacketSize]); L1 = ei_pload(&lb[15*PacketSize]); - T0 = cj_pmadd(A6, R0, T0); - T1 = cj_pmadd(A6, L0, T1); - T0 = cj_pmadd(A7, R1, T0); - T1 = cj_pmadd(A7, L1, T1); + T0 = cj.pmadd(A6, R0, T0); + T1 = cj.pmadd(A6, L0, T1); + T0 = cj.pmadd(A7, R1, T0); + T1 = cj.pmadd(A7, L1, T1); } lb += MaxBlockRows*2*PacketSize; |