aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Gustavo Lima Chaves <gustavo.lima.chaves@intel.com>2019-04-26 14:12:39 -0700
committerGravatar Gustavo Lima Chaves <gustavo.lima.chaves@intel.com>2019-04-26 14:12:39 -0700
commitd4dcb71bcb6a9c05fc417edcdbaa9841bbd02400 (patch)
treeec28b2bf142f70347ec194e7def6576fe93d503f /Eigen/src/Core/products
parent665ac22cc6b1cb86a3f9200e0b6b9eb7dbdc834e (diff)
Speed up GEMV on AVX-512 builds, just as done for GEBP previously.
We take advantage of smaller SIMD registers as well, in that case. Gains up to 3x for select input sizes.
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralBlockPanelKernel.h22
-rw-r--r--Eigen/src/Core/products/GeneralMatrixVector.h174
2 files changed, 155 insertions, 41 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index fdd0ec0e9..6c1d882fd 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -15,13 +15,13 @@ namespace Eigen {
namespace internal {
-enum PacketSizeType {
- PacketFull = 0,
- PacketHalf,
- PacketQuarter
+enum GEBPPacketSizeType {
+ GEBPPacketFull = 0,
+ GEBPPacketHalf,
+ GEBPPacketQuarter
};
-template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=PacketFull>
+template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=GEBPPacketFull>
class gebp_traits;
@@ -375,10 +375,10 @@ template <int N, typename T1, typename T2, typename T3>
struct packet_conditional { typedef T3 type; };
template <typename T1, typename T2, typename T3>
-struct packet_conditional<PacketFull, T1, T2, T3> { typedef T1 type; };
+struct packet_conditional<GEBPPacketFull, T1, T2, T3> { typedef T1 type; };
template <typename T1, typename T2, typename T3>
-struct packet_conditional<PacketHalf, T1, T2, T3> { typedef T2 type; };
+struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
typedef typename packet_conditional<packet_size, \
@@ -1054,8 +1054,8 @@ protected:
#if EIGEN_ARCH_ARM64 && defined EIGEN_VECTORIZE_NEON
template<>
-struct gebp_traits <float, float, false, false,Architecture::NEON,PacketFull>
- : gebp_traits<float,float,false,false,Architecture::Generic,PacketFull>
+struct gebp_traits <float, float, false, false,Architecture::NEON,GEBPPacketFull>
+ : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
{
typedef float RhsPacket;
@@ -1203,8 +1203,8 @@ template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMa
struct gebp_kernel
{
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
- typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketHalf> HalfTraits;
- typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketQuarter> QuarterTraits;
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
typedef typename Traits::ResScalar ResScalar;
typedef typename Traits::LhsPacket LhsPacket;
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h
index 767feb99d..eb1d924e5 100644
--- a/Eigen/src/Core/products/GeneralMatrixVector.h
+++ b/Eigen/src/Core/products/GeneralMatrixVector.h
@@ -14,6 +14,54 @@ namespace Eigen {
namespace internal {
+enum GEMVPacketSizeType {
+ GEMVPacketFull = 0,
+ GEMVPacketHalf,
+ GEMVPacketQuarter
+};
+
+template <int N, typename T1, typename T2, typename T3>
+struct gemv_packet_cond { typedef T3 type; };
+
+template <typename T1, typename T2, typename T3>
+struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; };
+
+template <typename T1, typename T2, typename T3>
+struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; };
+
+template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull>
+class gemv_traits
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
+
+#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
+ typedef typename gemv_packet_cond<packet_size, \
+ typename packet_traits<name ## Scalar>::type, \
+ typename packet_traits<name ## Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
+ prefix ## name ## Packet
+
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+#undef PACKET_DECL_COND_PREFIX
+
+public:
+ enum {
+ Vectorizable = unpacket_traits<_LhsPacket>::vectorizable &&
+ unpacket_traits<_RhsPacket>::vectorizable &&
+ int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size),
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1
+ };
+
+ typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
+ typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
+ typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+};
+
+
/* Optimized col-major matrix * vector product:
* This algorithm processes the matrix per vertical panels,
* which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments.
@@ -30,23 +78,23 @@ namespace internal {
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
{
+ typedef gemv_traits<LhsScalar,RhsScalar> Traits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
+
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
-enum {
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
- && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
-};
+ typedef typename Traits::LhsPacket LhsPacket;
+ typedef typename Traits::RhsPacket RhsPacket;
+ typedef typename Traits::ResPacket ResPacket;
-typedef typename packet_traits<LhsScalar>::type _LhsPacket;
-typedef typename packet_traits<RhsScalar>::type _RhsPacket;
-typedef typename packet_traits<ResScalar>::type _ResPacket;
+ typedef typename HalfTraits::LhsPacket LhsPacketHalf;
+ typedef typename HalfTraits::RhsPacket RhsPacketHalf;
+ typedef typename HalfTraits::ResPacket ResPacketHalf;
-typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
-typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
-typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
+ typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
+ typedef typename QuarterTraits::ResPacket ResPacketQuarter;
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
Index rows, Index cols,
@@ -73,19 +121,33 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
+ conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
+ conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
+
const Index lhsStride = lhs.stride();
// TODO: for padded aligned inputs, we could enable aligned reads
- enum { LhsAlignment = Unaligned };
+ enum { LhsAlignment = Unaligned,
+ ResPacketSize = Traits::ResPacketSize,
+ ResPacketSizeHalf = HalfTraits::ResPacketSize,
+ ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
+ LhsPacketSize = Traits::LhsPacketSize,
+ HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
+ HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
+ };
const Index n8 = rows-8*ResPacketSize+1;
const Index n4 = rows-4*ResPacketSize+1;
const Index n3 = rows-3*ResPacketSize+1;
const Index n2 = rows-2*ResPacketSize+1;
const Index n1 = rows-1*ResPacketSize+1;
+ const Index n_half = rows-1*ResPacketSizeHalf+1;
+ const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
// TODO: improve the following heuristic:
const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4);
ResPacket palpha = pset1<ResPacket>(alpha);
+ ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
+ ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
for(Index j2=0; j2<cols; j2+=block_cols)
{
@@ -190,6 +252,28 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
i+=ResPacketSize;
}
+ if(HasHalf && i<n_half)
+ {
+ ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0));
+ c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
+ }
+ pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
+ i+=ResPacketSizeHalf;
+ }
+ if(HasQuarter && i<n_quarter)
+ {
+ ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0));
+ c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
+ }
+ pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
+ i+=ResPacketSizeQuarter;
+ }
for(;i<rows;++i)
{
ResScalar c0(0);
@@ -213,23 +297,24 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
{
-typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
-
-enum {
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
- && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
-};
+ typedef gemv_traits<LhsScalar,RhsScalar> Traits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
+
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
-typedef typename packet_traits<LhsScalar>::type _LhsPacket;
-typedef typename packet_traits<RhsScalar>::type _RhsPacket;
-typedef typename packet_traits<ResScalar>::type _ResPacket;
+ typedef typename Traits::LhsPacket LhsPacket;
+ static const Index LhsPacketSize = Traits::LhsPacketSize;
+ typedef typename Traits::RhsPacket RhsPacket;
+ typedef typename Traits::ResPacket ResPacket;
-typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
-typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
-typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef typename HalfTraits::LhsPacket LhsPacketHalf;
+ typedef typename HalfTraits::RhsPacket RhsPacketHalf;
+ typedef typename HalfTraits::ResPacket ResPacketHalf;
+
+ typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
+ typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
+ typedef typename QuarterTraits::ResPacket ResPacketQuarter;
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
Index rows, Index cols,
@@ -254,6 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
eigen_internal_assert(rhs.stride()==1);
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
+ conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
+ conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
// TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
// processing 8 rows at once might be counter productive wrt cache.
@@ -262,7 +349,16 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
const Index n2 = rows-1;
// TODO: for padded aligned inputs, we could enable aligned reads
- enum { LhsAlignment = Unaligned };
+ enum { LhsAlignment = Unaligned,
+ ResPacketSize = Traits::ResPacketSize,
+ ResPacketSizeHalf = HalfTraits::ResPacketSize,
+ ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
+ LhsPacketSize = Traits::LhsPacketSize,
+ LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
+ LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
+ HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
+ HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
+ };
Index i=0;
for(; i<n8; i+=8)
@@ -383,6 +479,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
for(; i<rows; ++i)
{
ResPacket c0 = pset1<ResPacket>(ResScalar(0));
+ ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
+ ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
@@ -390,6 +488,22 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
}
ResScalar cc0 = predux(c0);
+ if (HasHalf) {
+ for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
+ {
+ RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
+ c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
+ }
+ cc0 += predux(c0_h);
+ }
+ if (HasQuarter) {
+ for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
+ {
+ RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0);
+ c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
+ }
+ cc0 += predux(c0_q);
+ }
for(; j<cols; ++j)
{
cc0 += cj.pmul(lhs(i,j), rhs(j,0));