diff options
author | Renjie Liu <renjieliu@google.com> | 2019-01-16 12:51:36 +0800 |
---|---|---|
committer | Renjie Liu <renjieliu@google.com> | 2019-01-16 12:51:36 +0800 |
commit | dbfcceabf50db9c1dc6d82863aa9670a1b53c0a4 (patch) | |
tree | 040d043da4ec629bd23753c369b6be70a425d27e /Eigen | |
parent | 2b70b2f5708fdedf24c5d47768c2b24019c48311 (diff) |
Bug: 1633: refactor gebp kernel and optimize for neon
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 435 |
1 files changed, 274 insertions, 161 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index d6dd9dc17..bfc7d1979 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -347,6 +347,14 @@ inline void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_ // #define CJMADD(CJ,A,B,C,T) T = B; T = CJ.pmul(A,T); C = padd(C,T); #endif +template <typename RhsPacket, typename RhsPacketx4, int registers_taken> +struct RhsPanelHelper { + private: + typedef typename conditional<(registers_taken < 15), RhsPacket, RhsPacketx4>::type inter_type; + public: + typedef typename conditional<(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS < 32), RhsPacket, inter_type>::type type; +}; + /* Vectorization logic * real*real: unpack rhs to constant packets, ... * @@ -404,29 +412,42 @@ public: typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); } - - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) - { - pbroadcast4(b, b0, b1, b2, b3); - } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// pbroadcast2(b, b0, b1); -// } - + template<typename RhsPacketType> EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const { dest = pset1<RhsPacketType>(*b); } - + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); + } + + template<typename RhsPacketType> + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + { + } + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad<RhsPacket>(b); @@ -444,8 +465,8 @@ public: dest = ploadu<LhsPacketType>(a); } - template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, AccPacketType& tmp) const + template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const FixedInt&) const { conj_helper<LhsPacketType,RhsPacketType,ConjLhs,ConjRhs> cj; // It would be a lot cleaner to call pmadd all the time. Unfortunately if we @@ -460,6 +481,13 @@ public: #endif } + template<typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const { r = pmadd(c,alpha,r); @@ -511,6 +539,14 @@ public: typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -523,6 +559,20 @@ public: { dest = pset1<RhsPacketType>(*b); } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); + } + + template<typename RhsPacketType> + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { @@ -554,18 +604,8 @@ public: dest = ploadu<LhsPacketType>(a); } - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) - { - pbroadcast4(b, b0, b1, b2, b3); - } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// pbroadcast2(b, b0, b1); -// } - - template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const + template <typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const { madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); } @@ -586,6 +626,13 @@ public: c += a * b; } + template<typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + template <typename ResPacketType, typename AccPacketType> EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { @@ -708,6 +755,14 @@ public: typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type RhsPacket; typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket; typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type AccPacket; + + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); } @@ -730,39 +785,39 @@ public: dest.first = pset1<RealPacketType>(real(*b)); dest.second = pset1<RealPacketType>(imag(*b)); } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const { - loadRhs(b,dest); + loadRhs(b, dest.B_0); + loadRhs(b + 1, dest.B1); + loadRhs(b + 2, dest.B2); + loadRhs(b + 3, dest.B3); } - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const + + // Scalar path + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, ScalarPacket& dest) const { - loadQuadToDoublePacket(b,dest); + loadRhs(b, dest); } - - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + + // Vectorized path + template<typename RealPacketType> + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); - loadRhs(b+2, b2); - loadRhs(b+3, b3); + loadRhs(b, dest); } - - // Vectorized path - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1) + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); } - // Scalar path - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsScalar& b0, RhsScalar& b1) + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const + { + loadRhs(b,dest); + } + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); + loadQuadToDoublePacket(b,dest); } // nothing special here @@ -777,17 +832,25 @@ public: dest = ploadu<LhsPacketType>((const typename unpacket_traits<LhsPacketType>::type*)(a)); } - template<typename LhsPacketType, typename RhsPacketType, typename ResPacketType, typename TmpType> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/) const + template<typename LhsPacketType, typename ResPacketType, typename TmpType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/, const FixedInt&) const { c.first = padd(pmul(a,b.first), c.first); c.second = padd(pmul(a,b.second),c.second); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/) const + template <typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/, const FixedInt&) const { c = cj.pmadd(a,b,c); } + + template<typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; } @@ -860,6 +923,14 @@ public: typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -872,18 +943,20 @@ public: { dest = pset1<RhsPacketType>(*b); } - - void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const { - pbroadcast4(b, b0, b1, b2, b3); + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// // FIXME not sure that's the best way to implement it! -// b0 = pload1<RhsPacket>(b+0); -// b1 = pload1<RhsPacket>(b+1); -// } + + template<typename RhsPacketType> + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const { @@ -901,8 +974,8 @@ public: dest = ploaddup<LhsPacketType>(a); } - template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const + template <typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const { madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); } @@ -924,6 +997,13 @@ public: c += a * b; } + template<typename LhsPacketType, typename AccPacketType, typename FixedInt> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + template <typename ResPacketType, typename AccPacketType> EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { @@ -932,7 +1012,7 @@ public: } protected: - + }; @@ -944,27 +1024,54 @@ struct gebp_traits <float, float, false, false,Architecture::NEON> { typedef float RhsPacket; - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + typedef float32x4_t RhsPacketx4; + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { - loadRhs(b+0, b0); - loadRhs(b+1, b1); - loadRhs(b+2, b2); - loadRhs(b+3, b3); + dest = *b; } - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + dest = vld1q_f32(b); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { dest = *b; } + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const + {} + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b,dest); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c += a * b; + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<0>&) const + { + c = vfmaq_lane_f32(c, a, vget_low_f32(b), 0); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<1>&) const + { + c = vfmaq_lane_f32(c, a, vget_low_f32(b), 1); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<2>&) const + { + c = vfmaq_lane_f32(c, a, vget_high_f32(b), 0); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<3>&) const { - c = vfmaq_n_f32(c, a, b); + c = vfmaq_lane_f32(c, a, vget_high_f32(b), 1); } }; @@ -986,6 +1093,9 @@ struct gebp_kernel typedef typename Traits::RhsPacket RhsPacket; typedef typename Traits::ResPacket ResPacket; typedef typename Traits::AccPacket AccPacket; + typedef typename Traits::RhsPacketx4 RhsPacketx4; + + typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15; typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; typedef typename SwappedTraits::ResScalar SResScalar; @@ -1075,7 +1185,7 @@ struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, SRhsPacketQuarter b0; straits.loadLhsUnaligned(blB, a0); straits.loadRhs(blA, b0); - straits.madd(a0,b0,c0,b0); + straits.madd(a0,b0,c0,b0, fix<0>); blB += SwappedTraits::LhsProgress/4; blA += 1; } @@ -1166,36 +1276,39 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga for(Index k=0; k<peeled_kc; k+=pk) { EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4"); - RhsPacket B_0, T0; + // 15 registers are taken (12 for acc, 2 for lhs). + RhsPanel15 rhs_panel, T0; LhsPacket A2; -#define EIGEN_GEBP_ONESTEP(K) \ - do { \ - EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \ +#define EIGEN_GEBP_ONESTEP(K) \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \ - internal::prefetch(blA+(3*K+16)*LhsProgress); \ - if (EIGEN_ARCH_ARM || EIGEN_ARCH_MIPS) { internal::prefetch(blB+(4*K+16)*RhsProgress); } /* Bug 953 */ \ - traits.loadLhs(&blA[(0+3*K)*LhsProgress], A0); \ - traits.loadLhs(&blA[(1+3*K)*LhsProgress], A1); \ - traits.loadLhs(&blA[(2+3*K)*LhsProgress], A2); \ - traits.loadRhs(blB + (0+4*K)*Traits::RhsProgress, B_0); \ - traits.madd(A0, B_0, C0, T0); \ - traits.madd(A1, B_0, C4, T0); \ - traits.madd(A2, B_0, C8, B_0); \ - traits.loadRhs(blB + (1+4*K)*Traits::RhsProgress, B_0); \ - traits.madd(A0, B_0, C1, T0); \ - traits.madd(A1, B_0, C5, T0); \ - traits.madd(A2, B_0, C9, B_0); \ - traits.loadRhs(blB + (2+4*K)*Traits::RhsProgress, B_0); \ - traits.madd(A0, B_0, C2, T0); \ - traits.madd(A1, B_0, C6, T0); \ - traits.madd(A2, B_0, C10, B_0); \ - traits.loadRhs(blB + (3+4*K)*Traits::RhsProgress, B_0); \ - traits.madd(A0, B_0, C3 , T0); \ - traits.madd(A1, B_0, C7, T0); \ - traits.madd(A2, B_0, C11, B_0); \ - EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \ - } while(false) + internal::prefetch(blA + (3 * K + 16) * LhsProgress); \ + if (EIGEN_ARCH_ARM || EIGEN_ARCH_MIPS) { \ + internal::prefetch(blB + (4 * K + 16) * RhsProgress); \ + } /* Bug 953 */ \ + traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \ + traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \ + traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \ + traits.loadRhs(blB + (0+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C0, T0, fix<0>); \ + traits.madd(A1, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A2, rhs_panel, C8, T0, fix<0>); \ + traits.updateRhs(blB + (1+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A2, rhs_panel, C9, T0, fix<1>); \ + traits.updateRhs(blB + (2+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A2, rhs_panel, C10, T0, fix<2>); \ + traits.updateRhs(blB + (3+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C7, T0, fix<3>); \ + traits.madd(A2, rhs_panel, C11, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \ + } while (false) internal::prefetch(blB); EIGEN_GEBP_ONESTEP(0); @@ -1215,7 +1328,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - RhsPacket B_0, T0; + RhsPanel15 rhs_panel, T0; LhsPacket A2; EIGEN_GEBP_ONESTEP(0); blB += 4*RhsProgress; @@ -1295,20 +1408,20 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga { EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX1"); RhsPacket B_0; -#define EIGEN_GEBGP_ONESTEP(K) \ - do { \ - EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \ +#define EIGEN_GEBGP_ONESTEP(K) \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \ - traits.loadLhs(&blA[(0+3*K)*LhsProgress], A0); \ - traits.loadLhs(&blA[(1+3*K)*LhsProgress], A1); \ - traits.loadLhs(&blA[(2+3*K)*LhsProgress], A2); \ - traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \ - traits.madd(A0, B_0, C0, B_0); \ - traits.madd(A1, B_0, C4, B_0); \ - traits.madd(A2, B_0, C8, B_0); \ - EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \ - } while(false) - + traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \ + traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \ + traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \ + traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \ + traits.madd(A0, B_0, C0, B_0, fix<0>); \ + traits.madd(A1, B_0, C4, B_0, fix<0>); \ + traits.madd(A2, B_0, C8, B_0, fix<0>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \ + } while (false) + EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(1); EIGEN_GEBGP_ONESTEP(2); @@ -1397,7 +1510,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga for(Index k=0; k<peeled_kc; k+=pk) { EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX4"); - RhsPacket B_0, B1, B2, B3, T0; + RhsPacketx4 rhs_panel, T0; // NOTE: the begin/end asm comments below work around bug 935! // but they are not enough for gcc>=6 without FMA (bug 1637) @@ -1406,24 +1519,24 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga #else #define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND #endif - #define EIGEN_GEBGP_ONESTEP(K) \ - do { \ - EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \ - traits.loadLhs(&blA[(0+2*K)*LhsProgress], A0); \ - traits.loadLhs(&blA[(1+2*K)*LhsProgress], A1); \ - traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], B_0, B1, B2, B3); \ - traits.madd(A0, B_0, C0, T0); \ - traits.madd(A1, B_0, C4, B_0); \ - traits.madd(A0, B1, C1, T0); \ - traits.madd(A1, B1, C5, B1); \ - traits.madd(A0, B2, C2, T0); \ - traits.madd(A1, B2, C6, B2); \ - traits.madd(A0, B3, C3, T0); \ - traits.madd(A1, B3, C7, B3); \ - EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \ - EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \ - } while(false) - +#define EIGEN_GEBGP_ONESTEP(K) \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \ + traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \ + traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \ + traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C0, T0, fix<0>); \ + traits.madd(A1, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C7, T0, fix<3>); \ + EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \ + } while (false) + internal::prefetch(blB+(48+0)); EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(1); @@ -1443,7 +1556,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - RhsPacket B_0, B1, B2, B3, T0; + RhsPacketx4 rhs_panel, T0; EIGEN_GEBGP_ONESTEP(0); blB += 4*RhsProgress; blA += 2*Traits::LhsProgress; @@ -1514,8 +1627,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga traits.loadLhs(&blA[(0+2*K)*LhsProgress], A0); \ traits.loadLhs(&blA[(1+2*K)*LhsProgress], A1); \ traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \ - traits.madd(A0, B_0, C0, B1); \ - traits.madd(A1, B_0, C4, B_0); \ + traits.madd(A0, B_0, C0, B1, fix<0>); \ + traits.madd(A1, B_0, C4, B_0, fix<0>); \ EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX1"); \ } while(false) @@ -1596,19 +1709,19 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga for(Index k=0; k<peeled_kc; k+=pk) { EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX4"); - RhsPacket B_0, B1, B2, B3; - + RhsPacketx4 rhs_panel, T0; + #define EIGEN_GEBGP_ONESTEP(K) \ - do { \ - EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX4"); \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX4"); \ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \ - traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \ - traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], B_0, B1, B2, B3); \ - traits.madd(A0, B_0, C0, B_0); \ - traits.madd(A0, B1, C1, B1); \ - traits.madd(A0, B2, C2, B2); \ - traits.madd(A0, B3, C3, B3); \ - EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX4"); \ + traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \ + traits.loadRhs(&blB[(0+4*K)*RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C0, T0, fix<0>); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX4"); \ } while(false) internal::prefetch(blB+(48+0)); @@ -1630,7 +1743,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - RhsPacket B_0, B1, B2, B3; + RhsPacketx4 rhs_panel, T0; EIGEN_GEBGP_ONESTEP(0); blB += 4*RhsProgress; blA += 1*LhsProgress; @@ -1678,13 +1791,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga RhsPacket B_0; #define EIGEN_GEBGP_ONESTEP(K) \ - do { \ - EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX1"); \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX1"); \ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \ - traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \ - traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \ - traits.madd(A0, B_0, C0, B_0); \ - EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX1"); \ + traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \ + traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \ + traits.madd(A0, B_0, C0, B_0, fix<0>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX1"); \ } while(false); EIGEN_GEBGP_ONESTEP(0); @@ -1763,15 +1876,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga straits.loadRhsQuad(blA+0*spk, B_0); straits.loadRhsQuad(blA+1*spk, B_1); - straits.madd(A0,B_0,C0,B_0); - straits.madd(A1,B_1,C1,B_1); + straits.madd(A0,B_0,C0,B_0, fix<0>); + straits.madd(A1,B_1,C1,B_1, fix<0>); straits.loadLhsUnaligned(blB+2*SwappedTraits::LhsProgress, A0); straits.loadLhsUnaligned(blB+3*SwappedTraits::LhsProgress, A1); straits.loadRhsQuad(blA+2*spk, B_0); straits.loadRhsQuad(blA+3*spk, B_1); - straits.madd(A0,B_0,C2,B_0); - straits.madd(A1,B_1,C3,B_1); + straits.madd(A0,B_0,C2,B_0, fix<0>); + straits.madd(A1,B_1,C3,B_1, fix<0>); blB += 4*SwappedTraits::LhsProgress; blA += 4*spk; @@ -1784,7 +1897,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga straits.loadLhsUnaligned(blB, A0); straits.loadRhsQuad(blA, B_0); - straits.madd(A0,B_0,C0,B_0); + straits.madd(A0,B_0,C0,B_0, fix<0>); blB += SwappedTraits::LhsProgress; blA += spk; @@ -1808,7 +1921,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga straits.loadLhsUnaligned(blB, a0); straits.loadRhs(blA, b0); SAccPacketHalf c0 = predux_half_dowto4(C0); - straits.madd(a0,b0,c0,b0); + straits.madd(a0,b0,c0,b0, fix<0>); straits.acc(c0, alphav, R); } else |