aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/GeneralBlockPanelKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/products/GeneralBlockPanelKernel.h')
-rw-r--r--Eigen/src/Core/products/GeneralBlockPanelKernel.h19
1 files changed, 10 insertions, 9 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index 54e118395..4c1a63d40 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -595,7 +595,7 @@ DoublePacket<Packet> padd(const DoublePacket<Packet> &a, const DoublePacket<Pack
}
template<typename Packet>
-const DoublePacket<Packet>& predux4(const DoublePacket<Packet> &a)
+const DoublePacket<Packet>& predux_half(const DoublePacket<Packet> &a)
{
return a;
}
@@ -1628,9 +1628,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
prefetch(&blA[0]);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
- if( (SwappedTraits::LhsProgress % 4)==0 )
+ // NOTE The following piece of code doesn't work for 512 bit registers,
+ // so we don't call it for registers that contain more than 8 values.
+ if( ((SwappedTraits::LhsProgress % 4)==0) && (SwappedTraits::LhsProgress <= 8))
{
- // NOTE The following piece of code wont work for 512 bit registers
SAccPacket C0, C1, C2, C3;
straits.initAcc(C0);
straits.initAcc(C1);
@@ -1681,10 +1682,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
if(SwappedTraits::LhsProgress==8)
{
// Special case where we have to first reduce the accumulation register C0
- typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
@@ -1696,13 +1697,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
SRhsPacketHalf b0;
straits.loadLhsUnaligned(blB, a0);
straits.loadRhs(blA, b0);
- SAccPacketHalf c0 = predux4(C0);
+ SAccPacketHalf c0 = predux_half(C0);
straits.madd(a0,b0,c0,b0);
straits.acc(c0, alphav, R);
}
else
{
- straits.acc(predux4(C0), alphav, R);
+ straits.acc(predux_half(C0), alphav, R);
}
res.scatterPacket(i, j2, R);
}