aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
diff options
context:
space:
mode:
authorGravatar Chip Kerchner <chip.kerchner@ibm.com>2021-02-17 17:49:23 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-17 17:49:23 +0000
commit9b51dc7972c9f64727e9c8e8db0c60aaf9aae532 (patch)
treefc74a4266657205346b26ae7a2c78a06a9cb505e /Eigen/src/Core/arch/AltiVec/MatrixProduct.h
parentbe0574e2159ce3d6a1748ba6060bea5dedccdbc9 (diff)
Fixed performance issues for VSX and P10 MMA in general_matrix_matrix_product
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec/MatrixProduct.h')
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProduct.h1306
1 files changed, 698 insertions, 608 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 53116ad89..9d9bbebe5 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -70,7 +70,7 @@ struct quad_traits<double>
// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
// are responsible to extract from convert between Eigen's and MatrixProduct approach.
-const static Packet4f p4f_CONJUGATE = {-1.0f, -1.0f, -1.0f, -1.0f};
+const static Packet4f p4f_CONJUGATE = {float(-1.0), float(-1.0), float(-1.0), float(-1.0)};
const static Packet2d p2d_CONJUGATE = {-1.0, -1.0};
@@ -122,7 +122,7 @@ EIGEN_STRONG_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_b
v.imag(dt(i,j).imag());
} else {
v.real(dt(i,j).real());
- v.imag((Scalar)0.0f);
+ v.imag((Scalar)0.0);
}
return v;
}
@@ -136,7 +136,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar> *bloc
Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
Index ri = 0, j = 0;
- for(; j + vectorSize < cols; j+=vectorSize)
+ for(; j + vectorSize <= cols; j+=vectorSize)
{
Index i = k2;
for(; i < depth; i++)
@@ -192,7 +192,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar> *bloc
Index ri = 0, j = 0;
Scalar *blockAf = (Scalar *)(blockA);
- for(; j + vectorSize < rows; j+=vectorSize)
+ for(; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
@@ -247,7 +247,7 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs
const int vectorSize = quad_traits<Scalar>::vectorsize;
Index ri = 0, j = 0;
- for(; j + N*vectorSize < cols; j+=N*vectorSize)
+ for(; j + N*vectorSize <= cols; j+=N*vectorSize)
{
Index i = k2;
for(; i < depth; i++)
@@ -284,7 +284,7 @@ EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs
const int vectorSize = quad_traits<Scalar>::vectorsize;
Index ri = 0, j = 0;
- for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ for(j = 0; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
@@ -410,15 +410,15 @@ struct lhs_cpack {
const int vectorSize = quad_traits<Scalar>::vectorsize;
Index ri = 0, j = 0;
Scalar *blockAt = reinterpret_cast<Scalar *>(blockA);
- Packet conj = pset1<Packet>((Scalar)-1.0f);
+ Packet conj = pset1<Packet>((Scalar)-1.0);
- for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ for(j = 0; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet, 4> block;
@@ -446,10 +446,10 @@ struct lhs_cpack {
cblock.packet[7] = lhs.template loadPacket<PacketC>(j + 3, i + 2);
}
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32);
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32);
- block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32);
- block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
+ block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
+ block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
if(StorageOrder == RowMajor) ptranspose(block);
@@ -475,7 +475,7 @@ struct lhs_cpack {
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<PacketC, 8> cblock;
if(StorageOrder == ColMajor)
@@ -502,10 +502,10 @@ struct lhs_cpack {
}
PacketBlock<Packet, 4> block;
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32);
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32);
- block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32);
- block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
+ block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
+ block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
if(Conjugate)
{
@@ -585,13 +585,13 @@ struct lhs_pack{
const int vectorSize = quad_traits<Scalar>::vectorsize;
Index ri = 0, j = 0;
- for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ for(j = 0; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet, 4> block;
@@ -637,13 +637,16 @@ struct lhs_pack{
if(PanelMode) ri += offset*(rows - j);
- for(Index i = 0; i < depth; i++)
+ if (j < rows)
{
- Index k = j;
- for(; k < rows; k++)
+ for(Index i = 0; i < depth; i++)
{
- blockA[ri] = lhs(k, i);
- ri += 1;
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockA[ri] = lhs(k, i);
+ ri += 1;
+ }
}
}
@@ -659,16 +662,16 @@ struct rhs_cpack
{
const int vectorSize = quad_traits<Scalar>::vectorsize;
Scalar *blockBt = reinterpret_cast<Scalar *>(blockB);
- Packet conj = pset1<Packet>((Scalar)-1.0f);
+ Packet conj = pset1<Packet>((Scalar)-1.0);
Index ri = 0, j = 0;
- for(; j + vectorSize < cols; j+=vectorSize)
+ for(; j + vectorSize <= cols; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += offset*vectorSize;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<PacketC, 8> cblock;
if(StorageOrder == ColMajor)
@@ -695,10 +698,10 @@ struct rhs_cpack
}
PacketBlock<Packet, 4> block;
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32);
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32);
- block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32);
- block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
+ block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
+ block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
if(StorageOrder == ColMajor) ptranspose(block);
@@ -724,7 +727,7 @@ struct rhs_cpack
if(PanelMode) ri += offset*vectorSize;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<PacketC, 8> cblock;
if(StorageOrder == ColMajor)
@@ -752,10 +755,10 @@ struct rhs_cpack
}
PacketBlock<Packet, 4> block;
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32);
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32);
- block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32);
- block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
+ block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
+ block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
if(Conjugate)
{
@@ -832,13 +835,14 @@ struct rhs_pack {
{
const int vectorSize = quad_traits<Scalar>::vectorsize;
Index ri = 0, j = 0;
- for(; j + vectorSize < cols; j+=vectorSize)
+
+ for(; j + vectorSize <= cols; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += offset*vectorSize;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet, 4> block;
if(StorageOrder == ColMajor)
@@ -883,13 +887,16 @@ struct rhs_pack {
if(PanelMode) ri += offset*(cols - j);
- for(Index i = 0; i < depth; i++)
+ if (j < cols)
{
- Index k = j;
- for(; k < cols; k++)
+ for(Index i = 0; i < depth; i++)
{
- blockB[ri] = rhs(i, k);
- ri += 1;
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockB[ri] = rhs(i, k);
+ ri += 1;
+ }
}
}
if(PanelMode) ri += (cols - j)*(stride - offset - depth);
@@ -905,13 +912,13 @@ struct lhs_pack<double,Index, DataMapper, Packet2d, StorageOrder, PanelMode>
const int vectorSize = quad_traits<double>::vectorsize;
Index ri = 0, j = 0;
- for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ for(j = 0; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet2d, 2> block;
if(StorageOrder == RowMajor)
@@ -970,12 +977,12 @@ struct rhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode>
{
const int vectorSize = quad_traits<double>::vectorsize;
Index ri = 0, j = 0;
- for(; j + 2*vectorSize < cols; j+=2*vectorSize)
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
{
Index i = 0;
if(PanelMode) ri += offset*(2*vectorSize);
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet2d, 4> block;
if(StorageOrder == ColMajor)
@@ -1059,13 +1066,13 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
double *blockAt = reinterpret_cast<double *>(blockA);
Packet conj = pset1<Packet>(-1.0);
- for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ for(j = 0; j + vectorSize <= rows; j+=vectorSize)
{
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet, 2> block;
@@ -1078,8 +1085,8 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
} else {
cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
@@ -1087,8 +1094,8 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
- block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
+ block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
}
pstore<double>(blockAt + ri , block.packet[0]);
@@ -1108,7 +1115,7 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
if(PanelMode) ri += vectorSize*offset;
- for(; i + vectorSize < depth; i+=vectorSize)
+ for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet, 2> block;
@@ -1121,8 +1128,8 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0);
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1);
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETIMAG64);
- block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETIMAG64);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
} else {
cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
@@ -1130,8 +1137,8 @@ struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageO
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1);
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1);
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64);
- block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
}
if(Conjugate)
@@ -1205,7 +1212,7 @@ struct rhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
Packet conj = pset1<Packet>(-1.0);
Index ri = 0, j = 0;
- for(; j + 2*vectorSize < cols; j+=2*vectorSize)
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
{
Index i = 0;
@@ -1221,8 +1228,8 @@ struct rhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2);
cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3);
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64);
- block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
+ block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
pstore<double>(blockBt + ri , block.packet[0]);
pstore<double>(blockBt + ri + 2, block.packet[1]);
@@ -1246,8 +1253,8 @@ struct rhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2); //[c1 c1i]
cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3); //[d1 d1i]
- block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64);
- block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64);
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
if(Conjugate)
{
@@ -1300,25 +1307,84 @@ struct rhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
template<typename Scalar, typename Packet, bool NegativeAccumulate>
-EIGEN_STRONG_INLINE void pger(PacketBlock<Packet, 4> *acc, const Scalar* lhs, const Scalar* rhs)
+EIGEN_STRONG_INLINE void pger(PacketBlock<Packet,4>* acc, const Scalar* lhs, const Packet* rhsV)
{
- Packet lhsV = *((Packet *) lhs);
- Packet rhsV1 = pset1<Packet>(rhs[0]);
- Packet rhsV2 = pset1<Packet>(rhs[1]);
- Packet rhsV3 = pset1<Packet>(rhs[2]);
- Packet rhsV4 = pset1<Packet>(rhs[3]);
+ asm("#pger begin");
+ Packet lhsV = pload<Packet>(lhs);
if(NegativeAccumulate)
{
- acc->packet[0] -= lhsV*rhsV1;
- acc->packet[1] -= lhsV*rhsV2;
- acc->packet[2] -= lhsV*rhsV3;
- acc->packet[3] -= lhsV*rhsV4;
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
} else {
- acc->packet[0] += lhsV*rhsV1;
- acc->packet[1] += lhsV*rhsV2;
- acc->packet[2] += lhsV*rhsV3;
- acc->packet[3] += lhsV*rhsV4;
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
+ asm("#pger end");
+}
+
+template<typename Scalar, typename Packet, bool NegativeAccumulate>
+EIGEN_STRONG_INLINE void pger(PacketBlock<Packet,1>* acc, const Scalar* lhs, const Packet* rhsV)
+{
+ Packet lhsV = pload<Packet>(lhs);
+
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
+EIGEN_STRONG_INLINE void pger(PacketBlock<Packet,4>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
+#else
+ Packet lhsV;
+ Index i = 0;
+ do {
+ lhsV[i] = lhs[i];
+ } while (++i < remaining_rows);
+#endif
+
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
+EIGEN_STRONG_INLINE void pger(PacketBlock<Packet,1>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
+#else
+ Packet lhsV;
+ Index i = 0;
+ do {
+ lhsV[i] = lhs[i];
+ } while (++i < remaining_rows);
+#endif
+
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
}
}
@@ -1399,7 +1465,7 @@ EIGEN_STRONG_INLINE void pgerc(PacketBlock<Packet, 4>& accReal, PacketBlock<Pack
template<typename Scalar, typename Packet>
EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs)
{
- return *((Packet *)lhs);
+ return *((Packet *)lhs);
}
// Zero the accumulator on PacketBlock.
@@ -1412,14 +1478,26 @@ EIGEN_STRONG_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
acc.packet[3] = pset1<Packet>((Scalar)0);
}
+template<typename Scalar, typename Packet>
+EIGEN_STRONG_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+}
+
// Scale the PacketBlock vectors by alpha.
template<typename Packet>
EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
{
- acc.packet[0] = pmadd(pAlpha,accZ.packet[0], acc.packet[0]);
- acc.packet[1] = pmadd(pAlpha,accZ.packet[1], acc.packet[1]);
- acc.packet[2] = pmadd(pAlpha,accZ.packet[2], acc.packet[2]);
- acc.packet[3] = pmadd(pAlpha,accZ.packet[3], acc.packet[3]);
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+}
+
+template<typename Packet>
+EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
}
// Complex version of PacketBlock scaling.
@@ -1471,534 +1549,546 @@ EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res
acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
}
+const static Packet4i mask41 = { -1, 0, 0, 0 };
+const static Packet4i mask42 = { -1, -1, 0, 0 };
+const static Packet4i mask43 = { -1, -1, -1, 0 };
-// PEEL loop factor.
-#define PEEL 10
+const static Packet2l mask21 = { -1, 0 };
-/****************
- * GEMM kernels *
- * **************/
-template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper>
-EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
- Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols)
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows)
{
- const Index remaining_rows = rows % accCols;
- const Index remaining_cols = cols % accRows;
-
- if( strideA == -1 ) strideA = depth;
- if( strideB == -1 ) strideB = depth;
+ if (remaining_rows == 0) {
+ return pset1<Packet>(float(0.0));
+ } else {
+ switch (remaining_rows) {
+ case 1: return Packet(mask41);
+ case 2: return Packet(mask42);
+ default: return Packet(mask43);
+ }
+ }
+}
- const Packet pAlpha = pset1<Packet>(alpha);
- Index col = 0;
- for(; col + accRows <= cols; col += accRows)
- {
- const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows;
- const Scalar *lhs_base = blockA;
+template<>
+EIGEN_STRONG_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
+{
+ if (remaining_rows == 0) {
+ return pset1<Packet2d>(double(0.0));
+ } else {
+ return Packet2d(mask21);
+ }
+}
- Index row = 0;
- for(; row + 6*accCols <= rows; row += 6*accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
- lhs_ptr2 += accCols; \
- pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
- lhs_ptr3 += accCols; \
- pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
- lhs_ptr4 += accCols; \
- pger<Scalar, Packet, false>(&accZero5, lhs_ptr5, rhs_ptr); \
- lhs_ptr5 += accCols; \
- pger<Scalar, Packet, false>(&accZero6, lhs_ptr6, rhs_ptr); \
- lhs_ptr6 += accCols; \
- rhs_ptr += accRows;
+template<typename Packet>
+EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
+{
+ acc.packet[0] = pmadd(pAlpha, pand(accZ.packet[0], pMask), acc.packet[0]);
+ acc.packet[1] = pmadd(pAlpha, pand(accZ.packet[1], pMask), acc.packet[1]);
+ acc.packet[2] = pmadd(pAlpha, pand(accZ.packet[2], pMask), acc.packet[2]);
+ acc.packet[3] = pmadd(pAlpha, pand(accZ.packet[3], pMask), acc.packet[3]);
+}
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols;
- const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
- const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
- const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
- const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols;
- const Scalar *lhs_ptr6 = lhs_base + ((row/accCols) + 5)*strideA*accCols;
-
- PacketBlock<Packet,4> acc1, accZero1;
- PacketBlock<Packet,4> acc2, accZero2;
- PacketBlock<Packet,4> acc3, accZero3;
- PacketBlock<Packet,4> acc4, accZero4;
- PacketBlock<Packet,4> acc5, accZero5;
- PacketBlock<Packet,4> acc6, accZero6;
-
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
- bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero2);
- bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero3);
- bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero4);
- bload<DataMapper, Packet, Index, 4>(acc5, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero5);
- bload<DataMapper, Packet, Index, 5>(acc6, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero6);
+// PEEL loop factor.
+#define PEEL 10
- lhs_ptr1 += accCols*offsetA;
- lhs_ptr2 += accCols*offsetA;
- lhs_ptr3 += accCols*offsetA;
- lhs_ptr4 += accCols*offsetA;
- lhs_ptr5 += accCols*offsetA;
- lhs_ptr6 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_STRONG_INLINE void MICRO_EXTRA_COL(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,1> &accZero,
+ Index remaining_rows,
+ Index remaining_cols)
+{
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+}
- Index k = 0;
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
- prefetch(lhs_ptr2);
- prefetch(lhs_ptr3);
- prefetch(lhs_ptr4);
- prefetch(lhs_ptr5);
- prefetch(lhs_ptr6);
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
-#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_col(
+ const DataMapper& res,
+ const Scalar *lhs_base,
+ const Scalar *rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,1> accZero, acc;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ for(; k < depth; k++)
+ {
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+ }
- bscale<Packet>(acc1,accZero1, pAlpha);
- bscale<Packet>(acc2,accZero2, pAlpha);
- bscale<Packet>(acc3,accZero3, pAlpha);
- bscale<Packet>(acc4,accZero4, pAlpha);
- bscale<Packet>(acc5,accZero5, pAlpha);
- bscale<Packet>(acc6,accZero6, pAlpha);
-
- res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
- res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
- res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
- res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
- res.template storePacketBlock<Packet, 4>(row + 4*accCols, col, acc5);
- res.template storePacketBlock<Packet, 4>(row + 5*accCols, col, acc6);
-#undef MICRO
- }
- for(; row + 5*accCols <= rows; row += 5*accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
- lhs_ptr2 += accCols; \
- pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
- lhs_ptr3 += accCols; \
- pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
- lhs_ptr4 += accCols; \
- pger<Scalar, Packet, false>(&accZero5, lhs_ptr5, rhs_ptr); \
- lhs_ptr5 += accCols; \
- rhs_ptr += accRows;
+ acc.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
+ for(Index i = 0; i < remaining_rows; i++){
+ res(row + i, col) += acc.packet[0][i];
+ }
+}
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
- const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
- const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
- const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
- const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols;
-
- PacketBlock<Packet,4> acc1, accZero1;
- PacketBlock<Packet,4> acc2, accZero2;
- PacketBlock<Packet,4> acc3, accZero3;
- PacketBlock<Packet,4> acc4, accZero4;
- PacketBlock<Packet,4> acc5, accZero5;
-
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
- bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero2);
- bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero3);
- bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero4);
- bload<DataMapper, Packet, Index, 4>(acc5, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero5);
+template<typename Scalar, typename Packet, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void MICRO_EXTRA_ROW(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,4> &accZero,
+ Index remaining_rows)
+{
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+}
- lhs_ptr1 += accCols*offsetA;
- lhs_ptr2 += accCols*offsetA;
- lhs_ptr3 += accCols*offsetA;
- lhs_ptr4 += accCols*offsetA;
- lhs_ptr5 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
- Index k = 0;
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar *lhs_base,
+ const Scalar *rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,4> accZero, acc;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
- prefetch(lhs_ptr2);
- prefetch(lhs_ptr3);
- prefetch(lhs_ptr4);
- prefetch(lhs_ptr5);
+ if (remaining_depth == depth)
+ {
+ for(Index j = 0; j < 4; j++){
+ acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
+ }
+ bscale(acc, accZero, pAlpha, pMask);
+ res.template storePacketBlock<Packet, 4>(row, col, acc);
+ } else {
+ for(; k < depth; k++)
+ {
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+ }
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
-#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
+ for(Index j = 0; j < 4; j++){
+ acc.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
+ }
+ for(Index j = 0; j < 4; j++){
+ for(Index i = 0; i < remaining_rows; i++){
+ res(row + i, col + j) += acc.packet[j][i];
+ }
+ }
+ }
+}
- bscale<Packet>(acc1,accZero1, pAlpha);
- bscale<Packet>(acc2,accZero2, pAlpha);
- bscale<Packet>(acc3,accZero3, pAlpha);
- bscale<Packet>(acc4,accZero4, pAlpha);
- bscale<Packet>(acc5,accZero5, pAlpha);
-
- res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
- res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
- res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
- res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
- res.template storePacketBlock<Packet, 4>(row + 4*accCols, col, acc5);
-#undef MICRO
- }
- for(; row + 4*accCols <= rows; row += 4*accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
- lhs_ptr2 += accCols; \
- pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
- lhs_ptr3 += accCols; \
- pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
- lhs_ptr4 += accCols; \
- rhs_ptr += accRows;
+#define MICRO_DST \
+ PacketBlock<Packet,4> *accZero0, PacketBlock<Packet,4> *accZero1, PacketBlock<Packet,4> *accZero2, \
+ PacketBlock<Packet,4> *accZero3, PacketBlock<Packet,4> *accZero4, PacketBlock<Packet,4> *accZero5, \
+ PacketBlock<Packet,4> *accZero6, PacketBlock<Packet,4> *accZero7
+
+#define MICRO_COL_DST \
+ PacketBlock<Packet,1> *accZero0, PacketBlock<Packet,1> *accZero1, PacketBlock<Packet,1> *accZero2, \
+ PacketBlock<Packet,1> *accZero3, PacketBlock<Packet,1> *accZero4, PacketBlock<Packet,1> *accZero5, \
+ PacketBlock<Packet,1> *accZero6, PacketBlock<Packet,1> *accZero7
+
+#define MICRO_SRC \
+ const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \
+ const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \
+ const Scalar **lhs_ptr6, const Scalar **lhs_ptr7
+
+#define MICRO_ONE \
+ MICRO<unroll_factor, Scalar, Packet, accRows, accCols>(\
+ &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \
+ rhs_ptr, \
+ &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7);
+
+#define MICRO_COL_ONE \
+ MICRO_COL<unroll_factor, Scalar, Packet, Index, accCols>(\
+ &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \
+ rhs_ptr, \
+ &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7, \
+ remaining_cols);
+
+#define MICRO_WORK_ONE(iter) \
+ if (N > iter) { \
+ pger<Scalar, Packet, false>(accZero##iter, *lhs_ptr##iter, rhsV); \
+ *lhs_ptr##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
- const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
- const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
- const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
-
- PacketBlock<Packet,4> acc1, accZero1;
- PacketBlock<Packet,4> acc2, accZero2;
- PacketBlock<Packet,4> acc3, accZero3;
- PacketBlock<Packet,4> acc4, accZero4;
-
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
- bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero2);
- bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero3);
- bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero4);
+#define MICRO_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
- lhs_ptr1 += accCols*offsetA;
- lhs_ptr2 += accCols*offsetA;
- lhs_ptr3 += accCols*offsetA;
- lhs_ptr4 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
- Index k = 0;
+#define MICRO_WORK MICRO_UNROLL(MICRO_WORK_ONE)
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
- prefetch(lhs_ptr2);
- prefetch(lhs_ptr3);
- prefetch(lhs_ptr4);
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
-#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
+#define MICRO_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter){ \
+ bsetzero<Scalar, Packet>(accZero##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ }
- bscale<Packet>(acc1,accZero1, pAlpha);
- bscale<Packet>(acc2,accZero2, pAlpha);
- bscale<Packet>(acc3,accZero3, pAlpha);
- bscale<Packet>(acc4,accZero4, pAlpha);
+#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
- res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
- res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
- res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
- res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
-#undef MICRO
- }
- for(; row + 3*accCols <= rows; row += 3*accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
- lhs_ptr2 += accCols; \
- pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
- lhs_ptr3 += accCols; \
- rhs_ptr += accRows;
+#define MICRO_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
- const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
- const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
+#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
- PacketBlock<Packet,4> acc1, accZero1;
- PacketBlock<Packet,4> acc2, accZero2;
- PacketBlock<Packet,4> acc3, accZero3;
+#define MICRO_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter){ \
+ prefetch(lhs_ptr##iter); \
+ }
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
- bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero2);
- bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero3);
+#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
- lhs_ptr1 += accCols*offsetA;
- lhs_ptr2 += accCols*offsetA;
- lhs_ptr3 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
- Index k = 0;
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
- prefetch(lhs_ptr2);
- prefetch(lhs_ptr3);
+#define MICRO_STORE_ONE(iter) \
+ if (unroll_factor > iter){ \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
+ acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
+ acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
+ bscale(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet, 4>(row + iter*accCols, col, acc); \
+ }
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
-#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
+#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
- bscale<Packet>(acc1,accZero1, pAlpha);
- bscale<Packet>(acc2,accZero2, pAlpha);
- bscale<Packet>(acc3,accZero3, pAlpha);
+#define MICRO_COL_STORE_ONE(iter) \
+ if (unroll_factor > iter){ \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ bscale(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet, 1>(row + iter*accCols, col, acc); \
+ }
- res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
- res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
- res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
-#undef MICRO
- }
- for(; row + 2*accCols <= rows; row += 2*accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
- lhs_ptr2 += accCols; \
- rhs_ptr += accRows;
+#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
- const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
-
- PacketBlock<Packet,4> acc1, accZero1;
- PacketBlock<Packet,4> acc2, accZero2;
+template<int N, typename Scalar, typename Packet, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void MICRO(
+ MICRO_SRC,
+ const Scalar* &rhs_ptr,
+ MICRO_DST)
+ {
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ asm("#unrolled pger? begin");
+ MICRO_WORK
+ asm("#unrolled pger? end");
+ rhs_ptr += accRows;
+ }
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
- bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero2);
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
+ const DataMapper& res,
+ const Scalar *lhs_base,
+ const Scalar *rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ const Packet& pAlpha)
+{
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7;
+ PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,4> acc;
+
+ asm("#unrolled start");
+ MICRO_SRC_PTR
+ asm("#unrolled zero?");
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ MICRO_PREFETCH
+ asm("#unrolled inner loop?");
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_ONE
+ }
+ asm("#unrolled inner loop end?");
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_ONE
+ }
+ MICRO_STORE
- lhs_ptr1 += accCols*offsetA;
- lhs_ptr2 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
- Index k = 0;
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
- prefetch(lhs_ptr2);
+ row += unroll_factor*accCols;
+}
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
-#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
+template<int N, typename Scalar, typename Packet, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void MICRO_COL(
+ MICRO_SRC,
+ const Scalar* &rhs_ptr,
+ MICRO_COL_DST,
+ Index remaining_rows)
+ {
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ asm("#unrolled pger? begin");
+ MICRO_WORK
+ asm("#unrolled pger? end");
+ rhs_ptr += remaining_rows;
+ }
- bscale<Packet>(acc1,accZero1, pAlpha);
- bscale<Packet>(acc2,accZero2, pAlpha);
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
+ const DataMapper& res,
+ const Scalar *lhs_base,
+ const Scalar *rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7;
+ PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,1> acc;
+
+ MICRO_SRC_PTR
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ MICRO_PREFETCH
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_COL_ONE
+ }
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COL_ONE
+ }
+ MICRO_COL_STORE
- res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
- res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
-#undef MICRO
- }
+ row += unroll_factor*accCols;
+}
- for(; row + accCols <= rows; row += accCols)
- {
-#define MICRO() \
- pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
- lhs_ptr1 += accCols; \
- rhs_ptr += accRows;
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col(
+ const DataMapper& res,
+ const Scalar *lhs_base,
+ const Scalar *rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows){
+ gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ }
+ switch( (rows-row)/accCols ){
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_UNROLL
+}
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols;
+/****************
+ * GEMM kernels *
+ * **************/
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
- PacketBlock<Packet,4> acc1, accZero1;
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
- bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
- bsetzero<Scalar, Packet>(accZero1);
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
- lhs_ptr1 += accCols*offsetA;
- rhs_ptr += accRows*offsetB;
- Index k = 0;
- for(; k + PEEL < depth; k+= PEEL)
- {
- prefetch(rhs_ptr);
- prefetch(lhs_ptr1);
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar *lhs_base = blockA;
+ Index row = 0;
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
- MICRO();
-#if PEEL > 8
- MICRO();
- MICRO();
+ asm("#jump table");
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows){
+ gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ }
+ switch( (rows-row)/accCols ){
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
#endif
- }
- for(; k < depth; k++)
- {
- MICRO();
- }
-
- bscale<Packet>(acc1,accZero1, pAlpha);
-
- res.template storePacketBlock<Packet, 4>(row, col, acc1);
-#undef MICRO
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+ default:
+ break;
}
+#undef MAX_UNROLL
+ asm("#jump table end");
+
if(remaining_rows > 0)
{
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
-
- lhs_ptr += remaining_rows*offsetA;
- rhs_ptr += accRows*offsetB;
- for(Index k = 0; k < depth; k++)
- {
- for(Index arow = 0; arow < remaining_rows; arow++)
- {
- for(Index acol = 0; acol < accRows; acol++ )
- {
- res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += accRows;
- lhs_ptr += remaining_rows;
- }
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask);
}
}
if(remaining_cols > 0)
{
- const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows;
+ const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB;
const Scalar *lhs_base = blockA;
- Index row = 0;
- for(; row + accCols <= rows; row += accCols)
+ for(; col < cols; col++)
{
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
+ Index row = 0;
- lhs_ptr += accCols*offsetA;
- rhs_ptr += remaining_cols*offsetB;
- for(Index k = 0; k < depth; k++)
- {
- for(Index arow = 0; arow < accCols; arow++)
- {
- for(Index acol = 0; acol < remaining_cols; acol++ )
- {
- res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += remaining_cols;
- lhs_ptr += accCols;
- }
- }
-
- if(remaining_rows > 0 )
- {
- const Scalar *rhs_ptr = rhs_base;
- const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
+ gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
- lhs_ptr += remaining_rows*offsetA;
- rhs_ptr += remaining_cols*offsetB;
- for(Index k = 0; k < depth; k++)
+ if (remaining_rows > 0)
{
- for(Index arow = 0; arow < remaining_rows; arow++)
- {
- for(Index acol = 0; acol < remaining_cols; acol++ )
- {
- res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += remaining_cols;
- lhs_ptr += remaining_rows;
+ gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
}
+ rhs_base++;
}
}
}
-template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const int accRows, const int accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc,
- Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols)
+ Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const int remaining_rows = rows % accCols;
const int remaining_cols = cols % accRows;
@@ -2018,7 +2108,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl
const Scalar *blockA = (Scalar *) blockAc;
const Scalar *blockB = (Scalar *) blockBc;
- Packet conj = pset1<Packet>((Scalar)-1.0f);
+ Packet conj = pset1<Packet>((Scalar)-1.0);
Index col = 0;
for(; col + accRows <= cols; col += accRows)
@@ -2054,7 +2144,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl
if(!RhsIsReal)
rhs_ptr_imag += accRows*offsetB;
Index k = 0;
- for(; k + PEEL < depth; k+=PEEL)
+ for(; k + PEEL <= depth; k+=PEEL)
{
prefetch(rhs_ptr);
prefetch(rhs_ptr_imag);
@@ -2180,8 +2270,8 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl
{
for(Index acol = 0; acol < 4; acol++ )
{
- scalarAcc[arow][acol].real((Scalar)0.0f);
- scalarAcc[arow][acol].imag((Scalar)0.0f);
+ scalarAcc[arow][acol].real((Scalar)0.0);
+ scalarAcc[arow][acol].imag((Scalar)0.0);
}
}
for(Index k = 0; k < depth; k++)
@@ -2550,24 +2640,24 @@ void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, Conjugat
Index rows, Index depth, Index cols, float alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
- const int accRows = quad_traits<float>::rows;
- const int accCols = quad_traits<float>::size;
- void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index, const int, const int);
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
else{
- gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
#else
- gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2591,22 +2681,22 @@ void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr
const int accRows = quad_traits<float>::rows;
const int accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
- Index, Index, Index, std::complex<float>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2630,21 +2720,21 @@ void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, Conjugat
const int accRows = quad_traits<float>::rows;
const int accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
- Index, Index, Index, std::complex<float>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2668,21 +2758,21 @@ void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, Conjugat
const int accRows = quad_traits<float>::rows;
const int accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
- Index, Index, Index, std::complex<float>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2702,24 +2792,24 @@ void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, Conjug
Index rows, Index depth, Index cols, double alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
- const int accRows = quad_traits<double>::rows;
- const int accCols = quad_traits<double>::size;
- void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index, const int, const int);
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
else{
- gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
#else
- gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper>;
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2743,21 +2833,21 @@ void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper,
const int accRows = quad_traits<double>::rows;
const int accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
- Index, Index, Index, std::complex<double>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2781,21 +2871,21 @@ void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, Conjug
const int accRows = quad_traits<double>::rows;
const int accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
- Index, Index, Index, std::complex<double>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>;
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2819,21 +2909,21 @@ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, Conjug
const int accRows = quad_traits<double>::rows;
const int accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
- Index, Index, Index, std::complex<double>, Index, Index , Index, Index, const int, const int);
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
#ifdef EIGEN_ALTIVEC_MMA_ONLY
//generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
#elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
}
else{
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
}
#else
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>;
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
} // end namespace internal