aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/GeneralBlockPanelKernel.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2018-12-06 15:58:06 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2018-12-06 15:58:06 +0100
commitc53eececb0415834b961cb61cd466907261b4b2f (patch)
tree2fda50e7549fda68501e3a33634dfc2ccfac743a /Eigen/src/Core/products/GeneralBlockPanelKernel.h
parent3fba59ea594eb26446352cb28813b38921439f23 (diff)
Implement AVX512 vectorization of std::complex<float/double>
Diffstat (limited to 'Eigen/src/Core/products/GeneralBlockPanelKernel.h')
-rw-r--r--Eigen/src/Core/products/GeneralBlockPanelKernel.h167
1 files changed, 119 insertions, 48 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index 9ca865bd1..9475a6ecc 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -506,13 +506,28 @@ public:
p = pset1<ResPacket>(ResScalar(0));
}
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
{
- dest = pset1<RhsPacket>(*b);
+ dest = pset1<RhsPacketType>(*b);
}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
+ loadRhsQuad_impl(b,dest, typename conditional<RhsPacketSize==16,true_type,false_type>::type());
+ }
+
+ EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const true_type&) const
+ {
+ // FIXME we can do better!
+ // what we want here is a ploadheight
+ RhsScalar tmp[4] = {b[0],b[0],b[1],b[1]};
+ dest = ploadquad<RhsPacket>(tmp);
+ }
+
+ EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const false_type&) const
+ {
+ eigen_internal_assert(RhsPacketSize<=8);
dest = pset1<RhsPacket>(*b);
}
@@ -521,9 +536,10 @@ public:
dest = pload<LhsPacket>(a);
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploadu<LhsPacket>(a);
+ dest = ploadu<LhsPacketType>(a);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
@@ -536,12 +552,14 @@ public:
// pbroadcast2(b, b0, b1);
// }
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const
{
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
}
- EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const
{
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
EIGEN_UNUSED_VARIABLE(tmp);
@@ -556,13 +574,14 @@ public:
c += a * b;
}
- EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ template <typename ResPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const
{
+ const conj_helper<ResPacketType,ResPacketType,ConjLhs,false> cj;
r = cj.pmadd(c,alpha,r);
}
protected:
- conj_helper<ResPacket,ResPacket,ConjLhs,false> cj;
};
template<typename Packet>
@@ -581,13 +600,57 @@ DoublePacket<Packet> padd(const DoublePacket<Packet> &a, const DoublePacket<Pack
return res;
}
+// note that for DoublePacket<RealPacket> the "4" in "downto4"
+// corresponds to the number of complexes, so it means "8"
+// it terms of real coefficients.
+
template<typename Packet>
-const DoublePacket<Packet>& predux_half_dowto4(const DoublePacket<Packet> &a)
+const DoublePacket<Packet>&
+predux_half_dowto4(const DoublePacket<Packet> &a,
+ typename enable_if<unpacket_traits<Packet>::size<=8>::type* = 0)
{
return a;
}
-template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > { typedef DoublePacket<Packet> half; };
+template<typename Packet>
+DoublePacket<typename unpacket_traits<Packet>::half>
+predux_half_dowto4(const DoublePacket<Packet> &a,
+ typename enable_if<unpacket_traits<Packet>::size==16>::type* = 0)
+{
+ // yes, that's pretty hackish :(
+ DoublePacket<typename unpacket_traits<Packet>::half> res;
+ typedef std::complex<typename unpacket_traits<Packet>::type> Cplx;
+ typedef typename packet_traits<Cplx>::type CplxPacket;
+ res.first = predux_half_dowto4(CplxPacket(a.first)).v;
+ res.second = predux_half_dowto4(CplxPacket(a.second)).v;
+ return res;
+}
+
+// same here, "quad" actually means "8" in terms of real coefficients
+template<typename Scalar, typename RealPacket>
+void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
+ typename enable_if<unpacket_traits<RealPacket>::size<=8>::type* = 0)
+{
+ dest.first = pset1<RealPacket>(real(*b));
+ dest.second = pset1<RealPacket>(imag(*b));
+}
+
+template<typename Scalar, typename RealPacket>
+void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
+ typename enable_if<unpacket_traits<RealPacket>::size==16>::type* = 0)
+{
+ // yes, that's pretty hackish too :(
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ RealScalar r[4] = {real(b[0]), real(b[0]), real(b[1]), real(b[1])};
+ RealScalar i[4] = {imag(b[0]), imag(b[0]), imag(b[1]), imag(b[1])};
+ dest.first = ploadquad<RealPacket>(r);
+ dest.second = ploadquad<RealPacket>(i);
+}
+
+
+template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > {
+ typedef DoublePacket<typename unpacket_traits<Packet>::half> half;
+};
// template<typename Packet>
// DoublePacket<Packet> pmadd(const DoublePacket<Packet> &a, const DoublePacket<Packet> &b)
// {
@@ -611,10 +674,10 @@ public:
ConjRhs = _ConjRhs,
Vectorizable = packet_traits<RealScalar>::Vectorizable
&& packet_traits<Scalar>::Vectorizable,
- RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
+ ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
+ LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
+ RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
+ RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
// FIXME: should depend on NumberOfRegisters
nr = 4,
@@ -626,7 +689,7 @@ public:
typedef typename packet_traits<RealScalar>::type RealPacket;
typedef typename packet_traits<Scalar>::type ScalarPacket;
- typedef DoublePacket<RealPacket> DoublePacketType;
+ typedef DoublePacket<RealPacket> DoublePacketType;
typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type LhsPacket4Packing;
typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
@@ -643,16 +706,17 @@ public:
}
// Scalar path
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ResPacket& dest) const
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ScalarPacket& dest) const
{
- dest = pset1<ResPacket>(*b);
+ dest = pset1<ScalarPacket>(*b);
}
// Vectorized path
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacketType& dest) const
+ template<typename RealPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const
{
- dest.first = pset1<RealPacket>(real(*b));
- dest.second = pset1<RealPacket>(imag(*b));
+ dest.first = pset1<RealPacketType>(real(*b));
+ dest.second = pset1<RealPacketType>(imag(*b));
}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const
@@ -661,8 +725,7 @@ public:
}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const
{
- eigen_internal_assert(unpacket_traits<ScalarPacket>::size<=4);
- loadRhs(b,dest);
+ loadQuadToDoublePacket(b,dest);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
@@ -696,12 +759,14 @@ public:
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
+ dest = ploadu<LhsPacketType>((const typename unpacket_traits<LhsPacketType>::type*)(a));
}
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacketType& c, RhsPacket& /*tmp*/) const
+ template<typename LhsPacketType, typename RhsPacketType, typename ResPacketType, typename TmpType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/) const
{
c.first = padd(pmul(a,b.first), c.first);
c.second = padd(pmul(a,b.second),c.second);
@@ -714,29 +779,30 @@ public:
EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; }
- EIGEN_STRONG_INLINE void acc(const DoublePacketType& c, const ResPacket& alpha, ResPacket& r) const
+ template<typename RealPacketType, typename ResPacketType>
+ EIGEN_STRONG_INLINE void acc(const DoublePacket<RealPacketType>& c, const ResPacketType& alpha, ResPacketType& r) const
{
// assemble c
- ResPacket tmp;
+ ResPacketType tmp;
if((!ConjLhs)&&(!ConjRhs))
{
- tmp = pcplxflip(pconj(ResPacket(c.second)));
- tmp = padd(ResPacket(c.first),tmp);
+ tmp = pcplxflip(pconj(ResPacketType(c.second)));
+ tmp = padd(ResPacketType(c.first),tmp);
}
else if((!ConjLhs)&&(ConjRhs))
{
- tmp = pconj(pcplxflip(ResPacket(c.second)));
- tmp = padd(ResPacket(c.first),tmp);
+ tmp = pconj(pcplxflip(ResPacketType(c.second)));
+ tmp = padd(ResPacketType(c.first),tmp);
}
else if((ConjLhs)&&(!ConjRhs))
{
- tmp = pcplxflip(ResPacket(c.second));
- tmp = padd(pconj(ResPacket(c.first)),tmp);
+ tmp = pcplxflip(ResPacketType(c.second));
+ tmp = padd(pconj(ResPacketType(c.first)),tmp);
}
else if((ConjLhs)&&(ConjRhs))
{
- tmp = pcplxflip(ResPacket(c.second));
- tmp = psub(pconj(ResPacket(c.first)),tmp);
+ tmp = pcplxflip(ResPacketType(c.second));
+ tmp = psub(pconj(ResPacketType(c.first)),tmp);
}
r = pmadd(tmp,alpha,r);
@@ -789,9 +855,10 @@ public:
p = pset1<ResPacket>(ResScalar(0));
}
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
{
- dest = pset1<RhsPacket>(*b);
+ dest = pset1<RhsPacketType>(*b);
}
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
@@ -813,21 +880,23 @@ public:
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
- eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
- loadRhs(b,dest);
+ dest = ploadquad<RhsPacket>(b);
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploaddup<LhsPacket>(a);
+ dest = ploaddup<LhsPacketType>(a);
}
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const
{
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
}
- EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const
{
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
EIGEN_UNUSED_VARIABLE(tmp);
@@ -843,13 +912,15 @@ public:
c += a * b;
}
- EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ template <typename ResPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const
{
+ const conj_helper<ResPacketType,ResPacketType,false,ConjRhs> cj;
r = cj.pmadd(alpha,c,r);
}
protected:
- conj_helper<ResPacket,ResPacket,false,ConjRhs> cj;
+
};
@@ -1649,7 +1720,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
const int SResPacketQuarterSize = unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
if ((SwappedTraits::LhsProgress % 4) == 0 &&
(SwappedTraits::LhsProgress<=16) &&
- (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) &&
+ (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) &&
(SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr))
{
SAccPacket C0, C1, C2, C3;
@@ -1704,7 +1775,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
// Special case where we have to first reduce the accumulation register C0
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SRhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
@@ -1734,7 +1805,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
// template form, so that LhsProgress < 16 paths don't
// fail to compile
last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
- p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0);
+ p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0);
}
else
{