diff options
author | 2010-07-07 19:49:09 +0200 | |
---|---|---|
committer | 2010-07-07 19:49:09 +0200 | |
commit | 31a36aa9c407d736075de8dc06f5af0d0fe912d5 (patch) | |
tree | 58febad97092226a768b89ad571d56ed7470a2b2 /Eigen | |
parent | 861962c55f728a1eb68c0b6915c77e8c9b424cff (diff) |
support for real * complex matrix product - step 1 (works for some special cases)
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/arch/SSE/Complex.h | 18 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 410 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 98 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointProduct.h | 6 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverMatrix.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/util/BlasUtil.h | 41 |
8 files changed, 331 insertions, 266 deletions
diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index 4ecfc2f43..259aebe2c 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -194,6 +194,15 @@ template<> struct ei_conj_helper<Packet2cf, Packet2cf, true,true> } }; +template<> struct ei_conj_helper<Packet4f, Packet2cf, false,false> +{ + EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet4f& x, const Packet2cf& y, const Packet2cf& c) const + { return ei_padd(c, pmul(x,y)); } + + EIGEN_STRONG_INLINE Packet2cf pmul(const Packet4f& x, const Packet2cf& y) const + { return Packet2cf(ei_pmul(x, y.v)); } +}; + template<> EIGEN_STRONG_INLINE Packet2cf ei_pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { // TODO optimize it for SSE3 and 4 @@ -359,6 +368,15 @@ template<> struct ei_conj_helper<Packet1cd, Packet1cd, true,true> } }; +template<> struct ei_conj_helper<Packet2d, Packet1cd, false,false> +{ + EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet2d& x, const Packet1cd& y, const Packet1cd& c) const + { return ei_padd(c, pmul(x,y)); } + + EIGEN_STRONG_INLINE Packet1cd pmul(const Packet2d& x, const Packet1cd& y) const + { return Packet1cd(ei_pmul(x, y.v)); } +}; + template<> EIGEN_STRONG_INLINE Packet1cd ei_pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { // TODO optimize it for SSE3 and 4 diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index cf133f68f..3dae26eee 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -114,9 +114,9 @@ void computeProductBlockingSizes(std::ptrdiff_t& k, std::ptrdiff_t& m, std::ptrd std::ptrdiff_t l1, l2; enum { - kdiv = KcFactor * 2 * ei_product_blocking_traits<RhsScalar>::nr + kdiv = KcFactor * 2 * ei_product_blocking_traits<LhsScalar,RhsScalar>::nr * ei_packet_traits<RhsScalar>::size * sizeof(RhsScalar), - mr = ei_product_blocking_traits<LhsScalar>::mr, + mr = ei_product_blocking_traits<LhsScalar,RhsScalar>::mr, mr_mask = (0xffffffff/mr)*mr }; @@ -140,35 +140,50 @@ inline void computeProductBlockingSizes(std::ptrdiff_t& k, std::ptrdiff_t& m, st #endif // optimized GEneral packed Block * packed Panel product kernel -template<typename Scalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> struct ei_gebp_kernel { - void operator()(Scalar* res, Index resStride, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, Scalar* unpackedB = 0) + typedef typename ei_scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; + + enum { + Vectorizable = ei_packet_traits<LhsScalar>::Vectorizable && ei_packet_traits<RhsScalar>::Vectorizable, + LhsPacketSize = Vectorizable ? ei_packet_traits<LhsScalar>::size : 1, + RhsPacketSize = Vectorizable ? ei_packet_traits<RhsScalar>::size : 1, + ResPacketSize = Vectorizable ? ei_packet_traits<ResScalar>::size : 1 + }; + + typedef typename ei_packet_traits<LhsScalar>::type _LhsPacketType; + typedef typename ei_packet_traits<RhsScalar>::type _RhsPacketType; + typedef typename ei_packet_traits<ResScalar>::type _ResPacketType; + + typedef typename ei_meta_if<Vectorizable,_LhsPacketType,LhsScalar>::ret LhsPacketType; + typedef typename ei_meta_if<Vectorizable,_RhsPacketType,RhsScalar>::ret RhsPacketType; + typedef typename ei_meta_if<Vectorizable,_ResPacketType,ResScalar>::ret ResPacketType; + + void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB = 0) { - typedef typename ei_packet_traits<Scalar>::type PacketType; - enum { PacketSize = ei_packet_traits<Scalar>::size }; if(strideA==-1) strideA = depth; if(strideB==-1) strideB = depth; - ei_conj_helper<Scalar,Scalar,ConjugateLhs,ConjugateRhs> cj; - ei_conj_helper<PacketType,PacketType,ConjugateLhs,ConjugateRhs> pcj; + ei_conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; + ei_conj_helper<LhsPacketType,RhsPacketType,ConjugateLhs,ConjugateRhs> pcj; Index packet_cols = (cols/nr) * nr; const Index peeled_mc = (rows/mr)*mr; - const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= PacketSize ? PacketSize : 0); + const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= LhsPacketSize ? LhsPacketSize : 0); const Index peeled_kc = (depth/4)*4; if(unpackedB==0) - unpackedB = const_cast<Scalar*>(blockB - strideB * nr * PacketSize); + unpackedB = const_cast<RhsScalar*>(blockB - strideB * nr * RhsPacketSize); // loops on each micro vertical panel of rhs (depth x nr) for(Index j2=0; j2<packet_cols; j2+=nr) { // unpack B { - const Scalar* blB = &blockB[j2*strideB+offsetB*nr]; + const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; Index n = depth*nr; for(Index k=0; k<n; k++) - ei_pstore(&unpackedB[k*PacketSize], ei_pset1(blB[k])); + ei_pstore(&unpackedB[k*RhsPacketSize], ei_pset1(blB[k])); /*Scalar* dest = unpackedB; for(Index k=0; k<n; k+=4*PacketSize) { @@ -222,26 +237,26 @@ struct ei_gebp_kernel // stored into mr/packet_size x nr registers. for(Index i=0; i<peeled_mc; i+=mr) { - const Scalar* blA = &blockA[i*strideA+offsetA*mr]; + const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; ei_prefetch(&blA[0]); // TODO move the res loads to the stores // gets res block as register - PacketType C0, C1, C2, C3, C4, C5, C6, C7; - C0 = ei_pset1(Scalar(0)); - C1 = ei_pset1(Scalar(0)); - if(nr==4) C2 = ei_pset1(Scalar(0)); - if(nr==4) C3 = ei_pset1(Scalar(0)); - C4 = ei_pset1(Scalar(0)); - C5 = ei_pset1(Scalar(0)); - if(nr==4) C6 = ei_pset1(Scalar(0)); - if(nr==4) C7 = ei_pset1(Scalar(0)); - - Scalar* r0 = &res[(j2+0)*resStride + i]; - Scalar* r1 = r0 + resStride; - Scalar* r2 = r1 + resStride; - Scalar* r3 = r2 + resStride; + ResPacketType C0, C1, C2, C3, C4, C5, C6, C7; + C0 = ei_pset1(ResScalar(0)); + C1 = ei_pset1(ResScalar(0)); + if(nr==4) C2 = ei_pset1(ResScalar(0)); + if(nr==4) C3 = ei_pset1(ResScalar(0)); + C4 = ei_pset1(ResScalar(0)); + C5 = ei_pset1(ResScalar(0)); + if(nr==4) C6 = ei_pset1(ResScalar(0)); + if(nr==4) C7 = ei_pset1(ResScalar(0)); + + ResScalar* r0 = &res[(j2+0)*resStride + i]; + ResScalar* r1 = r0 + resStride; + ResScalar* r2 = r1 + resStride; + ResScalar* r3 = r2 + resStride; ei_prefetch(r0+16); ei_prefetch(r1+16); @@ -251,110 +266,111 @@ struct ei_gebp_kernel // performs "inner" product // TODO let's check wether the folowing peeled loop could not be // optimized via optimal prefetching from one loop to the other - const Scalar* blB = unpackedB; + const RhsScalar* blB = unpackedB; for(Index k=0; k<peeled_kc; k+=4) { if(nr==2) { - PacketType B0, A0, A1; + LhsPacketType A0, A1; + RhsPacketType B0; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0; + ResPacketType T0; #endif EIGEN_ASM_COMMENT("mybegin"); - A0 = ei_pload(&blA[0*PacketSize]); - A1 = ei_pload(&blA[1*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + A1 = ei_pload(&blA[1*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[1*PacketSize]); + B0 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); MADD(pcj,A1,B0,C5,B0); - A0 = ei_pload(&blA[2*PacketSize]); - A1 = ei_pload(&blA[3*PacketSize]); - B0 = ei_pload(&blB[2*PacketSize]); + A0 = ei_pload(&blA[2*LhsPacketSize]); + A1 = ei_pload(&blA[3*LhsPacketSize]); + B0 = ei_pload(&blB[2*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[3*PacketSize]); + B0 = ei_pload(&blB[3*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); MADD(pcj,A1,B0,C5,B0); - A0 = ei_pload(&blA[4*PacketSize]); - A1 = ei_pload(&blA[5*PacketSize]); - B0 = ei_pload(&blB[4*PacketSize]); + A0 = ei_pload(&blA[4*LhsPacketSize]); + A1 = ei_pload(&blA[5*LhsPacketSize]); + B0 = ei_pload(&blB[4*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[5*PacketSize]); + B0 = ei_pload(&blB[5*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); MADD(pcj,A1,B0,C5,B0); - A0 = ei_pload(&blA[6*PacketSize]); - A1 = ei_pload(&blA[7*PacketSize]); - B0 = ei_pload(&blB[6*PacketSize]); + A0 = ei_pload(&blA[6*LhsPacketSize]); + A1 = ei_pload(&blA[7*LhsPacketSize]); + B0 = ei_pload(&blB[6*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[7*PacketSize]); + B0 = ei_pload(&blB[7*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); MADD(pcj,A1,B0,C5,B0); EIGEN_ASM_COMMENT("myend"); } else { - PacketType B0, B1, B2, B3, A0, A1; + LhsPacketType A0, A1; + RhsPacketType B0, B1, B2, B3; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0; + ResPacketType T0; #endif -EIGEN_ASM_COMMENT("mybegin"); - A0 = ei_pload(&blA[0*PacketSize]); - A1 = ei_pload(&blA[1*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); - B1 = ei_pload(&blB[1*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + A1 = ei_pload(&blA[1*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); + B1 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); - B2 = ei_pload(&blB[2*PacketSize]); + B2 = ei_pload(&blB[2*RhsPacketSize]); MADD(pcj,A1,B0,C4,B0); - B3 = ei_pload(&blB[3*PacketSize]); - B0 = ei_pload(&blB[4*PacketSize]); + B3 = ei_pload(&blB[3*RhsPacketSize]); + B0 = ei_pload(&blB[4*RhsPacketSize]); MADD(pcj,A0,B1,C1,T0); MADD(pcj,A1,B1,C5,B1); - B1 = ei_pload(&blB[5*PacketSize]); + B1 = ei_pload(&blB[5*RhsPacketSize]); MADD(pcj,A0,B2,C2,T0); MADD(pcj,A1,B2,C6,B2); - B2 = ei_pload(&blB[6*PacketSize]); + B2 = ei_pload(&blB[6*RhsPacketSize]); MADD(pcj,A0,B3,C3,T0); - A0 = ei_pload(&blA[2*PacketSize]); + A0 = ei_pload(&blA[2*LhsPacketSize]); MADD(pcj,A1,B3,C7,B3); - A1 = ei_pload(&blA[3*PacketSize]); - B3 = ei_pload(&blB[7*PacketSize]); + A1 = ei_pload(&blA[3*LhsPacketSize]); + B3 = ei_pload(&blB[7*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[8*PacketSize]); + B0 = ei_pload(&blB[8*RhsPacketSize]); MADD(pcj,A0,B1,C1,T0); MADD(pcj,A1,B1,C5,B1); - B1 = ei_pload(&blB[9*PacketSize]); + B1 = ei_pload(&blB[9*RhsPacketSize]); MADD(pcj,A0,B2,C2,T0); MADD(pcj,A1,B2,C6,B2); - B2 = ei_pload(&blB[10*PacketSize]); + B2 = ei_pload(&blB[10*RhsPacketSize]); MADD(pcj,A0,B3,C3,T0); - A0 = ei_pload(&blA[4*PacketSize]); + A0 = ei_pload(&blA[4*LhsPacketSize]); MADD(pcj,A1,B3,C7,B3); - A1 = ei_pload(&blA[5*PacketSize]); - B3 = ei_pload(&blB[11*PacketSize]); + A1 = ei_pload(&blA[5*LhsPacketSize]); + B3 = ei_pload(&blB[11*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[12*PacketSize]); + B0 = ei_pload(&blB[12*RhsPacketSize]); MADD(pcj,A0,B1,C1,T0); MADD(pcj,A1,B1,C5,B1); - B1 = ei_pload(&blB[13*PacketSize]); + B1 = ei_pload(&blB[13*RhsPacketSize]); MADD(pcj,A0,B2,C2,T0); MADD(pcj,A1,B2,C6,B2); - B2 = ei_pload(&blB[14*PacketSize]); + B2 = ei_pload(&blB[14*RhsPacketSize]); MADD(pcj,A0,B3,C3,T0); - A0 = ei_pload(&blA[6*PacketSize]); + A0 = ei_pload(&blA[6*LhsPacketSize]); MADD(pcj,A1,B3,C7,B3); - A1 = ei_pload(&blA[7*PacketSize]); - B3 = ei_pload(&blB[15*PacketSize]); + A1 = ei_pload(&blA[7*LhsPacketSize]); + B3 = ei_pload(&blB[15*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); MADD(pcj,A0,B1,C1,T0); @@ -363,10 +379,9 @@ EIGEN_ASM_COMMENT("mybegin"); MADD(pcj,A1,B2,C6,B2); MADD(pcj,A0,B3,C3,T0); MADD(pcj,A1,B3,C7,B3); -EIGEN_ASM_COMMENT("myend"); } - blB += 4*nr*PacketSize; + blB += 4*nr*RhsPacketSize; blA += 4*mr; } // process remaining peeled loop @@ -374,36 +389,38 @@ EIGEN_ASM_COMMENT("myend"); { if(nr==2) { - PacketType B0, A0, A1; + LhsPacketType A0, A1; + RhsPacketType B0; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0; + ResPacketType T0; #endif - A0 = ei_pload(&blA[0*PacketSize]); - A1 = ei_pload(&blA[1*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + A1 = ei_pload(&blA[1*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,B0); - B0 = ei_pload(&blB[1*PacketSize]); + B0 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); MADD(pcj,A1,B0,C5,B0); } else { - PacketType B0, B1, B2, B3, A0, A1; + LhsPacketType A0, A1; + RhsPacketType B0, B1, B2, B3; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0; + ResPacketType T0; #endif - A0 = ei_pload(&blA[0*PacketSize]); - A1 = ei_pload(&blA[1*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); - B1 = ei_pload(&blB[1*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + A1 = ei_pload(&blA[1*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); + B1 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); - B2 = ei_pload(&blB[2*PacketSize]); + B2 = ei_pload(&blB[2*RhsPacketSize]); MADD(pcj,A1,B0,C4,B0); - B3 = ei_pload(&blB[3*PacketSize]); + B3 = ei_pload(&blB[3*RhsPacketSize]); MADD(pcj,A0,B1,C1,T0); MADD(pcj,A1,B1,C5,B1); MADD(pcj,A0,B2,C2,T0); @@ -412,20 +429,20 @@ EIGEN_ASM_COMMENT("myend"); MADD(pcj,A1,B3,C7,B3); } - blB += nr*PacketSize; + blB += nr*RhsPacketSize; blA += mr; } - PacketType R0, R1, R2, R3, R4, R5, R6, R7; + ResPacketType R0, R1, R2, R3, R4, R5, R6, R7; R0 = ei_ploadu(r0); R1 = ei_ploadu(r1); if(nr==4) R2 = ei_ploadu(r2); if(nr==4) R3 = ei_ploadu(r3); - R4 = ei_ploadu(r0 + PacketSize); - R5 = ei_ploadu(r1 + PacketSize); - if(nr==4) R6 = ei_ploadu(r2 + PacketSize); - if(nr==4) R7 = ei_ploadu(r3 + PacketSize); + R4 = ei_ploadu(r0 + ResPacketSize); + R5 = ei_ploadu(r1 + ResPacketSize); + if(nr==4) R6 = ei_ploadu(r2 + ResPacketSize); + if(nr==4) R7 = ei_ploadu(r3 + ResPacketSize); C0 = ei_padd(R0, C0); C1 = ei_padd(R1, C1); @@ -440,129 +457,133 @@ EIGEN_ASM_COMMENT("myend"); ei_pstoreu(r1, C1); if(nr==4) ei_pstoreu(r2, C2); if(nr==4) ei_pstoreu(r3, C3); - ei_pstoreu(r0 + PacketSize, C4); - ei_pstoreu(r1 + PacketSize, C5); - if(nr==4) ei_pstoreu(r2 + PacketSize, C6); - if(nr==4) ei_pstoreu(r3 + PacketSize, C7); + ei_pstoreu(r0 + ResPacketSize, C4); + ei_pstoreu(r1 + ResPacketSize, C5); + if(nr==4) ei_pstoreu(r2 + ResPacketSize, C6); + if(nr==4) ei_pstoreu(r3 + ResPacketSize, C7); } - if(rows-peeled_mc>=PacketSize) + if(rows-peeled_mc>=LhsPacketSize) { Index i = peeled_mc; - const Scalar* blA = &blockA[i*strideA+offsetA*PacketSize]; + const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsPacketSize]; ei_prefetch(&blA[0]); // gets res block as register - PacketType C0, C1, C2, C3; + ResPacketType C0, C1, C2, C3; C0 = ei_ploadu(&res[(j2+0)*resStride + i]); C1 = ei_ploadu(&res[(j2+1)*resStride + i]); if(nr==4) C2 = ei_ploadu(&res[(j2+2)*resStride + i]); if(nr==4) C3 = ei_ploadu(&res[(j2+3)*resStride + i]); // performs "inner" product - const Scalar* blB = unpackedB; + const RhsScalar* blB = unpackedB; for(Index k=0; k<peeled_kc; k+=4) { if(nr==2) { - PacketType B0, B1, A0; + LhsPacketType A0; + RhsPacketType B0, B1; - A0 = ei_pload(&blA[0*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); - B1 = ei_pload(&blB[1*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); + B1 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B0 = ei_pload(&blB[2*PacketSize]); + B0 = ei_pload(&blB[2*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - A0 = ei_pload(&blA[1*PacketSize]); - B1 = ei_pload(&blB[3*PacketSize]); + A0 = ei_pload(&blA[1*LhsPacketSize]); + B1 = ei_pload(&blB[3*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B0 = ei_pload(&blB[4*PacketSize]); + B0 = ei_pload(&blB[4*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - A0 = ei_pload(&blA[2*PacketSize]); - B1 = ei_pload(&blB[5*PacketSize]); + A0 = ei_pload(&blA[2*LhsPacketSize]); + B1 = ei_pload(&blB[5*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B0 = ei_pload(&blB[6*PacketSize]); + B0 = ei_pload(&blB[6*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - A0 = ei_pload(&blA[3*PacketSize]); - B1 = ei_pload(&blB[7*PacketSize]); + A0 = ei_pload(&blA[3*LhsPacketSize]); + B1 = ei_pload(&blB[7*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); MADD(pcj,A0,B1,C1,B1); } else { - PacketType B0, B1, B2, B3, A0; + LhsPacketType A0; + RhsPacketType B0, B1, B2, B3; - A0 = ei_pload(&blA[0*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); - B1 = ei_pload(&blB[1*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); + B1 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B2 = ei_pload(&blB[2*PacketSize]); - B3 = ei_pload(&blB[3*PacketSize]); - B0 = ei_pload(&blB[4*PacketSize]); + B2 = ei_pload(&blB[2*RhsPacketSize]); + B3 = ei_pload(&blB[3*RhsPacketSize]); + B0 = ei_pload(&blB[4*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - B1 = ei_pload(&blB[5*PacketSize]); + B1 = ei_pload(&blB[5*RhsPacketSize]); MADD(pcj,A0,B2,C2,B2); - B2 = ei_pload(&blB[6*PacketSize]); + B2 = ei_pload(&blB[6*RhsPacketSize]); MADD(pcj,A0,B3,C3,B3); - A0 = ei_pload(&blA[1*PacketSize]); - B3 = ei_pload(&blB[7*PacketSize]); + A0 = ei_pload(&blA[1*LhsPacketSize]); + B3 = ei_pload(&blB[7*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B0 = ei_pload(&blB[8*PacketSize]); + B0 = ei_pload(&blB[8*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - B1 = ei_pload(&blB[9*PacketSize]); + B1 = ei_pload(&blB[9*RhsPacketSize]); MADD(pcj,A0,B2,C2,B2); - B2 = ei_pload(&blB[10*PacketSize]); + B2 = ei_pload(&blB[10*RhsPacketSize]); MADD(pcj,A0,B3,C3,B3); - A0 = ei_pload(&blA[2*PacketSize]); - B3 = ei_pload(&blB[11*PacketSize]); + A0 = ei_pload(&blA[2*LhsPacketSize]); + B3 = ei_pload(&blB[11*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); - B0 = ei_pload(&blB[12*PacketSize]); + B0 = ei_pload(&blB[12*RhsPacketSize]); MADD(pcj,A0,B1,C1,B1); - B1 = ei_pload(&blB[13*PacketSize]); + B1 = ei_pload(&blB[13*RhsPacketSize]); MADD(pcj,A0,B2,C2,B2); - B2 = ei_pload(&blB[14*PacketSize]); + B2 = ei_pload(&blB[14*RhsPacketSize]); MADD(pcj,A0,B3,C3,B3); - A0 = ei_pload(&blA[3*PacketSize]); - B3 = ei_pload(&blB[15*PacketSize]); + A0 = ei_pload(&blA[3*LhsPacketSize]); + B3 = ei_pload(&blB[15*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0); MADD(pcj,A0,B1,C1,B1); MADD(pcj,A0,B2,C2,B2); MADD(pcj,A0,B3,C3,B3); } - blB += 4*nr*PacketSize; - blA += 4*PacketSize; + blB += 4*nr*RhsPacketSize; + blA += 4*LhsPacketSize; } // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { if(nr==2) { - PacketType B0, A0; + LhsPacketType A0; + RhsPacketType B0; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0; + ResPacketType T0; #endif - A0 = ei_pload(&blA[0*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); - B0 = ei_pload(&blB[1*PacketSize]); + B0 = ei_pload(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C1,T0); } else { - PacketType B0, B1, B2, B3, A0; + LhsPacketType A0; + RhsPacketType B0, B1, B2, B3; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0, T1; + ResPacketType T0, T1; #endif - A0 = ei_pload(&blA[0*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); - B1 = ei_pload(&blB[1*PacketSize]); - B2 = ei_pload(&blB[2*PacketSize]); - B3 = ei_pload(&blB[3*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); + B1 = ei_pload(&blB[1*RhsPacketSize]); + B2 = ei_pload(&blB[2*RhsPacketSize]); + B3 = ei_pload(&blB[3*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A0,B1,C1,T1); @@ -570,8 +591,8 @@ EIGEN_ASM_COMMENT("myend"); MADD(pcj,A0,B3,C3,T1); } - blB += nr*PacketSize; - blA += PacketSize; + blB += nr*RhsPacketSize; + blA += LhsPacketSize; } ei_pstoreu(&res[(j2+0)*resStride + i], C0); @@ -581,40 +602,42 @@ EIGEN_ASM_COMMENT("myend"); } for(Index i=peeled_mc2; i<rows; i++) { - const Scalar* blA = &blockA[i*strideA+offsetA]; + const LhsScalar* blA = &blockA[i*strideA+offsetA]; ei_prefetch(&blA[0]); // gets a 1 x nr res block as registers - Scalar C0(0), C1(0), C2(0), C3(0); + ResScalar C0(0), C1(0), C2(0), C3(0); // TODO directly use blockB ??? - const Scalar* blB = unpackedB;//&blockB[j2*strideB+offsetB*nr]; + const RhsScalar* blB = unpackedB;//&blockB[j2*strideB+offsetB*nr]; for(Index k=0; k<depth; k++) { if(nr==2) { - Scalar B0, A0; + LhsScalar A0; + RhsScalar B0; #ifndef EIGEN_HAS_FUSE_CJMADD - Scalar T0; + ResScalar T0; #endif A0 = blA[k]; - B0 = blB[0*PacketSize]; + B0 = blB[0*RhsPacketSize]; MADD(cj,A0,B0,C0,T0); - B0 = blB[1*PacketSize]; + B0 = blB[1*RhsPacketSize]; MADD(cj,A0,B0,C1,T0); } else { - Scalar B0, B1, B2, B3, A0; + LhsScalar A0; + RhsScalar B0, B1, B2, B3; #ifndef EIGEN_HAS_FUSE_CJMADD - Scalar T0, T1; + ResScalar T0, T1; #endif A0 = blA[k]; - B0 = blB[0*PacketSize]; - B1 = blB[1*PacketSize]; - B2 = blB[2*PacketSize]; - B3 = blB[3*PacketSize]; + B0 = blB[0*RhsPacketSize]; + B1 = blB[1*RhsPacketSize]; + B2 = blB[2*RhsPacketSize]; + B3 = blB[3*RhsPacketSize]; MADD(cj,A0,B0,C0,T0); MADD(cj,A0,B1,C1,T1); @@ -622,7 +645,7 @@ EIGEN_ASM_COMMENT("myend"); MADD(cj,A0,B3,C3,T1); } - blB += nr*PacketSize; + blB += nr*RhsPacketSize; } res[(j2+0)*resStride + i] += C0; res[(j2+1)*resStride + i] += C1; @@ -637,78 +660,79 @@ EIGEN_ASM_COMMENT("myend"); { // unpack B { - const Scalar* blB = &blockB[j2*strideB+offsetB]; + const RhsScalar* blB = &blockB[j2*strideB+offsetB]; for(Index k=0; k<depth; k++) - ei_pstore(&unpackedB[k*PacketSize], ei_pset1(blB[k])); + ei_pstore(&unpackedB[k*RhsPacketSize], ei_pset1(blB[k])); } for(Index i=0; i<peeled_mc; i+=mr) { - const Scalar* blA = &blockA[i*strideA+offsetA*mr]; + const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; ei_prefetch(&blA[0]); // TODO move the res loads to the stores // get res block as registers - PacketType C0, C4; + ResPacketType C0, C4; C0 = ei_ploadu(&res[(j2+0)*resStride + i]); - C4 = ei_ploadu(&res[(j2+0)*resStride + i + PacketSize]); + C4 = ei_ploadu(&res[(j2+0)*resStride + i + ResPacketSize]); - const Scalar* blB = unpackedB; + const RhsScalar* blB = unpackedB; for(Index k=0; k<depth; k++) { - PacketType B0, A0, A1; + LhsPacketType A0, A1; + RhsPacketType B0; #ifndef EIGEN_HAS_FUSE_CJMADD - PacketType T0, T1; + ResPacketType T0, T1; #endif - A0 = ei_pload(&blA[0*PacketSize]); - A1 = ei_pload(&blA[1*PacketSize]); - B0 = ei_pload(&blB[0*PacketSize]); + A0 = ei_pload(&blA[0*LhsPacketSize]); + A1 = ei_pload(&blA[1*LhsPacketSize]); + B0 = ei_pload(&blB[0*RhsPacketSize]); MADD(pcj,A0,B0,C0,T0); MADD(pcj,A1,B0,C4,T1); - blB += PacketSize; + blB += RhsPacketSize; blA += mr; } ei_pstoreu(&res[(j2+0)*resStride + i], C0); - ei_pstoreu(&res[(j2+0)*resStride + i + PacketSize], C4); + ei_pstoreu(&res[(j2+0)*resStride + i + ResPacketSize], C4); } - if(rows-peeled_mc>=PacketSize) + if(rows-peeled_mc>=LhsPacketSize) { Index i = peeled_mc; - const Scalar* blA = &blockA[i*strideA+offsetA*PacketSize]; + const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsPacketSize]; ei_prefetch(&blA[0]); - PacketType C0 = ei_ploadu(&res[(j2+0)*resStride + i]); + ResPacketType C0 = ei_ploadu(&res[(j2+0)*resStride + i]); - const Scalar* blB = unpackedB; + const RhsScalar* blB = unpackedB; for(Index k=0; k<depth; k++) { - PacketType T0; + ResPacketType T0; MADD(pcj,ei_pload(blA), ei_pload(blB), C0, T0); - blB += PacketSize; - blA += PacketSize; + blB += RhsPacketSize; + blA += LhsPacketSize; } ei_pstoreu(&res[(j2+0)*resStride + i], C0); } for(Index i=peeled_mc2; i<rows; i++) { - const Scalar* blA = &blockA[i*strideA+offsetA]; + const LhsScalar* blA = &blockA[i*strideA+offsetA]; ei_prefetch(&blA[0]); // gets a 1 x 1 res block as registers - Scalar C0(0); + ResScalar C0(0); // FIXME directly use blockB ?? - const Scalar* blB = unpackedB; + const RhsScalar* blB = unpackedB; for(Index k=0; k<depth; k++) { #ifndef EIGEN_HAS_FUSE_CJMADD - Scalar T0; + ResScalar T0; #endif - MADD(cj,blA[k], blB[k*PacketSize], C0, T0); + MADD(cj,blA[k], blB[k*RhsPacketSize], C0, T0); } res[(j2+0)*resStride + i] += C0; } diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 2ae78c1e7..c480ce14d 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -29,26 +29,25 @@ template<typename _LhsScalar, typename _RhsScalar> class ei_level3_blocking; /* Specialization for a row-major destination matrix => simple transposition of the product */ template< - typename Scalar, typename Index, - int LhsStorageOrder, bool ConjugateLhs, - int RhsStorageOrder, bool ConjugateRhs> -struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,RowMajor> + typename Index, + typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> +struct ei_general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor> { + typedef typename ei_scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run( Index rows, Index cols, Index depth, - const Scalar* lhs, Index lhsStride, - const Scalar* rhs, Index rhsStride, - Scalar* res, Index resStride, - Scalar alpha, - ei_level3_blocking<Scalar,Scalar>& blocking, + const LhsScalar* lhs, Index lhsStride, + const RhsScalar* rhs, Index rhsStride, + ResScalar* res, Index resStride, + ResScalar alpha, + ei_level3_blocking<RhsScalar,LhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) { // transpose the product such that the result is column major - ei_general_matrix_matrix_product<Scalar, Index, - RhsStorageOrder==RowMajor ? ColMajor : RowMajor, - ConjugateRhs, - LhsStorageOrder==RowMajor ? ColMajor : RowMajor, - ConjugateLhs, + ei_general_matrix_matrix_product<Index, + RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, + LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, ColMajor> ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info); } @@ -57,24 +56,24 @@ struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLh /* Specialization for a col-major destination matrix * => Blocking algorithm following Goto's paper */ template< - typename Scalar, typename Index, - int LhsStorageOrder, bool ConjugateLhs, - int RhsStorageOrder, bool ConjugateRhs> -struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> + typename Index, + typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> +struct ei_general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor> { +typedef typename ei_scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static void run(Index rows, Index cols, Index depth, - const Scalar* _lhs, Index lhsStride, - const Scalar* _rhs, Index rhsStride, - Scalar* res, Index resStride, - Scalar alpha, - ei_level3_blocking<Scalar,Scalar>& blocking, + const LhsScalar* _lhs, Index lhsStride, + const RhsScalar* _rhs, Index rhsStride, + ResScalar* res, Index resStride, + ResScalar alpha, + ei_level3_blocking<LhsScalar,RhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) { - ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); - ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); + ei_const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); + ei_const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); - typedef typename ei_packet_traits<Scalar>::type PacketType; - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<LhsScalar,RhsScalar> Blocking; Index kc = blocking.kc(); // cache block size along the K direction Index mc = std::min(rows,blocking.mc()); // cache block size along the M direction @@ -83,9 +82,9 @@ static void run(Index rows, Index cols, Index depth, // FIXME starting from SSE3, normal complex product cannot be optimized as well as // conjugate product, therefore it is better to conjugate during the copies. // With SSE2, this is the other way round. - ei_gemm_pack_lhs<Scalar, Index, Blocking::mr, LhsStorageOrder, ConjugateLhs> pack_lhs; - ei_gemm_pack_rhs<Scalar, Index, Blocking::nr, RhsStorageOrder, ConjugateRhs> pack_rhs; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr> gebp; + ei_gemm_pack_lhs<LhsScalar, Index, Blocking::mr, LhsStorageOrder, ConjugateLhs> pack_lhs; + ei_gemm_pack_rhs<RhsScalar, Index, Blocking::nr, RhsStorageOrder, ConjugateRhs> pack_rhs; + ei_gebp_kernel<LhsScalar, RhsScalar, Index, Blocking::mr, Blocking::nr> gebp; // if (ConjugateRhs) // alpha = ei_conj(alpha); @@ -173,10 +172,10 @@ static void run(Index rows, Index cols, Index depth, // this is the sequential version! std::size_t sizeA = kc*mc; std::size_t sizeB = kc*cols; - std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr; - Scalar *blockA = blocking.blockA()==0 ? ei_aligned_stack_new(Scalar, sizeA) : blocking.blockA(); - Scalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(Scalar, sizeB) : blocking.blockB(); - Scalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(Scalar, sizeW) : blocking.blockW(); + std::size_t sizeW = kc*ei_packet_traits<RhsScalar>::size*Blocking::nr; + LhsScalar *blockA = blocking.blockA()==0 ? ei_aligned_stack_new(LhsScalar, sizeA) : blocking.blockA(); + RhsScalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(RhsScalar, sizeB) : blocking.blockB(); + RhsScalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(RhsScalar, sizeW) : blocking.blockW(); // For each horizontal panel of the rhs, and corresponding panel of the lhs... // (==GEMM_VAR1) @@ -208,9 +207,9 @@ static void run(Index rows, Index cols, Index depth, } } - if(blocking.blockA()==0) ei_aligned_stack_delete(Scalar, blockA, kc*mc); - if(blocking.blockB()==0) ei_aligned_stack_delete(Scalar, blockB, sizeB); - if(blocking.blockW()==0) ei_aligned_stack_delete(Scalar, blockW, sizeW); + if(blocking.blockA()==0) ei_aligned_stack_delete(LhsScalar, blockA, kc*mc); + if(blocking.blockB()==0) ei_aligned_stack_delete(RhsScalar, blockB, sizeB); + if(blocking.blockW()==0) ei_aligned_stack_delete(RhsScalar, blockW, sizeW); } } @@ -245,8 +244,8 @@ struct ei_gemm_functor cols = m_rhs.cols(); Gemm::run(rows, cols, m_lhs.cols(), - (const Scalar*)&(m_lhs.const_cast_derived().coeffRef(row,0)), m_lhs.outerStride(), - (const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,col)), m_rhs.outerStride(), + /*(const Scalar*)*/&(m_lhs.const_cast_derived().coeffRef(row,0)), m_lhs.outerStride(), + /*(const Scalar*)*/&(m_rhs.const_cast_derived().coeffRef(0,col)), m_rhs.outerStride(), (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), m_actualAlpha, m_blocking, info); } @@ -305,7 +304,7 @@ class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols }; typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar; typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar; - typedef ei_product_blocking_traits<RhsScalar> Blocking; + typedef ei_product_blocking_traits<LhsScalar,RhsScalar> Blocking; enum { SizeA = ActualRows * MaxDepth, SizeB = ActualCols * MaxDepth, @@ -345,7 +344,7 @@ class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols }; typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar; typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar; - typedef ei_product_blocking_traits<RhsScalar> Blocking; + typedef ei_product_blocking_traits<LhsScalar,RhsScalar> Blocking; DenseIndex m_sizeA; DenseIndex m_sizeB; @@ -410,10 +409,15 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) { - EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret), - YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) + // TODO add a weak static assert +// EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret), +// YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef Scalar ResScalar; + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); @@ -424,15 +428,15 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); - typedef ei_gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, + typedef ei_gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; typedef ei_gemm_functor< Scalar, Index, ei_general_matrix_matrix_product< - Scalar, Index, - (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + Index, + LhsScalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + RhsScalar, (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor; diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index d8fa1bd9c..3bedafb68 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -256,7 +256,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate if (ConjugateRhs) alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; Index kc = size; // cache block size along the K direction Index mc = rows; // cache block size along the M direction @@ -270,7 +270,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB); Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; ei_symm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed; @@ -341,7 +341,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,Conjugat if (ConjugateRhs) alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; Index kc = size; // cache block size along the K direction Index mc = rows; // cache block size along the M direction @@ -353,7 +353,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,Conjugat Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB); Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs; ei_symm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index 40c0c9aac..c45a3bac7 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -68,7 +68,7 @@ struct ei_selfadjoint_product<Scalar, Index, MatStorageOrder, ColMajor, AAT, UpL if(AAT) alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; Index kc = depth; // cache block size along the K direction Index mc = size; // cache block size along the M direction @@ -89,7 +89,7 @@ struct ei_selfadjoint_product<Scalar, Index, MatStorageOrder, ColMajor, AAT, UpL ConjRhs = NumTraits<Scalar>::IsComplex && AAT }; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjLhs, ConjRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, ConjLhs, ConjRhs> gebp_kernel; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,MatStorageOrder==RowMajor ? ColMajor : RowMajor> pack_rhs; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,MatStorageOrder, false> pack_lhs; ei_sybb_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjLhs, ConjRhs, UpLo> sybb; @@ -175,7 +175,7 @@ struct ei_sybb_kernel }; void operator()(Scalar* res, Index resStride, const Scalar* blockA, const Scalar* blockB, Index size, Index depth, Scalar* workspace) { - ei_gebp_kernel<Scalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel; Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer; // let's process the block per panel of actual_mc x BlockSize, diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index be9362958..ce9a76654 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -108,7 +108,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, if (ConjugateRhs) alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; enum { SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), IsLower = (Mode&Lower) == Lower @@ -129,7 +129,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, triangularBuffer.setZero(); triangularBuffer.diagonal().setOnes(); - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; @@ -234,7 +234,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false, if (ConjugateRhs) alpha = ei_conj(alpha); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; enum { SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), IsLower = (Mode&Lower) == Lower @@ -254,7 +254,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false, triangularBuffer.setZero(); triangularBuffer.diagonal().setOnes(); - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder,false,true> pack_rhs_panel; diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 0fce7159e..d6ae2131d 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -57,7 +57,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStora ei_const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride); ei_blas_data_mapper<Scalar, Index, ColMajor> other(_other,otherStride); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; enum { SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), IsLower = (Mode&Lower) == Lower @@ -74,7 +74,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStora Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr; ei_conj_if<Conjugate> conj; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, Conjugate, false> gebp_kernel; + ei_gebp_kernel<Scalar, Scalar, Index, Blocking::mr, Blocking::nr, Conjugate, false> gebp_kernel; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,TriStorageOrder> pack_lhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr, ColMajor, false, true> pack_rhs; @@ -191,7 +191,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStor ei_const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride); ei_blas_data_mapper<Scalar, Index, ColMajor> lhs(_other,otherStride); - typedef ei_product_blocking_traits<Scalar> Blocking; + typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; enum { RhsStorageOrder = TriStorageOrder, SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), @@ -212,7 +212,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStor Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr; ei_conj_if<Conjugate> conj; - ei_gebp_kernel<Scalar, Index, Blocking::mr, Blocking::nr, false, Conjugate> gebp_kernel; + ei_gebp_kernel<Scalar,Scalar, Index, Blocking::mr, Blocking::nr, false, Conjugate> gebp_kernel; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder,false,true> pack_rhs_panel; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr, ColMajor, false, true> pack_lhs_panel; diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 38c86511c..1b7d03722 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -29,7 +29,7 @@ // implement and control fast level 2 and level 3 BLAS-like routines. // forward declarations -template<typename Scalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> +template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> struct ei_gebp_kernel; template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> @@ -39,9 +39,9 @@ template<typename Scalar, typename Index, int mr, int StorageOrder, bool Conjuga struct ei_gemm_pack_lhs; template< - typename Scalar, typename Index, - int LhsStorageOrder, bool ConjugateLhs, - int RhsStorageOrder, bool ConjugateRhs, + typename Index, + typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder> struct ei_general_matrix_matrix_product; @@ -89,6 +89,25 @@ template<typename RealScalar> struct ei_conj_helper<std::complex<RealScalar>, st { return Scalar(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } }; +template<typename RealScalar> struct ei_conj_helper<std::complex<RealScalar>, RealScalar, false,false> +{ + typedef std::complex<RealScalar> Scalar; + EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const { return ei_padd(c, ei_pmul(x,y)); } + + EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const + { return ei_pmul(x,y); } +}; + +template<typename RealScalar> struct ei_conj_helper<RealScalar, std::complex<RealScalar>, false,false> +{ + typedef std::complex<RealScalar> Scalar; + EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const { return ei_padd(c, pmul(x,y)); } + + EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const + { return x * y; } +}; + + // Lightweight helper class to access matrix coefficients. // Yes, this is somehow redundant with Map<>, but this version is much much lighter, // and so I hope better compilation performance (time and code quality). @@ -118,29 +137,29 @@ class ei_const_blas_data_mapper }; // Defines various constant controlling register blocking for matrix-matrix algorithms. -template<typename Scalar> +template<typename LhsScalar, typename RhsScalar> struct ei_product_blocking_traits; + +template<typename LhsScalar, typename RhsScalar> struct ei_product_blocking_traits { - typedef typename ei_packet_traits<Scalar>::type PacketType; enum { - PacketSize = sizeof(PacketType)/sizeof(Scalar), + LhsPacketSize = ei_packet_traits<LhsScalar>::size, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, // register block size along the N direction (must be either 2 or 4) nr = NumberOfRegisters/4, // register block size along the M direction (currently, this one cannot be modified) - mr = 2 * PacketSize + mr = 2 * LhsPacketSize }; }; template<typename Real> -struct ei_product_blocking_traits<std::complex<Real> > +struct ei_product_blocking_traits<std::complex<Real>, std::complex<Real> > { typedef std::complex<Real> Scalar; - typedef typename ei_packet_traits<Scalar>::type PacketType; enum { - PacketSize = sizeof(PacketType)/sizeof(Scalar), + PacketSize = ei_packet_traits<Scalar>::size, nr = 2, mr = 2 * PacketSize }; |