From dbfcceabf50db9c1dc6d82863aa9670a1b53c0a4 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 16 Jan 2019 12:51:36 +0800 Subject: Bug: 1633: refactor gebp kernel and optimize for neon --- Eigen/src/Core/products/GeneralBlockPanelKernel.h | 435 ++++++++++++++-------- 1 file changed, 274 insertions(+), 161 deletions(-) (limited to 'Eigen') diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index d6dd9dc17..bfc7d1979 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -347,6 +347,14 @@ inline void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_ // #define CJMADD(CJ,A,B,C,T) T = B; T = CJ.pmul(A,T); C = padd(C,T); #endif +template +struct RhsPanelHelper { + private: + typedef typename conditional<(registers_taken < 15), RhsPacket, RhsPacketx4>::type inter_type; + public: + typedef typename conditional<(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS < 32), RhsPacket, inter_type>::type type; +}; + /* Vectorization logic * real*real: unpack rhs to constant packets, ... * @@ -404,29 +412,42 @@ public: typedef typename conditional::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) { p = pset1(ResScalar(0)); } - - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) - { - pbroadcast4(b, b0, b1, b2, b3); - } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// pbroadcast2(b, b0, b1); -// } - + template EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const { dest = pset1(*b); } - + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); + } + + template + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + { + } + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad(b); @@ -444,8 +465,8 @@ public: dest = ploadu(a); } - template - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, AccPacketType& tmp) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const FixedInt&) const { conj_helper cj; // It would be a lot cleaner to call pmadd all the time. Unfortunately if we @@ -460,6 +481,13 @@ public: #endif } + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const { r = pmadd(c,alpha,r); @@ -511,6 +539,14 @@ public: typedef typename conditional::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -523,6 +559,20 @@ public: { dest = pset1(*b); } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); + } + + template + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { @@ -554,18 +604,8 @@ public: dest = ploadu(a); } - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) - { - pbroadcast4(b, b0, b1, b2, b3); - } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// pbroadcast2(b, b0, b1); -// } - - template - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const { madd_impl(a, b, c, tmp, typename conditional::type()); } @@ -586,6 +626,13 @@ public: c += a * b; } + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + template EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { @@ -708,6 +755,14 @@ public: typedef typename conditional::type RhsPacket; typedef typename conditional::type ResPacket; typedef typename conditional::type AccPacket; + + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); } @@ -730,39 +785,39 @@ public: dest.first = pset1(real(*b)); dest.second = pset1(imag(*b)); } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const { - loadRhs(b,dest); + loadRhs(b, dest.B_0); + loadRhs(b + 1, dest.B1); + loadRhs(b + 2, dest.B2); + loadRhs(b + 3, dest.B3); } - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const + + // Scalar path + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, ScalarPacket& dest) const { - loadQuadToDoublePacket(b,dest); + loadRhs(b, dest); } - - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + + // Vectorized path + template + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, DoublePacket& dest) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); - loadRhs(b+2, b2); - loadRhs(b+3, b3); + loadRhs(b, dest); } - - // Vectorized path - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1) + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); } - // Scalar path - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsScalar& b0, RhsScalar& b1) + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const + { + loadRhs(b,dest); + } + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const { - // FIXME not sure that's the best way to implement it! - loadRhs(b+0, b0); - loadRhs(b+1, b1); + loadQuadToDoublePacket(b,dest); } // nothing special here @@ -777,17 +832,25 @@ public: dest = ploadu((const typename unpacket_traits::type*)(a)); } - template - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket& c, TmpType& /*tmp*/) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, DoublePacket& c, TmpType& /*tmp*/, const FixedInt&) const { c.first = padd(pmul(a,b.first), c.first); c.second = padd(pmul(a,b.second),c.second); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/, const FixedInt&) const { c = cj.pmadd(a,b,c); } + + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; } @@ -860,6 +923,14 @@ public: typedef typename conditional::type ResPacket; typedef LhsPacket LhsPacket4Packing; + typedef struct { + RhsPacket B_0, B1, B2, B3; + const RhsPacket& get(const FixedInt<0>&) const { return B_0; } + const RhsPacket& get(const FixedInt<1>&) const { return B1; } + const RhsPacket& get(const FixedInt<2>&) const { return B2; } + const RhsPacket& get(const FixedInt<3>&) const { return B3; } + } RhsPacketx4; + typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -872,18 +943,20 @@ public: { dest = pset1(*b); } - - void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const { - pbroadcast4(b, b0, b1, b2, b3); + pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3); } - -// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) -// { -// // FIXME not sure that's the best way to implement it! -// b0 = pload1(b+0); -// b1 = pload1(b+1); -// } + + template + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const + { + loadRhs(b, dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const { @@ -901,8 +974,8 @@ public: dest = ploaddup(a); } - template - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const { madd_impl(a, b, c, tmp, typename conditional::type()); } @@ -924,6 +997,13 @@ public: c += a * b; } + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + { + RhsPacket tmp; + madd(a, b.get(lane), c, tmp, lane); + } + template EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { @@ -932,7 +1012,7 @@ public: } protected: - + }; @@ -944,27 +1024,54 @@ struct gebp_traits { typedef float RhsPacket; - EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + typedef float32x4_t RhsPacketx4; + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { - loadRhs(b+0, b0); - loadRhs(b+1, b1); - loadRhs(b+2, b2); - loadRhs(b+3, b3); + dest = *b; } - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + dest = vld1q_f32(b); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { dest = *b; } + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const + {} + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b,dest); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c += a * b; + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<0>&) const + { + c = vfmaq_lane_f32(c, a, vget_low_f32(b), 0); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<1>&) const + { + c = vfmaq_lane_f32(c, a, vget_low_f32(b), 1); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<2>&) const + { + c = vfmaq_lane_f32(c, a, vget_high_f32(b), 0); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<3>&) const { - c = vfmaq_n_f32(c, a, b); + c = vfmaq_lane_f32(c, a, vget_high_f32(b), 1); } }; @@ -986,6 +1093,9 @@ struct gebp_kernel typedef typename Traits::RhsPacket RhsPacket; typedef typename Traits::ResPacket ResPacket; typedef typename Traits::AccPacket AccPacket; + typedef typename Traits::RhsPacketx4 RhsPacketx4; + + typedef typename RhsPanelHelper::type RhsPanel15; typedef gebp_traits SwappedTraits; typedef typename SwappedTraits::ResScalar SResScalar; @@ -1075,7 +1185,7 @@ struct last_row_process_16_packets); blB += SwappedTraits::LhsProgress/4; blA += 1; } @@ -1166,36 +1276,39 @@ void gebp_kernel); \ + traits.madd(A1, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A2, rhs_panel, C8, T0, fix<0>); \ + traits.updateRhs(blB + (1+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A2, rhs_panel, C9, T0, fix<1>); \ + traits.updateRhs(blB + (2+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A2, rhs_panel, C10, T0, fix<2>); \ + traits.updateRhs(blB + (3+4*K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C7, T0, fix<3>); \ + traits.madd(A2, rhs_panel, C11, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \ + } while (false) internal::prefetch(blB); EIGEN_GEBP_ONESTEP(0); @@ -1215,7 +1328,7 @@ void gebp_kernel); \ + traits.madd(A1, B_0, C4, B_0, fix<0>); \ + traits.madd(A2, B_0, C8, B_0, fix<0>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \ + } while (false) + EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(1); EIGEN_GEBGP_ONESTEP(2); @@ -1397,7 +1510,7 @@ void gebp_kernel=6 without FMA (bug 1637) @@ -1406,24 +1519,24 @@ void gebp_kernel); \ + traits.madd(A1, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C7, T0, fix<3>); \ + EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \ + } while (false) + internal::prefetch(blB+(48+0)); EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(1); @@ -1443,7 +1556,7 @@ void gebp_kernel); \ + traits.madd(A1, B_0, C4, B_0, fix<0>); \ EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX1"); \ } while(false) @@ -1596,19 +1709,19 @@ void gebp_kernel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX4"); \ } while(false) internal::prefetch(blB+(48+0)); @@ -1630,7 +1743,7 @@ void gebp_kernel); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX1"); \ } while(false); EIGEN_GEBGP_ONESTEP(0); @@ -1763,15 +1876,15 @@ void gebp_kernel); + straits.madd(A1,B_1,C1,B_1, fix<0>); straits.loadLhsUnaligned(blB+2*SwappedTraits::LhsProgress, A0); straits.loadLhsUnaligned(blB+3*SwappedTraits::LhsProgress, A1); straits.loadRhsQuad(blA+2*spk, B_0); straits.loadRhsQuad(blA+3*spk, B_1); - straits.madd(A0,B_0,C2,B_0); - straits.madd(A1,B_1,C3,B_1); + straits.madd(A0,B_0,C2,B_0, fix<0>); + straits.madd(A1,B_1,C3,B_1, fix<0>); blB += 4*SwappedTraits::LhsProgress; blA += 4*spk; @@ -1784,7 +1897,7 @@ void gebp_kernel); blB += SwappedTraits::LhsProgress; blA += spk; @@ -1808,7 +1921,7 @@ void gebp_kernel); straits.acc(c0, alphav, R); } else -- cgit v1.2.3