diff options
author | 2014-03-28 12:11:23 -0700 | |
---|---|---|
committer | 2014-03-28 12:11:23 -0700 | |
commit | ad59ade116969ca7b18409d690caf00c0b1c34c7 (patch) | |
tree | 98d3d37958231b86203e3b1303cf7a4dc3400fa6 /Eigen/src/Core/products/GeneralBlockPanelKernel.h | |
parent | 39bfbd43f05691874a78a5a6bf4504cf0e6ff452 (diff) |
Vectorized the loop peeling of the inner loop of the block-panel matrix multiplication code. This speeds up the multiplication of matrices which size is not a multiple of the packet size.
Diffstat (limited to 'Eigen/src/Core/products/GeneralBlockPanelKernel.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 224 |
1 files changed, 157 insertions, 67 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index 0f47f6de5..3ed1fc5a3 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -206,6 +206,11 @@ public: dest = pload<LhsPacket>(a); } + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + { + dest = ploadu<LhsPacket>(a); + } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const { // It would be a lot cleaner to call pmadd all the time. Unfortunately if we @@ -278,7 +283,12 @@ public: { dest = pload<LhsPacket>(a); } - + + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + { + dest = ploadu<LhsPacket>(a); + } + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) { pbroadcast4(b, b0, b1, b2, b3); @@ -334,7 +344,9 @@ public: && packet_traits<Scalar>::Vectorizable, RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1, ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, - + LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, + RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, + // FIXME: should depend on NumberOfRegisters nr = 4, mr = ResPacketSize, @@ -402,6 +414,11 @@ public: dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a)); } + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + { + dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a)); + } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const { c.first = padd(pmul(a,b.first), c.first); @@ -509,6 +526,11 @@ public: dest = ploaddup<LhsPacket>(a); } + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + { + dest = ploaddup<LhsPacket>(a); + } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const { madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); @@ -706,49 +728,84 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> const LhsScalar* blA = &blockA[i*strideA+offsetA]; prefetch(&blA[0]); - // gets a 1 x 8 res block as registers - ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0); // FIXME directly use blockB ??? const RhsScalar* blB = &blockB[j2*strideB+offsetB*8]; - // TODO peel this loop - for(Index k=0; k<depth; k++) - { - LhsScalar A0; - RhsScalar B_0, B_1; - A0 = blA[k]; + if(nr == Traits::RhsPacketSize) + { + EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows"); + + typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; + typedef typename SwappedTraits::ResScalar SResScalar; + typedef typename SwappedTraits::LhsPacket SLhsPacket; + typedef typename SwappedTraits::RhsPacket SRhsPacket; + typedef typename SwappedTraits::ResPacket SResPacket; + typedef typename SwappedTraits::AccPacket SAccPacket; + SwappedTraits straits; + + SAccPacket C0; + straits.initAcc(C0); + for(Index k=0; k<depth; k++) + { + SLhsPacket A0; + straits.loadLhsUnaligned(blB, A0); + SRhsPacket B_0; + straits.loadRhs(&blA[k], B_0); + SRhsPacket T0; + straits.madd(A0,B_0,C0,T0); + blB += nr; + } + SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride); + SResPacket alphav = pset1<SResPacket>(alpha); + straits.acc(C0, alphav, R); + pscatter(&res[j2*resStride + i], R, resStride); + + EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows"); + } + else + { + // gets a 1 x 8 res block as registers + ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0); + + for(Index k=0; k<depth; k++) + { + LhsScalar A0; + RhsScalar B_0, B_1; + + A0 = blA[k]; + + B_0 = blB[0]; + B_1 = blB[1]; + MADD(cj,A0,B_0,C0, B_0); + MADD(cj,A0,B_1,C1, B_1); + + B_0 = blB[2]; + B_1 = blB[3]; + MADD(cj,A0,B_0,C2, B_0); + MADD(cj,A0,B_1,C3, B_1); + + B_0 = blB[4]; + B_1 = blB[5]; + MADD(cj,A0,B_0,C4, B_0); + MADD(cj,A0,B_1,C5, B_1); - B_0 = blB[0]; - B_1 = blB[1]; - MADD(cj,A0,B_0,C0, B_0); - MADD(cj,A0,B_1,C1, B_1); - - B_0 = blB[2]; - B_1 = blB[3]; - MADD(cj,A0,B_0,C2, B_0); - MADD(cj,A0,B_1,C3, B_1); - - B_0 = blB[4]; - B_1 = blB[5]; - MADD(cj,A0,B_0,C4, B_0); - MADD(cj,A0,B_1,C5, B_1); - - B_0 = blB[6]; - B_1 = blB[7]; - MADD(cj,A0,B_0,C6, B_0); - MADD(cj,A0,B_1,C7, B_1); - - blB += 8; - } - res[(j2+0)*resStride + i] += alpha*C0; - res[(j2+1)*resStride + i] += alpha*C1; - res[(j2+2)*resStride + i] += alpha*C2; - res[(j2+3)*resStride + i] += alpha*C3; - res[(j2+4)*resStride + i] += alpha*C4; - res[(j2+5)*resStride + i] += alpha*C5; - res[(j2+6)*resStride + i] += alpha*C6; - res[(j2+7)*resStride + i] += alpha*C7; - } + B_0 = blB[6]; + B_1 = blB[7]; + MADD(cj,A0,B_0,C6, B_0); + MADD(cj,A0,B_1,C7, B_1); + + blB += 8; + } + res[(j2+0)*resStride + i] += alpha*C0; + res[(j2+1)*resStride + i] += alpha*C1; + res[(j2+2)*resStride + i] += alpha*C2; + res[(j2+3)*resStride + i] += alpha*C3; + res[(j2+4)*resStride + i] += alpha*C4; + res[(j2+5)*resStride + i] += alpha*C5; + res[(j2+6)*resStride + i] += alpha*C6; + res[(j2+7)*resStride + i] += alpha*C7; + } + } } } @@ -839,35 +896,68 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> const LhsScalar* blA = &blockA[i*strideA+offsetA]; prefetch(&blA[0]); - // gets a 1 x 4 res block as registers - ResScalar C0(0), C1(0), C2(0), C3(0); // FIXME directly use blockB ??? const RhsScalar* blB = &blockB[j2*strideB+offsetB*4]; - // TODO peel this loop - for(Index k=0; k<depth; k++) - { - LhsScalar A0; - RhsScalar B_0, B_1; - - A0 = blA[k]; - - B_0 = blB[0]; - B_1 = blB[1]; - MADD(cj,A0,B_0,C0, B_0); - MADD(cj,A0,B_1,C1, B_1); - - B_0 = blB[2]; - B_1 = blB[3]; - MADD(cj,A0,B_0,C2, B_0); - MADD(cj,A0,B_1,C3, B_1); - blB += 4; - } - res[(j2+0)*resStride + i] += alpha*C0; - res[(j2+1)*resStride + i] += alpha*C1; - res[(j2+2)*resStride + i] += alpha*C2; - res[(j2+3)*resStride + i] += alpha*C3; - } + if(nr == Traits::RhsPacketSize) + { + EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows"); + + typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; + typedef typename SwappedTraits::ResScalar SResScalar; + typedef typename SwappedTraits::LhsPacket SLhsPacket; + typedef typename SwappedTraits::RhsPacket SRhsPacket; + typedef typename SwappedTraits::ResPacket SResPacket; + typedef typename SwappedTraits::AccPacket SAccPacket; + SwappedTraits straits; + + SAccPacket C0; + straits.initAcc(C0); + for(Index k=0; k<depth; k++) + { + SLhsPacket A0; + straits.loadLhsUnaligned(blB, A0); + SRhsPacket B_0; + straits.loadRhs(&blA[k], B_0); + SRhsPacket T0; + straits.madd(A0,B_0,C0,T0); + blB += nr; + } + SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride); + SResPacket alphav = pset1<SResPacket>(alpha); + straits.acc(C0, alphav, R); + pscatter(&res[j2*resStride + i], R, resStride); + + EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows"); + } else { + // gets a 1 x 4 res block as registers + ResScalar C0(0), C1(0), C2(0), C3(0); + + for(Index k=0; k<depth; k++) + { + LhsScalar A0; + RhsScalar B_0, B_1; + + A0 = blA[k]; + + B_0 = blB[0]; + B_1 = blB[1]; + MADD(cj,A0,B_0,C0, B_0); + MADD(cj,A0,B_1,C1, B_1); + + B_0 = blB[2]; + B_1 = blB[3]; + MADD(cj,A0,B_0,C2, B_0); + MADD(cj,A0,B_1,C3, B_1); + + blB += 4; + } + res[(j2+0)*resStride + i] += alpha*C0; + res[(j2+1)*resStride + i] += alpha*C1; + res[(j2+2)*resStride + i] += alpha*C2; + res[(j2+3)*resStride + i] += alpha*C3; + } + } } } |