diff options
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 673 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 48 |
2 files changed, 291 insertions, 430 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index ba6fad246..8a398d912 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -161,11 +161,11 @@ public: NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, - // register block size along the N direction (must be either 2 or 4) - nr = NumberOfRegisters/4, + // register block size along the N direction (must be either 4 or 8) + nr = NumberOfRegisters/2, // register block size along the M direction (currently, this one cannot be modified) - mr = 2 * LhsPacketSize, + mr = LhsPacketSize, WorkSpaceFactor = nr * RhsPacketSize, @@ -187,6 +187,16 @@ public: { p = pset1<ResPacket>(ResScalar(0)); } + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + { + pbroadcast4(b, b0, b1, b2, b3); + } + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) + { + pbroadcast2(b, b0, b1); + } EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { @@ -230,8 +240,8 @@ public: ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, - nr = NumberOfRegisters/4, - mr = 2 * LhsPacketSize, + nr = NumberOfRegisters/2, + mr = LhsPacketSize, WorkSpaceFactor = nr*RhsPacketSize, LhsProgress = LhsPacketSize, @@ -262,6 +272,16 @@ public: { dest = pload<LhsPacket>(a); } + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + { + pbroadcast4(b, b0, b1, b2, b3); + } + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) + { + pbroadcast2(b, b0, b1); + } EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const { @@ -304,8 +324,9 @@ public: RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1, ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, - nr = 2, - mr = 2 * ResPacketSize, + // FIXME: should depend on NumberOfRegisters + nr = 4, + mr = ResPacketSize, WorkSpaceFactor = Vectorizable ? 2*nr*RealPacketSize : nr, LhsProgress = ResPacketSize, @@ -333,16 +354,37 @@ public: p.second = pset1<RealPacket>(RealScalar(0)); } + // Scalar path EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ResPacket& dest) const { dest = pset1<ResPacket>(*b); } + // Vectorized path EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket& dest) const { dest.first = pset1<RealPacket>(real(*b)); dest.second = pset1<RealPacket>(imag(*b)); } + + // linking error if instantiated without being optimized out: + void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); + + // Vectorized path + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacket& b0, DoublePacket& b1) + { + // FIXME not sure that's the best way to implement it! + loadRhs(b+0, b0); + loadRhs(b+1, b1); + } + + // Scalar path + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsScalar& b0, RhsScalar& b1) + { + // FIXME not sure that's the best way to implement it! + loadRhs(b+0, b0); + loadRhs(b+1, b1); + } // nothing special here EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const @@ -414,8 +456,9 @@ public: ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, + // FIXME: should depend on NumberOfRegisters nr = 4, - mr = 2*ResPacketSize, + mr = ResPacketSize, WorkSpaceFactor = nr*RhsPacketSize, LhsProgress = ResPacketSize, @@ -441,6 +484,16 @@ public: { dest = pset1<RhsPacket>(*b); } + + // linking error if instantiated without being optimized out: + void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) + { + // FIXME not sure that's the best way to implement it! + b0 = pload1<RhsPacket>(b+0); + b1 = pload1<RhsPacket>(b+1); + } EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const { @@ -511,11 +564,9 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> if(strideA==-1) strideA = depth; if(strideB==-1) strideB = depth; conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; -// conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; Index packet_cols = (cols/nr) * nr; + // Here we assume that mr==LhsProgress const Index peeled_mc = (rows/mr)*mr; - // FIXME: - const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= LhsProgress ? LhsProgress : 0); const Index peeled_kc = (depth/4)*4; // loops on each micro vertical panel of rhs (depth x nr) @@ -527,144 +578,88 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> for(Index i=0; i<peeled_mc; i+=mr) { const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; - prefetch(&blA[0]); + // prefetch(&blA[0]); // gets res block as register AccPacket C0, C1, C2, C3, C4, C5, C6, C7; traits.initAcc(C0); traits.initAcc(C1); - if(nr==4) traits.initAcc(C2); - if(nr==4) traits.initAcc(C3); - traits.initAcc(C4); - traits.initAcc(C5); - if(nr==4) traits.initAcc(C6); - if(nr==4) traits.initAcc(C7); + traits.initAcc(C2); + traits.initAcc(C3); + if(nr==8) traits.initAcc(C4); + if(nr==8) traits.initAcc(C5); + if(nr==8) traits.initAcc(C6); + if(nr==8) traits.initAcc(C7); ResScalar* r0 = &res[(j2+0)*resStride + i]; - ResScalar* r1 = r0 + resStride; - ResScalar* r2 = r1 + resStride; - ResScalar* r3 = r2 + resStride; - - prefetch(r0+16); - prefetch(r1+16); - prefetch(r2+16); - prefetch(r3+16); - - // performs "inner" product - // TODO let's check wether the folowing peeled loop could not be - // optimized via optimal prefetching from one loop to the other + + // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; + LhsPacket A0, A1; + // uncomment for register prefetching + // traits.loadLhs(blA, A0); for(Index k=0; k<peeled_kc; k+=4) { - if(nr==2) + if(nr==4) { - LhsPacket A0, A1; - RhsPacket B_0; - RhsPacket T0; + EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 4"); -EIGEN_ASM_COMMENT("mybegin2"); - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadLhs(&blA[1*LhsProgress], A1); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[1*RhsProgress], B_0); - traits.madd(A0,B_0,C1,T0); - traits.madd(A1,B_0,C5,B_0); - - traits.loadLhs(&blA[2*LhsProgress], A0); - traits.loadLhs(&blA[3*LhsProgress], A1); - traits.loadRhs(&blB[2*RhsProgress], B_0); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[3*RhsProgress], B_0); - traits.madd(A0,B_0,C1,T0); - traits.madd(A1,B_0,C5,B_0); - - traits.loadLhs(&blA[4*LhsProgress], A0); - traits.loadLhs(&blA[5*LhsProgress], A1); - traits.loadRhs(&blB[4*RhsProgress], B_0); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[5*RhsProgress], B_0); - traits.madd(A0,B_0,C1,T0); - traits.madd(A1,B_0,C5,B_0); - - traits.loadLhs(&blA[6*LhsProgress], A0); - traits.loadLhs(&blA[7*LhsProgress], A1); - traits.loadRhs(&blB[6*RhsProgress], B_0); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[7*RhsProgress], B_0); - traits.madd(A0,B_0,C1,T0); - traits.madd(A1,B_0,C5,B_0); -EIGEN_ASM_COMMENT("myend"); + RhsPacket B_0, B1; + +#define EIGEN_GEBGP_ONESTEP4(K) \ + traits.loadLhs(&blA[K*LhsProgress], A0); \ + traits.broadcastRhs(&blB[0+4*K*RhsProgress], B_0, B1); \ + traits.madd(A0, B_0,C0, B_0); \ + traits.madd(A0, B1, C1, B1); \ + traits.broadcastRhs(&blB[2+4*K*RhsProgress], B_0, B1); \ + traits.madd(A0, B_0,C2, B_0); \ + traits.madd(A0, B1, C3, B1) + + EIGEN_GEBGP_ONESTEP4(0); + EIGEN_GEBGP_ONESTEP4(1); + EIGEN_GEBGP_ONESTEP4(2); + EIGEN_GEBGP_ONESTEP4(3); } - else + else // nr==8 { -EIGEN_ASM_COMMENT("mybegin4"); - LhsPacket A0, A1; + EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 8"); RhsPacket B_0, B1, B2, B3; - RhsPacket T0; - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadLhs(&blA[1*LhsProgress], A1); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - - traits.madd(A0,B_0,C0,T0); - traits.loadRhs(&blB[2*RhsProgress], B2); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[3*RhsProgress], B3); - traits.loadRhs(&blB[4*RhsProgress], B_0); - traits.madd(A0,B1,C1,T0); - traits.madd(A1,B1,C5,B1); - traits.loadRhs(&blB[5*RhsProgress], B1); - traits.madd(A0,B2,C2,T0); - traits.madd(A1,B2,C6,B2); - traits.loadRhs(&blB[6*RhsProgress], B2); - traits.madd(A0,B3,C3,T0); - traits.loadLhs(&blA[2*LhsProgress], A0); - traits.madd(A1,B3,C7,B3); - traits.loadLhs(&blA[3*LhsProgress], A1); - traits.loadRhs(&blB[7*RhsProgress], B3); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[8*RhsProgress], B_0); - traits.madd(A0,B1,C1,T0); - traits.madd(A1,B1,C5,B1); - traits.loadRhs(&blB[9*RhsProgress], B1); - traits.madd(A0,B2,C2,T0); - traits.madd(A1,B2,C6,B2); - traits.loadRhs(&blB[10*RhsProgress], B2); - traits.madd(A0,B3,C3,T0); - traits.loadLhs(&blA[4*LhsProgress], A0); - traits.madd(A1,B3,C7,B3); - traits.loadLhs(&blA[5*LhsProgress], A1); - traits.loadRhs(&blB[11*RhsProgress], B3); - - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[12*RhsProgress], B_0); - traits.madd(A0,B1,C1,T0); - traits.madd(A1,B1,C5,B1); - traits.loadRhs(&blB[13*RhsProgress], B1); - traits.madd(A0,B2,C2,T0); - traits.madd(A1,B2,C6,B2); - traits.loadRhs(&blB[14*RhsProgress], B2); - traits.madd(A0,B3,C3,T0); - traits.loadLhs(&blA[6*LhsProgress], A0); - traits.madd(A1,B3,C7,B3); - traits.loadLhs(&blA[7*LhsProgress], A1); - traits.loadRhs(&blB[15*RhsProgress], B3); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.madd(A0,B1,C1,T0); - traits.madd(A1,B1,C5,B1); - traits.madd(A0,B2,C2,T0); - traits.madd(A1,B2,C6,B2); - traits.madd(A0,B3,C3,T0); - traits.madd(A1,B3,C7,B3); + // The following version is faster on some architures + // but sometimes leads to segfaults because it might read one packet outside the bounds + // To test it, you also need to uncomment the initialization of A0 above and the copy of A1 to A0 below. +#if 0 +#define EIGEN_GEBGP_ONESTEP8(K,L,M) \ + traits.loadLhs(&blA[(K+1)*LhsProgress], L); \ + traits.broadcastRhs(&blB[0+8*K*RhsProgress], B_0, B1, B2, B3); \ + traits.madd(M, B_0,C0, B_0); \ + traits.madd(M, B1, C1, B1); \ + traits.madd(M, B2, C2, B2); \ + traits.madd(M, B3, C3, B3); \ + traits.broadcastRhs(&blB[4+8*K*RhsProgress], B_0, B1, B2, B3); \ + traits.madd(M, B_0,C4, B_0); \ + traits.madd(M, B1, C5, B1); \ + traits.madd(M, B2, C6, B2); \ + traits.madd(M, B3, C7, B3) +#endif + +#define EIGEN_GEBGP_ONESTEP8(K,L,M) \ + traits.loadLhs(&blA[K*LhsProgress], A0); \ + traits.broadcastRhs(&blB[0+8*K*RhsProgress], B_0, B1, B2, B3); \ + traits.madd(A0, B_0,C0, B_0); \ + traits.madd(A0, B1, C1, B1); \ + traits.madd(A0, B2, C2, B2); \ + traits.madd(A0, B3, C3, B3); \ + traits.broadcastRhs(&blB[4+8*K*RhsProgress], B_0, B1, B2, B3); \ + traits.madd(A0, B_0,C4, B_0); \ + traits.madd(A0, B1, C5, B1); \ + traits.madd(A0, B2, C6, B2); \ + traits.madd(A0, B3, C7, B3) + + EIGEN_GEBGP_ONESTEP8(0,A1,A0); + EIGEN_GEBGP_ONESTEP8(1,A0,A1); + EIGEN_GEBGP_ONESTEP8(2,A1,A0); + EIGEN_GEBGP_ONESTEP8(3,A0,A1); } blB += 4*nr*RhsProgress; @@ -673,63 +668,40 @@ EIGEN_ASM_COMMENT("mybegin4"); // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - if(nr==2) + if(nr==4) { - LhsPacket A0, A1; - RhsPacket B_0; - RhsPacket T0; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadLhs(&blA[1*LhsProgress], A1); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[1*RhsProgress], B_0); - traits.madd(A0,B_0,C1,T0); - traits.madd(A1,B_0,C5,B_0); + RhsPacket B_0, B1; + EIGEN_GEBGP_ONESTEP4(0); } - else + else // nr == 8 { - LhsPacket A0, A1; RhsPacket B_0, B1, B2, B3; - RhsPacket T0; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadLhs(&blA[1*LhsProgress], A1); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - - traits.madd(A0,B_0,C0,T0); - traits.loadRhs(&blB[2*RhsProgress], B2); - traits.madd(A1,B_0,C4,B_0); - traits.loadRhs(&blB[3*RhsProgress], B3); - traits.madd(A0,B1,C1,T0); - traits.madd(A1,B1,C5,B1); - traits.madd(A0,B2,C2,T0); - traits.madd(A1,B2,C6,B2); - traits.madd(A0,B3,C3,T0); - traits.madd(A1,B3,C7,B3); + EIGEN_GEBGP_ONESTEP8(0,A1,A0); + // uncomment for register prefetching + // A0 = A1; } blB += nr*RhsProgress; blA += mr; } +#undef EIGEN_GEBGP_ONESTEP4 +#undef EIGEN_GEBGP_ONESTEP8 - if(nr==4) + if(nr==8) { ResPacket R0, R1, R2, R3, R4, R5, R6; ResPacket alphav = pset1<ResPacket>(alpha); - R0 = ploadu<ResPacket>(r0); - R1 = ploadu<ResPacket>(r1); - R2 = ploadu<ResPacket>(r2); - R3 = ploadu<ResPacket>(r3); - R4 = ploadu<ResPacket>(r0 + ResPacketSize); - R5 = ploadu<ResPacket>(r1 + ResPacketSize); - R6 = ploadu<ResPacket>(r2 + ResPacketSize); + R0 = ploadu<ResPacket>(r0+0*resStride); + R1 = ploadu<ResPacket>(r0+1*resStride); + R2 = ploadu<ResPacket>(r0+2*resStride); + R3 = ploadu<ResPacket>(r0+3*resStride); + R4 = ploadu<ResPacket>(r0+4*resStride); + R5 = ploadu<ResPacket>(r0+5*resStride); + R6 = ploadu<ResPacket>(r0+6*resStride); traits.acc(C0, alphav, R0); - pstoreu(r0, R0); - R0 = ploadu<ResPacket>(r3 + ResPacketSize); + pstoreu(r0+0*resStride, R0); + R0 = ploadu<ResPacket>(r0+7*resStride); traits.acc(C1, alphav, R1); traits.acc(C2, alphav, R2); @@ -739,232 +711,107 @@ EIGEN_ASM_COMMENT("mybegin4"); traits.acc(C6, alphav, R6); traits.acc(C7, alphav, R0); - pstoreu(r1, R1); - pstoreu(r2, R2); - pstoreu(r3, R3); - pstoreu(r0 + ResPacketSize, R4); - pstoreu(r1 + ResPacketSize, R5); - pstoreu(r2 + ResPacketSize, R6); - pstoreu(r3 + ResPacketSize, R0); + pstoreu(r0+1*resStride, R1); + pstoreu(r0+2*resStride, R2); + pstoreu(r0+3*resStride, R3); + pstoreu(r0+4*resStride, R4); + pstoreu(r0+5*resStride, R5); + pstoreu(r0+6*resStride, R6); + pstoreu(r0+7*resStride, R0); } - else + else // nr==4 { - ResPacket R0, R1, R4; + ResPacket R0, R1, R2; ResPacket alphav = pset1<ResPacket>(alpha); - R0 = ploadu<ResPacket>(r0); - R1 = ploadu<ResPacket>(r1); - R4 = ploadu<ResPacket>(r0 + ResPacketSize); + R0 = ploadu<ResPacket>(r0+0*resStride); + R1 = ploadu<ResPacket>(r0+1*resStride); + R2 = ploadu<ResPacket>(r0+2*resStride); traits.acc(C0, alphav, R0); - pstoreu(r0, R0); - R0 = ploadu<ResPacket>(r1 + ResPacketSize); + pstoreu(r0+0*resStride, R0); + R0 = ploadu<ResPacket>(r0+3*resStride); + traits.acc(C1, alphav, R1); - traits.acc(C4, alphav, R4); - traits.acc(C5, alphav, R0); - pstoreu(r1, R1); - pstoreu(r0 + ResPacketSize, R4); - pstoreu(r1 + ResPacketSize, R0); + traits.acc(C2, alphav, R2); + traits.acc(C3, alphav, R0); + + pstoreu(r0+1*resStride, R1); + pstoreu(r0+2*resStride, R2); + pstoreu(r0+3*resStride, R0); } } - if(rows-peeled_mc>=LhsProgress) - { - Index i = peeled_mc; - const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress]; - prefetch(&blA[0]); - - // gets res block as register - AccPacket C0, C1, C2, C3; - traits.initAcc(C0); - traits.initAcc(C1); - if(nr==4) traits.initAcc(C2); - if(nr==4) traits.initAcc(C3); - - // performs "inner" product - const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; - for(Index k=0; k<peeled_kc; k+=4) - { - if(nr==2) - { - LhsPacket A0; - RhsPacket B_0, B1; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[2*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadLhs(&blA[1*LhsProgress], A0); - traits.loadRhs(&blB[3*RhsProgress], B1); - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[4*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadLhs(&blA[2*LhsProgress], A0); - traits.loadRhs(&blB[5*RhsProgress], B1); - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[6*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadLhs(&blA[3*LhsProgress], A0); - traits.loadRhs(&blB[7*RhsProgress], B1); - traits.madd(A0,B_0,C0,B_0); - traits.madd(A0,B1,C1,B1); - } - else - { - LhsPacket A0; - RhsPacket B_0, B1, B2, B3; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[2*RhsProgress], B2); - traits.loadRhs(&blB[3*RhsProgress], B3); - traits.loadRhs(&blB[4*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadRhs(&blB[5*RhsProgress], B1); - traits.madd(A0,B2,C2,B2); - traits.loadRhs(&blB[6*RhsProgress], B2); - traits.madd(A0,B3,C3,B3); - traits.loadLhs(&blA[1*LhsProgress], A0); - traits.loadRhs(&blB[7*RhsProgress], B3); - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[8*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadRhs(&blB[9*RhsProgress], B1); - traits.madd(A0,B2,C2,B2); - traits.loadRhs(&blB[10*RhsProgress], B2); - traits.madd(A0,B3,C3,B3); - traits.loadLhs(&blA[2*LhsProgress], A0); - traits.loadRhs(&blB[11*RhsProgress], B3); - - traits.madd(A0,B_0,C0,B_0); - traits.loadRhs(&blB[12*RhsProgress], B_0); - traits.madd(A0,B1,C1,B1); - traits.loadRhs(&blB[13*RhsProgress], B1); - traits.madd(A0,B2,C2,B2); - traits.loadRhs(&blB[14*RhsProgress], B2); - traits.madd(A0,B3,C3,B3); - - traits.loadLhs(&blA[3*LhsProgress], A0); - traits.loadRhs(&blB[15*RhsProgress], B3); - traits.madd(A0,B_0,C0,B_0); - traits.madd(A0,B1,C1,B1); - traits.madd(A0,B2,C2,B2); - traits.madd(A0,B3,C3,B3); - } - - blB += nr*4*RhsProgress; - blA += 4*LhsProgress; - } - // process remaining peeled loop - for(Index k=peeled_kc; k<depth; k++) - { - if(nr==2) - { - LhsPacket A0; - RhsPacket B_0, B1; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - traits.madd(A0,B_0,C0,B_0); - traits.madd(A0,B1,C1,B1); - } - else - { - LhsPacket A0; - RhsPacket B_0, B1, B2, B3; - - traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadRhs(&blB[0*RhsProgress], B_0); - traits.loadRhs(&blB[1*RhsProgress], B1); - traits.loadRhs(&blB[2*RhsProgress], B2); - traits.loadRhs(&blB[3*RhsProgress], B3); - - traits.madd(A0,B_0,C0,B_0); - traits.madd(A0,B1,C1,B1); - traits.madd(A0,B2,C2,B2); - traits.madd(A0,B3,C3,B3); - } - - blB += nr*RhsProgress; - blA += LhsProgress; - } - - ResPacket R0, R1, R2, R3; - ResPacket alphav = pset1<ResPacket>(alpha); - - ResScalar* r0 = &res[(j2+0)*resStride + i]; - ResScalar* r1 = r0 + resStride; - ResScalar* r2 = r1 + resStride; - ResScalar* r3 = r2 + resStride; - - R0 = ploadu<ResPacket>(r0); - R1 = ploadu<ResPacket>(r1); - if(nr==4) R2 = ploadu<ResPacket>(r2); - if(nr==4) R3 = ploadu<ResPacket>(r3); - - traits.acc(C0, alphav, R0); - traits.acc(C1, alphav, R1); - if(nr==4) traits.acc(C2, alphav, R2); - if(nr==4) traits.acc(C3, alphav, R3); - - pstoreu(r0, R0); - pstoreu(r1, R1); - if(nr==4) pstoreu(r2, R2); - if(nr==4) pstoreu(r3, R3); - } - for(Index i=peeled_mc2; i<rows; i++) + for(Index i=peeled_mc; i<rows; i++) { const LhsScalar* blA = &blockA[i*strideA+offsetA]; prefetch(&blA[0]); // gets a 1 x nr res block as registers - ResScalar C0(0), C1(0), C2(0), C3(0); - // TODO directly use blockB ??? + ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0); + // FIXME directly use blockB ??? const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; + // TODO peel this loop for(Index k=0; k<depth; k++) { - if(nr==2) + if(nr==4) { LhsScalar A0; - RhsScalar B_0, B1; + RhsScalar B_0, B_1; A0 = blA[k]; + B_0 = blB[0]; - B1 = blB[1]; - MADD(cj,A0,B_0,C0,B_0); - MADD(cj,A0,B1,C1,B1); + B_1 = blB[1]; + MADD(cj,A0,B_0,C0, B_0); + MADD(cj,A0,B_1,C1, B_1); + + B_0 = blB[2]; + B_1 = blB[3]; + MADD(cj,A0,B_0,C2, B_0); + MADD(cj,A0,B_1,C3, B_1); } - else + else // nr==8 { LhsScalar A0; - RhsScalar B_0, B1, B2, B3; + RhsScalar B_0, B_1; A0 = blA[k]; + B_0 = blB[0]; - B1 = blB[1]; - B2 = blB[2]; - B3 = blB[3]; - - MADD(cj,A0,B_0,C0,B_0); - MADD(cj,A0,B1,C1,B1); - MADD(cj,A0,B2,C2,B2); - MADD(cj,A0,B3,C3,B3); + B_1 = blB[1]; + MADD(cj,A0,B_0,C0, B_0); + MADD(cj,A0,B_1,C1, B_1); + + B_0 = blB[2]; + B_1 = blB[3]; + MADD(cj,A0,B_0,C2, B_0); + MADD(cj,A0,B_1,C3, B_1); + + B_0 = blB[4]; + B_1 = blB[5]; + MADD(cj,A0,B_0,C4, B_0); + MADD(cj,A0,B_1,C5, B_1); + + B_0 = blB[6]; + B_1 = blB[7]; + MADD(cj,A0,B_0,C6, B_0); + MADD(cj,A0,B_1,C7, B_1); } blB += nr; } res[(j2+0)*resStride + i] += alpha*C0; res[(j2+1)*resStride + i] += alpha*C1; - if(nr==4) res[(j2+2)*resStride + i] += alpha*C2; - if(nr==4) res[(j2+3)*resStride + i] += alpha*C3; + res[(j2+2)*resStride + i] += alpha*C2; + res[(j2+3)*resStride + i] += alpha*C3; + if(nr==8) res[(j2+4)*resStride + i] += alpha*C4; + if(nr==8) res[(j2+5)*resStride + i] += alpha*C5; + if(nr==8) res[(j2+6)*resStride + i] += alpha*C6; + if(nr==8) res[(j2+7)*resStride + i] += alpha*C7; } } + // process remaining rhs/res columns one at a time // => do the same but with nr==1 for(Index j2=packet_cols; j2<cols; j2++) @@ -977,67 +824,31 @@ EIGEN_ASM_COMMENT("mybegin4"); // TODO move the res loads to the stores // get res block as registers - AccPacket C0, C4; + AccPacket C0; traits.initAcc(C0); - traits.initAcc(C4); const RhsScalar* blB = &blockB[j2*strideB+offsetB]; for(Index k=0; k<depth; k++) { - LhsPacket A0, A1; + LhsPacket A0; RhsPacket B_0; RhsPacket T0; traits.loadLhs(&blA[0*LhsProgress], A0); - traits.loadLhs(&blA[1*LhsProgress], A1); traits.loadRhs(&blB[0*RhsProgress], B_0); traits.madd(A0,B_0,C0,T0); - traits.madd(A1,B_0,C4,B_0); blB += RhsProgress; - blA += 2*LhsProgress; + blA += LhsProgress; } - ResPacket R0, R4; + ResPacket R0; ResPacket alphav = pset1<ResPacket>(alpha); - ResScalar* r0 = &res[(j2+0)*resStride + i]; - R0 = ploadu<ResPacket>(r0); - R4 = ploadu<ResPacket>(r0+ResPacketSize); - traits.acc(C0, alphav, R0); - traits.acc(C4, alphav, R4); - - pstoreu(r0, R0); - pstoreu(r0+ResPacketSize, R4); + pstoreu(r0, R0); } - if(rows-peeled_mc>=LhsProgress) - { - Index i = peeled_mc; - const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress]; - prefetch(&blA[0]); - - AccPacket C0; - traits.initAcc(C0); - - const RhsScalar* blB = &blockB[j2*strideB+offsetB]; - for(Index k=0; k<depth; k++) - { - LhsPacket A0; - RhsPacket B_0; - traits.loadLhs(blA, A0); - traits.loadRhs(blB, B_0); - traits.madd(A0, B_0, C0, B_0); - blB += RhsProgress; - blA += LhsProgress; - } - - ResPacket alphav = pset1<ResPacket>(alpha); - ResPacket R0 = ploadu<ResPacket>(&res[(j2+0)*resStride + i]); - traits.acc(C0, alphav, R0); - pstoreu(&res[(j2+0)*resStride + i], R0); - } - for(Index i=peeled_mc2; i<rows; i++) + for(Index i=peeled_mc; i<rows; i++) { const LhsScalar* blA = &blockA[i*strideA+offsetA]; prefetch(&blA[0]); @@ -1091,7 +902,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, StorageOrder, EIGEN_UNUSED_VARIABLE(stride); EIGEN_UNUSED_VARIABLE(offset); eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); - eigen_assert( (StorageOrder==RowMajor) || ((Pack1%PacketSize)==0 && Pack1<=4*PacketSize) ); + eigen_assert( (StorageOrder==RowMajor) || ((Pack1%PacketSize)==0 && Pack1<=4*PacketSize) || (Pack1<=4) ); conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride); Index count = 0; @@ -1104,15 +915,25 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, StorageOrder, { for(Index k=0; k<depth; k++) { - Packet A, B, C, D; - if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); - if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k)); - if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k)); - if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k)); - if(Pack1>=1*PacketSize) { pstore(blockA+count, cj.pconj(A)); count+=PacketSize; } - if(Pack1>=2*PacketSize) { pstore(blockA+count, cj.pconj(B)); count+=PacketSize; } - if(Pack1>=3*PacketSize) { pstore(blockA+count, cj.pconj(C)); count+=PacketSize; } - if(Pack1>=4*PacketSize) { pstore(blockA+count, cj.pconj(D)); count+=PacketSize; } + if((Pack1%PacketSize)==0) + { + Packet A, B, C, D; + if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); + if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k)); + if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k)); + if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k)); + if(Pack1>=1*PacketSize) { pstore(blockA+count, cj.pconj(A)); count+=PacketSize; } + if(Pack1>=2*PacketSize) { pstore(blockA+count, cj.pconj(B)); count+=PacketSize; } + if(Pack1>=3*PacketSize) { pstore(blockA+count, cj.pconj(C)); count+=PacketSize; } + if(Pack1>=4*PacketSize) { pstore(blockA+count, cj.pconj(D)); count+=PacketSize; } + } + else + { + if(Pack1>=1) blockA[count++] = cj(lhs(i+0, k)); + if(Pack1>=2) blockA[count++] = cj(lhs(i+1, k)); + if(Pack1>=3) blockA[count++] = cj(lhs(i+2, k)); + if(Pack1>=4) blockA[count++] = cj(lhs(i+3, k)); + } } } else @@ -1191,12 +1012,20 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan const Scalar* b1 = &rhs[(j2+1)*rhsStride]; const Scalar* b2 = &rhs[(j2+2)*rhsStride]; const Scalar* b3 = &rhs[(j2+3)*rhsStride]; + const Scalar* b4 = &rhs[(j2+4)*rhsStride]; + const Scalar* b5 = &rhs[(j2+5)*rhsStride]; + const Scalar* b6 = &rhs[(j2+6)*rhsStride]; + const Scalar* b7 = &rhs[(j2+7)*rhsStride]; for(Index k=0; k<depth; k++) { blockB[count+0] = cj(b0[k]); blockB[count+1] = cj(b1[k]); - if(nr==4) blockB[count+2] = cj(b2[k]); - if(nr==4) blockB[count+3] = cj(b3[k]); + if(nr>=4) blockB[count+2] = cj(b2[k]); + if(nr>=4) blockB[count+3] = cj(b3[k]); + if(nr>=8) blockB[count+4] = cj(b4[k]); + if(nr>=8) blockB[count+5] = cj(b5[k]); + if(nr>=8) blockB[count+6] = cj(b6[k]); + if(nr>=8) blockB[count+7] = cj(b7[k]); count += nr; } // skip what we have after @@ -1251,8 +1080,12 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, Pan const Scalar* b0 = &rhs[k*rhsStride + j2]; blockB[count+0] = cj(b0[0]); blockB[count+1] = cj(b0[1]); - if(nr==4) blockB[count+2] = cj(b0[2]); - if(nr==4) blockB[count+3] = cj(b0[3]); + if(nr>=4) blockB[count+2] = cj(b0[2]); + if(nr>=4) blockB[count+3] = cj(b0[3]); + if(nr>=8) blockB[count+4] = cj(b0[4]); + if(nr>=8) blockB[count+5] = cj(b0[5]); + if(nr>=8) blockB[count+6] = cj(b0[6]); + if(nr>=8) blockB[count+7] = cj(b0[7]); count += nr; } } diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 99cf9e0ae..d9fd9f556 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -63,7 +63,7 @@ struct symm_pack_lhs for(Index i=peeled_mc; i<rows; i++) { for(Index k=0; k<i; k++) - blockA[count++] = lhs(i, k); // normal + blockA[count++] = lhs(i, k); // normal blockA[count++] = numext::real(lhs(i, i)); // real (diagonal) @@ -91,11 +91,18 @@ struct symm_pack_rhs { blockB[count+0] = rhs(k,j2+0); blockB[count+1] = rhs(k,j2+1); - if (nr==4) + if (nr>=4) { blockB[count+2] = rhs(k,j2+2); blockB[count+3] = rhs(k,j2+3); } + if (nr>=8) + { + blockB[count+4] = rhs(k,j2+4); + blockB[count+5] = rhs(k,j2+5); + blockB[count+6] = rhs(k,j2+6); + blockB[count+7] = rhs(k,j2+7); + } count += nr; } } @@ -109,11 +116,18 @@ struct symm_pack_rhs { blockB[count+0] = numext::conj(rhs(j2+0,k)); blockB[count+1] = numext::conj(rhs(j2+1,k)); - if (nr==4) + if (nr>=4) { blockB[count+2] = numext::conj(rhs(j2+2,k)); blockB[count+3] = numext::conj(rhs(j2+3,k)); } + if (nr>=8) + { + blockB[count+4] = numext::conj(rhs(j2+4,k)); + blockB[count+5] = numext::conj(rhs(j2+5,k)); + blockB[count+6] = numext::conj(rhs(j2+6,k)); + blockB[count+7] = numext::conj(rhs(j2+7,k)); + } count += nr; } // symmetric @@ -137,11 +151,18 @@ struct symm_pack_rhs { blockB[count+0] = rhs(k,j2+0); blockB[count+1] = rhs(k,j2+1); - if (nr==4) + if (nr>=4) { blockB[count+2] = rhs(k,j2+2); blockB[count+3] = rhs(k,j2+3); } + if (nr>=8) + { + blockB[count+4] = rhs(k,j2+4); + blockB[count+5] = rhs(k,j2+5); + blockB[count+6] = rhs(k,j2+6); + blockB[count+7] = rhs(k,j2+7); + } count += nr; } } @@ -153,11 +174,18 @@ struct symm_pack_rhs { blockB[count+0] = numext::conj(rhs(j2+0,k)); blockB[count+1] = numext::conj(rhs(j2+1,k)); - if (nr==4) + if (nr>=4) { blockB[count+2] = numext::conj(rhs(j2+2,k)); blockB[count+3] = numext::conj(rhs(j2+3,k)); } + if (nr>=8) + { + blockB[count+4] = numext::conj(rhs(j2+4,k)); + blockB[count+5] = numext::conj(rhs(j2+5,k)); + blockB[count+6] = numext::conj(rhs(j2+6,k)); + blockB[count+7] = numext::conj(rhs(j2+7,k)); + } count += nr; } } @@ -422,11 +450,11 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)), internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor> ::run( - lhs.rows(), rhs.cols(), // sizes - &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info - &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info - &dst.coeffRef(0,0), dst.outerStride(), // result info - actualAlpha // alpha + lhs.rows(), rhs.cols(), // sizes + &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info + &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info + &dst.coeffRef(0,0), dst.outerStride(), // result info + actualAlpha // alpha ); } }; |