diff options
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 22 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixVector.h | 174 |
2 files changed, 155 insertions, 41 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index fdd0ec0e9..6c1d882fd 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -15,13 +15,13 @@ namespace Eigen { namespace internal { -enum PacketSizeType { - PacketFull = 0, - PacketHalf, - PacketQuarter +enum GEBPPacketSizeType { + GEBPPacketFull = 0, + GEBPPacketHalf, + GEBPPacketQuarter }; -template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=PacketFull> +template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=GEBPPacketFull> class gebp_traits; @@ -375,10 +375,10 @@ template <int N, typename T1, typename T2, typename T3> struct packet_conditional { typedef T3 type; }; template <typename T1, typename T2, typename T3> -struct packet_conditional<PacketFull, T1, T2, T3> { typedef T1 type; }; +struct packet_conditional<GEBPPacketFull, T1, T2, T3> { typedef T1 type; }; template <typename T1, typename T2, typename T3> -struct packet_conditional<PacketHalf, T1, T2, T3> { typedef T2 type; }; +struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; }; #define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ typedef typename packet_conditional<packet_size, \ @@ -1054,8 +1054,8 @@ protected: #if EIGEN_ARCH_ARM64 && defined EIGEN_VECTORIZE_NEON template<> -struct gebp_traits <float, float, false, false,Architecture::NEON,PacketFull> - : gebp_traits<float,float,false,false,Architecture::Generic,PacketFull> +struct gebp_traits <float, float, false, false,Architecture::NEON,GEBPPacketFull> + : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull> { typedef float RhsPacket; @@ -1203,8 +1203,8 @@ template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMa struct gebp_kernel { typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits; - typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketHalf> HalfTraits; - typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketQuarter> QuarterTraits; + typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits; + typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits; typedef typename Traits::ResScalar ResScalar; typedef typename Traits::LhsPacket LhsPacket; diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 767feb99d..eb1d924e5 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -14,6 +14,54 @@ namespace Eigen { namespace internal { +enum GEMVPacketSizeType { + GEMVPacketFull = 0, + GEMVPacketHalf, + GEMVPacketQuarter +}; + +template <int N, typename T1, typename T2, typename T3> +struct gemv_packet_cond { typedef T3 type; }; + +template <typename T1, typename T2, typename T3> +struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; }; + +template <typename T1, typename T2, typename T3> +struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; }; + +template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull> +class gemv_traits +{ + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; + +#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ + typedef typename gemv_packet_cond<packet_size, \ + typename packet_traits<name ## Scalar>::type, \ + typename packet_traits<name ## Scalar>::half, \ + typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \ + prefix ## name ## Packet + + PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Res, _PacketSize); +#undef PACKET_DECL_COND_PREFIX + +public: + enum { + Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && + unpacket_traits<_RhsPacket>::vectorizable && + int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size), + LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1, + RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1, + ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1 + }; + + typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; + typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; + typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; +}; + + /* Optimized col-major matrix * vector product: * This algorithm processes the matrix per vertical panels, * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments. @@ -30,23 +78,23 @@ namespace internal { template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version> { + typedef gemv_traits<LhsScalar,RhsScalar> Traits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits; + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; -enum { - Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable - && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size), - LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, - RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, - ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1 -}; + typedef typename Traits::LhsPacket LhsPacket; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename packet_traits<LhsScalar>::type _LhsPacket; -typedef typename packet_traits<RhsScalar>::type _RhsPacket; -typedef typename packet_traits<ResScalar>::type _ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; -typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; -typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; -typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, @@ -73,19 +121,33 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; + conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half; + conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter; + const Index lhsStride = lhs.stride(); // TODO: for padded aligned inputs, we could enable aligned reads - enum { LhsAlignment = Unaligned }; + enum { LhsAlignment = Unaligned, + ResPacketSize = Traits::ResPacketSize, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSize = Traits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; const Index n8 = rows-8*ResPacketSize+1; const Index n4 = rows-4*ResPacketSize+1; const Index n3 = rows-3*ResPacketSize+1; const Index n2 = rows-2*ResPacketSize+1; const Index n1 = rows-1*ResPacketSize+1; + const Index n_half = rows-1*ResPacketSizeHalf+1; + const Index n_quarter = rows-1*ResPacketSizeQuarter+1; // TODO: improve the following heuristic: const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4); ResPacket palpha = pset1<ResPacket>(alpha); + ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha); + ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha); for(Index j2=0; j2<cols; j2+=block_cols) { @@ -190,6 +252,28 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); i+=ResPacketSize; } + if(HasHalf && i<n_half) + { + ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0)); + for(Index j=j2; j<jend; j+=1) + { + RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0)); + c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0); + } + pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0))); + i+=ResPacketSizeHalf; + } + if(HasQuarter && i<n_quarter) + { + ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0)); + for(Index j=j2; j<jend; j+=1) + { + RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0)); + c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0); + } + pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0))); + i+=ResPacketSizeQuarter; + } for(;i<rows;++i) { ResScalar c0(0); @@ -213,23 +297,24 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version> { -typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; - -enum { - Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable - && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size), - LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, - RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, - ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1 -}; + typedef gemv_traits<LhsScalar,RhsScalar> Traits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits; + + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; -typedef typename packet_traits<LhsScalar>::type _LhsPacket; -typedef typename packet_traits<RhsScalar>::type _RhsPacket; -typedef typename packet_traits<ResScalar>::type _ResPacket; + typedef typename Traits::LhsPacket LhsPacket; + static const Index LhsPacketSize = Traits::LhsPacketSize; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; -typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; -typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; + + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, @@ -254,6 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs eigen_internal_assert(rhs.stride()==1); conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; + conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half; + conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter; // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large, // processing 8 rows at once might be counter productive wrt cache. @@ -262,7 +349,16 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs const Index n2 = rows-1; // TODO: for padded aligned inputs, we could enable aligned reads - enum { LhsAlignment = Unaligned }; + enum { LhsAlignment = Unaligned, + ResPacketSize = Traits::ResPacketSize, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSize = Traits::LhsPacketSize, + LhsPacketSizeHalf = HalfTraits::LhsPacketSize, + LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; Index i=0; for(; i<n8; i+=8) @@ -383,6 +479,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs for(; i<rows; ++i) { ResPacket c0 = pset1<ResPacket>(ResScalar(0)); + ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0)); + ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { @@ -390,6 +488,22 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0); } ResScalar cc0 = predux(c0); + if (HasHalf) { + for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf) + { + RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0); + c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h); + } + cc0 += predux(c0_h); + } + if (HasQuarter) { + for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter) + { + RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0); + c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q); + } + cc0 += predux(c0_q); + } for(; j<cols; ++j) { cc0 += cj.pmul(lhs(i,j), rhs(j,0)); |