diff options
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 113 |
1 files changed, 48 insertions, 65 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index bfc7d1979..04b7bfa7e 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -355,6 +355,16 @@ struct RhsPanelHelper { typedef typename conditional<(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS < 32), RhsPacket, inter_type>::type type; }; +template <typename Packet> +struct QuadPacket +{ + Packet B_0, B1, B2, B3; + const Packet& get(const FixedInt<0>&) const { return B_0; } + const Packet& get(const FixedInt<1>&) const { return B1; } + const Packet& get(const FixedInt<2>&) const { return B2; } + const Packet& get(const FixedInt<3>&) const { return B3; } +}; + /* Vectorization logic * real*real: unpack rhs to constant packets, ... * @@ -412,14 +422,7 @@ public: typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; - typedef struct { - RhsPacket B_0, B1, B2, B3; - const RhsPacket& get(const FixedInt<0>&) const { return B_0; } - const RhsPacket& get(const FixedInt<1>&) const { return B1; } - const RhsPacket& get(const FixedInt<2>&) const { return B2; } - const RhsPacket& get(const FixedInt<3>&) const { return B3; } - } RhsPacketx4; - + typedef QuadPacket<RhsPacket> RhsPacketx4; typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -465,8 +468,8 @@ public: dest = ploadu<LhsPacketType>(a); } - template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const FixedInt&) const + template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const { conj_helper<LhsPacketType,RhsPacketType,ConjLhs,ConjRhs> cj; // It would be a lot cleaner to call pmadd all the time. Unfortunately if we @@ -481,10 +484,9 @@ public: #endif } - template<typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + template<typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const { - RhsPacket tmp; madd(a, b.get(lane), c, tmp, lane); } @@ -539,13 +541,7 @@ public: typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; - typedef struct { - RhsPacket B_0, B1, B2, B3; - const RhsPacket& get(const FixedInt<0>&) const { return B_0; } - const RhsPacket& get(const FixedInt<1>&) const { return B1; } - const RhsPacket& get(const FixedInt<2>&) const { return B2; } - const RhsPacket& get(const FixedInt<3>&) const { return B3; } - } RhsPacketx4; + typedef QuadPacket<RhsPacket> RhsPacketx4; typedef ResPacket AccPacket; @@ -604,8 +600,8 @@ public: dest = ploadu<LhsPacketType>(a); } - template <typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const + template <typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType&) const { madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); } @@ -626,10 +622,9 @@ public: c += a * b; } - template<typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + template<typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const { - RhsPacket tmp; madd(a, b.get(lane), c, tmp, lane); } @@ -756,13 +751,8 @@ public: typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket; typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type AccPacket; - typedef struct { - RhsPacket B_0, B1, B2, B3; - const RhsPacket& get(const FixedInt<0>&) const { return B_0; } - const RhsPacket& get(const FixedInt<1>&) const { return B1; } - const RhsPacket& get(const FixedInt<2>&) const { return B2; } - const RhsPacket& get(const FixedInt<3>&) const { return B3; } - } RhsPacketx4; + // this actualy holds 8 packets! + typedef QuadPacket<RhsPacket> RhsPacketx4; EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); } @@ -807,9 +797,7 @@ public: loadRhs(b, dest); } - EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const - { - } + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {} EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const { @@ -832,23 +820,22 @@ public: dest = ploadu<LhsPacketType>((const typename unpacket_traits<LhsPacketType>::type*)(a)); } - template<typename LhsPacketType, typename ResPacketType, typename TmpType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/, const FixedInt&) const + template<typename LhsPacketType, typename ResPacketType, typename TmpType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/, const LaneIdType&) const { c.first = padd(pmul(a,b.first), c.first); c.second = padd(pmul(a,b.second),c.second); } - template <typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/, const FixedInt&) const + template<typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/, const LaneIdType&) const { c = cj.pmadd(a,b,c); } - template<typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + template<typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const { - RhsPacket tmp; madd(a, b.get(lane), c, tmp, lane); } @@ -922,15 +909,7 @@ public: typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; typedef LhsPacket LhsPacket4Packing; - - typedef struct { - RhsPacket B_0, B1, B2, B3; - const RhsPacket& get(const FixedInt<0>&) const { return B_0; } - const RhsPacket& get(const FixedInt<1>&) const { return B1; } - const RhsPacket& get(const FixedInt<2>&) const { return B2; } - const RhsPacket& get(const FixedInt<3>&) const { return B3; } - } RhsPacketx4; - + typedef QuadPacket<RhsPacket> RhsPacketx4; typedef ResPacket AccPacket; EIGEN_STRONG_INLINE void initAcc(AccPacket& p) @@ -974,8 +953,8 @@ public: dest = ploaddup<LhsPacketType>(a); } - template <typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const FixedInt&) const + template <typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacket& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType&) const { madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); } @@ -997,10 +976,9 @@ public: c += a * b; } - template<typename LhsPacketType, typename AccPacketType, typename FixedInt> - EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacketx4&, const FixedInt& lane) const + template<typename LhsPacketType, typename AccPacketType, typename LaneIdType> + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const { - RhsPacket tmp; madd(a, b.get(lane), c, tmp, lane); } @@ -1054,22 +1032,22 @@ struct gebp_traits <float, float, false, false,Architecture::NEON> c += a * b; } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<0>&) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const { c = vfmaq_lane_f32(c, a, vget_low_f32(b), 0); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<1>&) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const { c = vfmaq_lane_f32(c, a, vget_low_f32(b), 1); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<2>&) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const { c = vfmaq_lane_f32(c, a, vget_high_f32(b), 0); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacketx4& /*tmp*/, const FixedInt<3>&) const + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const { c = vfmaq_lane_f32(c, a, vget_high_f32(b), 1); } @@ -1277,7 +1255,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga { EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4"); // 15 registers are taken (12 for acc, 2 for lhs). - RhsPanel15 rhs_panel, T0; + RhsPanel15 rhs_panel; + RhsPacket T0; LhsPacket A2; #define EIGEN_GEBP_ONESTEP(K) \ @@ -1510,7 +1489,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga for(Index k=0; k<peeled_kc; k+=pk) { EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX4"); - RhsPacketx4 rhs_panel, T0; + RhsPacketx4 rhs_panel; + RhsPacket T0; // NOTE: the begin/end asm comments below work around bug 935! // but they are not enough for gcc>=6 without FMA (bug 1637) @@ -1556,7 +1536,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - RhsPacketx4 rhs_panel, T0; + RhsPacketx4 rhs_panel; + RhsPacket T0; EIGEN_GEBGP_ONESTEP(0); blB += 4*RhsProgress; blA += 2*Traits::LhsProgress; @@ -1709,7 +1690,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga for(Index k=0; k<peeled_kc; k+=pk) { EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX4"); - RhsPacketx4 rhs_panel, T0; + RhsPacketx4 rhs_panel; + RhsPacket T0; #define EIGEN_GEBGP_ONESTEP(K) \ do { \ @@ -1743,7 +1725,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // process remaining peeled loop for(Index k=peeled_kc; k<depth; k++) { - RhsPacketx4 rhs_panel, T0; + RhsPacketx4 rhs_panel; + RhsPacket T0; EIGEN_GEBGP_ONESTEP(0); blB += 4*RhsProgress; blA += 1*LhsProgress; |