From 9b51dc7972c9f64727e9c8e8db0c60aaf9aae532 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 17 Feb 2021 17:49:23 +0000 Subject: Fixed performance issues for VSX and P10 MMA in general_matrix_matrix_product --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1306 ++++++++++++++------------- 1 file changed, 698 insertions(+), 608 deletions(-) (limited to 'Eigen/src/Core/arch/AltiVec/MatrixProduct.h') diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 53116ad89..9d9bbebe5 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -70,7 +70,7 @@ struct quad_traits // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then // are responsible to extract from convert between Eigen's and MatrixProduct approach. -const static Packet4f p4f_CONJUGATE = {-1.0f, -1.0f, -1.0f, -1.0f}; +const static Packet4f p4f_CONJUGATE = {float(-1.0), float(-1.0), float(-1.0), float(-1.0)}; const static Packet2d p2d_CONJUGATE = {-1.0, -1.0}; @@ -122,7 +122,7 @@ EIGEN_STRONG_INLINE std::complex getAdjointVal(Index i, Index j, const_b v.imag(dt(i,j).imag()); } else { v.real(dt(i,j).real()); - v.imag((Scalar)0.0f); + v.imag((Scalar)0.0); } return v; } @@ -136,7 +136,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex *bloc Scalar* blockBf = reinterpret_cast(blockB); Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = k2; for(; i < depth; i++) @@ -192,7 +192,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex *bloc Index ri = 0, j = 0; Scalar *blockAf = (Scalar *)(blockA); - for(; j + vectorSize < rows; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; @@ -247,7 +247,7 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + N*vectorSize < cols; j+=N*vectorSize) + for(; j + N*vectorSize <= cols; j+=N*vectorSize) { Index i = k2; for(; i < depth; i++) @@ -284,7 +284,7 @@ EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; @@ -410,15 +410,15 @@ struct lhs_cpack { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; Scalar *blockAt = reinterpret_cast(blockA); - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -446,10 +446,10 @@ struct lhs_cpack { cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); } - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); if(StorageOrder == RowMajor) ptranspose(block); @@ -475,7 +475,7 @@ struct lhs_cpack { if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -502,10 +502,10 @@ struct lhs_cpack { } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); if(Conjugate) { @@ -585,13 +585,13 @@ struct lhs_pack{ const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -637,13 +637,16 @@ struct lhs_pack{ if(PanelMode) ri += offset*(rows - j); - for(Index i = 0; i < depth; i++) + if (j < rows) { - Index k = j; - for(; k < rows; k++) + for(Index i = 0; i < depth; i++) { - blockA[ri] = lhs(k, i); - ri += 1; + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } } } @@ -659,16 +662,16 @@ struct rhs_cpack { const int vectorSize = quad_traits::vectorsize; Scalar *blockBt = reinterpret_cast(blockB); - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = 0; if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -695,10 +698,10 @@ struct rhs_cpack } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); if(StorageOrder == ColMajor) ptranspose(block); @@ -724,7 +727,7 @@ struct rhs_cpack if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -752,10 +755,10 @@ struct rhs_cpack } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); if(Conjugate) { @@ -832,13 +835,14 @@ struct rhs_pack { { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = 0; if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == ColMajor) @@ -883,13 +887,16 @@ struct rhs_pack { if(PanelMode) ri += offset*(cols - j); - for(Index i = 0; i < depth; i++) + if (j < cols) { - Index k = j; - for(; k < cols; k++) + for(Index i = 0; i < depth; i++) { - blockB[ri] = rhs(i, k); - ri += 1; + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } } } if(PanelMode) ri += (cols - j)*(stride - offset - depth); @@ -905,13 +912,13 @@ struct lhs_pack const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == RowMajor) @@ -970,12 +977,12 @@ struct rhs_pack { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + 2*vectorSize < cols; j+=2*vectorSize) + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; if(PanelMode) ri += offset*(2*vectorSize); - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == ColMajor) @@ -1059,13 +1066,13 @@ struct lhs_cpack(blockA); Packet conj = pset1(-1.0); - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -1078,8 +1085,8 @@ struct lhs_cpack(j + 1, i + 0); //[a2 a2i] cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] } else { cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] @@ -1087,8 +1094,8 @@ struct lhs_cpack(j + 0, i + 1); //[b1 b1i] cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] } pstore(blockAt + ri , block.packet[0]); @@ -1108,7 +1115,7 @@ struct lhs_cpack block; @@ -1121,8 +1128,8 @@ struct lhs_cpack(j + 1, i + 0); cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); } else { cblock.packet[0] = lhs.template loadPacket(j + 0, i); cblock.packet[1] = lhs.template loadPacket(j + 1, i); @@ -1130,8 +1137,8 @@ struct lhs_cpack(j + 0, i + 1); cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); } if(Conjugate) @@ -1205,7 +1212,7 @@ struct rhs_cpack(-1.0); Index ri = 0, j = 0; - for(; j + 2*vectorSize < cols; j+=2*vectorSize) + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; @@ -1221,8 +1228,8 @@ struct rhs_cpack(i, j + 2); cblock.packet[3] = rhs.template loadPacket(i, j + 3); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); pstore(blockBt + ri , block.packet[0]); pstore(blockBt + ri + 2, block.packet[1]); @@ -1246,8 +1253,8 @@ struct rhs_cpack(i, j + 2); //[c1 c1i] cblock.packet[3] = rhs.template loadPacket(i, j + 3); //[d1 d1i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); if(Conjugate) { @@ -1300,25 +1307,84 @@ struct rhs_cpack -EIGEN_STRONG_INLINE void pger(PacketBlock *acc, const Scalar* lhs, const Scalar* rhs) +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) { - Packet lhsV = *((Packet *) lhs); - Packet rhsV1 = pset1(rhs[0]); - Packet rhsV2 = pset1(rhs[1]); - Packet rhsV3 = pset1(rhs[2]); - Packet rhsV4 = pset1(rhs[3]); + asm("#pger begin"); + Packet lhsV = pload(lhs); if(NegativeAccumulate) { - acc->packet[0] -= lhsV*rhsV1; - acc->packet[1] -= lhsV*rhsV2; - acc->packet[2] -= lhsV*rhsV3; - acc->packet[3] -= lhsV*rhsV4; + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); } else { - acc->packet[0] += lhsV*rhsV1; - acc->packet[1] += lhsV*rhsV2; - acc->packet[2] += lhsV*rhsV3; - acc->packet[3] += lhsV*rhsV4; + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } + asm("#pger end"); +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +{ + Packet lhsV = pload(lhs); + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + } +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); +#else + Packet lhsV; + Index i = 0; + do { + lhsV[i] = lhs[i]; + } while (++i < remaining_rows); +#endif + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); +#else + Packet lhsV; + Index i = 0; + do { + lhsV[i] = lhs[i]; + } while (++i < remaining_rows); +#endif + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); } } @@ -1399,7 +1465,7 @@ EIGEN_STRONG_INLINE void pgerc(PacketBlock& accReal, PacketBlock EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) { - return *((Packet *)lhs); + return *((Packet *)lhs); } // Zero the accumulator on PacketBlock. @@ -1412,14 +1478,26 @@ EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) acc.packet[3] = pset1((Scalar)0); } +template +EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) +{ + acc.packet[0] = pset1((Scalar)0); +} + // Scale the PacketBlock vectors by alpha. template EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { - acc.packet[0] = pmadd(pAlpha,accZ.packet[0], acc.packet[0]); - acc.packet[1] = pmadd(pAlpha,accZ.packet[1], acc.packet[1]); - acc.packet[2] = pmadd(pAlpha,accZ.packet[2], acc.packet[2]); - acc.packet[3] = pmadd(pAlpha,accZ.packet[3], acc.packet[3]); + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); + acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); + acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); + acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); +} + +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); } // Complex version of PacketBlock scaling. @@ -1471,534 +1549,546 @@ EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); } +const static Packet4i mask41 = { -1, 0, 0, 0 }; +const static Packet4i mask42 = { -1, -1, 0, 0 }; +const static Packet4i mask43 = { -1, -1, -1, 0 }; -// PEEL loop factor. -#define PEEL 10 +const static Packet2l mask21 = { -1, 0 }; -/**************** - * GEMM kernels * - * **************/ -template -EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, - Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +template +EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows) { - const Index remaining_rows = rows % accCols; - const Index remaining_cols = cols % accRows; - - if( strideA == -1 ) strideA = depth; - if( strideB == -1 ) strideB = depth; + if (remaining_rows == 0) { + return pset1(float(0.0)); + } else { + switch (remaining_rows) { + case 1: return Packet(mask41); + case 2: return Packet(mask42); + default: return Packet(mask43); + } + } +} - const Packet pAlpha = pset1(alpha); - Index col = 0; - for(; col + accRows <= cols; col += accRows) - { - const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows; - const Scalar *lhs_base = blockA; +template<> +EIGEN_STRONG_INLINE Packet2d bmask(const int remaining_rows) +{ + if (remaining_rows == 0) { + return pset1(double(0.0)); + } else { + return Packet2d(mask21); + } +} - Index row = 0; - for(; row + 6*accCols <= rows; row += 6*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - pger(&accZero5, lhs_ptr5, rhs_ptr); \ - lhs_ptr5 += accCols; \ - pger(&accZero6, lhs_ptr6, rhs_ptr); \ - lhs_ptr6 += accCols; \ - rhs_ptr += accRows; +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) +{ + acc.packet[0] = pmadd(pAlpha, pand(accZ.packet[0], pMask), acc.packet[0]); + acc.packet[1] = pmadd(pAlpha, pand(accZ.packet[1], pMask), acc.packet[1]); + acc.packet[2] = pmadd(pAlpha, pand(accZ.packet[2], pMask), acc.packet[2]); + acc.packet[3] = pmadd(pAlpha, pand(accZ.packet[3], pMask), acc.packet[3]); +} - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; - const Scalar *lhs_ptr6 = lhs_base + ((row/accCols) + 5)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - PacketBlock acc5, accZero5; - PacketBlock acc6, accZero6; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); - bload(acc5, res, row, col, accCols); - bsetzero(accZero5); - bload(acc6, res, row, col, accCols); - bsetzero(accZero6); +// PEEL loop factor. +#define PEEL 10 - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - lhs_ptr5 += accCols*offsetA; - lhs_ptr6 += accCols*offsetA; - rhs_ptr += accRows*offsetB; +template +EIGEN_STRONG_INLINE void MICRO_EXTRA_COL( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows, + Index remaining_cols) +{ + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; +} - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - prefetch(lhs_ptr5); - prefetch(lhs_ptr6); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } +template +EIGEN_STRONG_INLINE void gemm_extra_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero, acc; + + bsetzero(accZero); + + Index remaining_depth = (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + for(; k < depth; k++) + { + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; + } - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); - bscale(acc5,accZero5, pAlpha); - bscale(acc6,accZero6, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); - res.template storePacketBlock(row + 4*accCols, col, acc5); - res.template storePacketBlock(row + 5*accCols, col, acc6); -#undef MICRO - } - for(; row + 5*accCols <= rows; row += 5*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - pger(&accZero5, lhs_ptr5, rhs_ptr); \ - lhs_ptr5 += accCols; \ - rhs_ptr += accRows; + acc.packet[0] = vec_mul(pAlpha, accZero.packet[0]); + for(Index i = 0; i < remaining_rows; i++){ + res(row + i, col) += acc.packet[0][i]; + } +} - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - PacketBlock acc5, accZero5; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); - bload(acc5, res, row, col, accCols); - bsetzero(accZero5); +template +EIGEN_STRONG_INLINE void MICRO_EXTRA_ROW( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows) +{ + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; +} - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - lhs_ptr5 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; +template +EIGEN_STRONG_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero, acc; + + bsetzero(accZero); + + Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - prefetch(lhs_ptr5); + if (remaining_depth == depth) + { + for(Index j = 0; j < 4; j++){ + acc.packet[j] = res.template loadPacket(row, col + j); + } + bscale(acc, accZero, pAlpha, pMask); + res.template storePacketBlock(row, col, acc); + } else { + for(; k < depth; k++) + { + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; + } - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } + for(Index j = 0; j < 4; j++){ + acc.packet[j] = vec_mul(pAlpha, accZero.packet[j]); + } + for(Index j = 0; j < 4; j++){ + for(Index i = 0; i < remaining_rows; i++){ + res(row + i, col + j) += acc.packet[j][i]; + } + } + } +} - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); - bscale(acc5,accZero5, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); - res.template storePacketBlock(row + 4*accCols, col, acc5); -#undef MICRO - } - for(; row + 4*accCols <= rows; row += 4*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - rhs_ptr += accRows; +#define MICRO_DST \ + PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ + PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ + PacketBlock *accZero6, PacketBlock *accZero7 + +#define MICRO_COL_DST \ + PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ + PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ + PacketBlock *accZero6, PacketBlock *accZero7 + +#define MICRO_SRC \ + const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \ + const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \ + const Scalar **lhs_ptr6, const Scalar **lhs_ptr7 + +#define MICRO_ONE \ + MICRO(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); + +#define MICRO_COL_ONE \ + MICRO_COL(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7, \ + remaining_cols); + +#define MICRO_WORK_ONE(iter) \ + if (N > iter) { \ + pger(accZero##iter, *lhs_ptr##iter, rhsV); \ + *lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); +#define MICRO_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; +#define MICRO_WORK MICRO_UNROLL(MICRO_WORK_ONE) - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } +#define MICRO_DST_PTR_ONE(iter) \ + if (unroll_factor > iter){ \ + bsetzero(accZero##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + } - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); +#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE) - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); -#undef MICRO - } - for(; row + 3*accCols <= rows; row += 3*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - rhs_ptr += accRows; +#define MICRO_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; +#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; +#define MICRO_PREFETCH_ONE(iter) \ + if (unroll_factor > iter){ \ + prefetch(lhs_ptr##iter); \ + } - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); +#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); +#define MICRO_STORE_ONE(iter) \ + if (unroll_factor > iter){ \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + acc.packet[1] = res.template loadPacket(row + iter*accCols, col + 1); \ + acc.packet[2] = res.template loadPacket(row + iter*accCols, col + 2); \ + acc.packet[3] = res.template loadPacket(row + iter*accCols, col + 3); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } +#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); +#define MICRO_COL_STORE_ONE(iter) \ + if (unroll_factor > iter){ \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); -#undef MICRO - } - for(; row + 2*accCols <= rows; row += 2*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - rhs_ptr += accRows; +#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; +template +EIGEN_STRONG_INLINE void MICRO( + MICRO_SRC, + const Scalar* &rhs_ptr, + MICRO_DST) + { + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + asm("#unrolled pger? begin"); + MICRO_WORK + asm("#unrolled pger? end"); + rhs_ptr += accRows; + } - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); +template +EIGEN_STRONG_INLINE void gemm_unrolled_iteration( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + asm("#unrolled start"); + MICRO_SRC_PTR + asm("#unrolled zero?"); + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + prefetch(rhs_ptr); + MICRO_PREFETCH + asm("#unrolled inner loop?"); + for (int l = 0; l < PEEL; l++) { + MICRO_ONE + } + asm("#unrolled inner loop end?"); + } + for(; k < depth; k++) + { + MICRO_ONE + } + MICRO_STORE - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); + row += unroll_factor*accCols; +} - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } +template +EIGEN_STRONG_INLINE void MICRO_COL( + MICRO_SRC, + const Scalar* &rhs_ptr, + MICRO_COL_DST, + Index remaining_rows) + { + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + asm("#unrolled pger? begin"); + MICRO_WORK + asm("#unrolled pger? end"); + rhs_ptr += remaining_rows; + } - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); +template +EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + MICRO_SRC_PTR + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + prefetch(rhs_ptr); + MICRO_PREFETCH + for (int l = 0; l < PEEL; l++) { + MICRO_COL_ONE + } + } + for(; k < depth; k++) + { + MICRO_COL_ONE + } + MICRO_COL_STORE - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); -#undef MICRO - } + row += unroll_factor*accCols; +} - for(; row + accCols <= rows; row += accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - rhs_ptr += accRows; +template +EIGEN_STRONG_INLINE void gemm_unrolled_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows){ + gemm_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + } + switch( (rows-row)/accCols ){ +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_UNROLL +} - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; +/**************** + * GEMM kernels * + * **************/ +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; - PacketBlock acc1, accZero1; + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); + const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); - lhs_ptr1 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar *lhs_base = blockA; + Index row = 0; - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); + asm("#jump table"); +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows){ + gemm_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + } + switch( (rows-row)/accCols ){ +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - - res.template storePacketBlock(row, col, acc1); -#undef MICRO +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif + default: + break; } +#undef MAX_UNROLL + asm("#jump table end"); + if(remaining_rows > 0) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; - - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < accRows; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += accRows; - lhs_ptr += remaining_rows; - } + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask); } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows; + const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB; const Scalar *lhs_base = blockA; - Index row = 0; - for(; row + accCols <= rows; row += accCols) + for(; col < cols; col++) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + Index row = 0; - lhs_ptr += accCols*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < accCols; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += accCols; - } - } - - if(remaining_rows > 0 ) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) + if (remaining_rows > 0) { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); } + rhs_base++; } } } -template +template EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, - Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) + Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const int remaining_rows = rows % accCols; const int remaining_cols = cols % accRows; @@ -2018,7 +2108,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl const Scalar *blockA = (Scalar *) blockAc; const Scalar *blockB = (Scalar *) blockBc; - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); Index col = 0; for(; col + accRows <= cols; col += accRows) @@ -2054,7 +2144,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl if(!RhsIsReal) rhs_ptr_imag += accRows*offsetB; Index k = 0; - for(; k + PEEL < depth; k+=PEEL) + for(; k + PEEL <= depth; k+=PEEL) { prefetch(rhs_ptr); prefetch(rhs_ptr_imag); @@ -2180,8 +2270,8 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl { for(Index acol = 0; acol < 4; acol++ ) { - scalarAcc[arow][acol].real((Scalar)0.0f); - scalarAcc[arow][acol].imag((Scalar)0.0f); + scalarAcc[arow][acol].real((Scalar)0.0); + scalarAcc[arow][acol].imag((Scalar)0.0); } } for(Index k = 0; k < depth; k++) @@ -2550,24 +2640,24 @@ void gebp_kernel::rows; - const int accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index, const int, const int); + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; } else{ - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; } #else - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2591,22 +2681,22 @@ void gebp_kernel, std::complex, Index, DataMapper, mr const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2630,21 +2720,21 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjugat const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const float*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2668,21 +2758,21 @@ void gebp_kernel, float, Index, DataMapper, mr, nr, Conjugat const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const float*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } else{ - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } #else - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2702,24 +2792,24 @@ void gebp_kernel::rows; - const int accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index, const int, const int); + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; } else{ - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; } #else - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2743,21 +2833,21 @@ void gebp_kernel, std::complex, Index, DataMapper, const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2781,21 +2871,21 @@ void gebp_kernel, double, Index, DataMapper, mr, nr, Conjug const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const double*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } else{ - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } #else - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2819,21 +2909,21 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const double*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } } // end namespace internal -- cgit v1.2.3