From c29935b323ffb0b903f640111f0a0b0440e94a2e Mon Sep 17 00:00:00 2001 From: Pedro Caldeira Date: Wed, 9 Sep 2020 12:16:44 -0500 Subject: Add support for dynamic dispatch of MMA instructions for POWER 10 --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 451 +++++---------- Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h | 80 +++ Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h | 638 ++++++++++++++++++++++ 3 files changed, 867 insertions(+), 302 deletions(-) create mode 100644 Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h create mode 100644 Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h (limited to 'Eigen/src/Core/arch/AltiVec') diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 57227e23b..b86367571 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -10,6 +10,19 @@ #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H +#include "MatrixProductCommon.h" + +#if __GNUC__ > 10 || \ + (__GNUC__ == 10 && (__GNUC_MINOR__ > 2 || \ + (__GNUC_MINOR__ == 2 && \ + __GNUC_PATCHLEVEL__ >= 1))) + #define ALTIVEC_MMA_SUPPORT +#endif + +#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + #include "MatrixProductMMA.h" +#endif + /************************************************************************************************** * TODO * * - Check StorageOrder on lhs_pack (the innermost second loop seems unvectorized when it could). * @@ -26,18 +39,6 @@ namespace internal { **************************/ const int QuadRegisterCount = 8; -#ifdef __MMA__ - -template -union Packetx2u -{ - __vector_pair vectorpair; - PacketBlock pair; -}; - -#endif - - template struct quad_traits { @@ -82,17 +83,6 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - -const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, - 16, 17, 18, 19, - 4, 5, 6, 7, - 20, 21, 22, 23}; - -const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, - 24, 25, 26, 27, - 12, 13, 14, 15, - 28, 29, 30, 31}; -//[a,ai],[b,bi] = [a,b] const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; @@ -100,14 +90,6 @@ const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; -//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 -const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; - -//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 -const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; - /********************************************* * Single precision real and complex packing * * *******************************************/ @@ -1316,154 +1298,6 @@ struct rhs_cpack -EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); - - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); -} - -template<> -EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); - - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); -} - -#ifdef __MMA__ -template -EIGEN_STRONG_INLINE PacketBlock pmul (const PacketBlock& a, const Packet& b) -{ - PacketBlock pb; - pb.packet[0] = a.packet[0]*b; - pb.packet[1] = a.packet[1]*b; - return pb; -} -template -EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad *acc) -{ - PacketBlock result; - __builtin_mma_disassemble_acc(&result.packet, acc); - - PacketBlock block; - block.packet[0] = data.template loadPacket(i, j + 0) + pmul(alpha, result.packet[0]); - block.packet[1] = data.template loadPacket(i, j + 1) + pmul(alpha, result.packet[1]); - block.packet[2] = data.template loadPacket(i, j + 2) + pmul(alpha, result.packet[2]); - block.packet[3] = data.template loadPacket(i, j + 3) + pmul(alpha, result.packet[3]); - - data.template storePacketBlock(i, j, block); -} - -template -EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad *accReal, __vector_quad *accImag, const int accColsC) -{ - PacketBlock resultReal, resultImag; - __builtin_mma_disassemble_acc(&resultReal.packet, accReal); - __builtin_mma_disassemble_acc(&resultImag.packet, accImag); - - PacketBlock taccReal, taccImag; - taccReal.packet[0] = pmul(resultReal.packet[0], alphaReal); - taccReal.packet[1] = pmul(resultReal.packet[1], alphaReal); - taccReal.packet[2] = pmul(resultReal.packet[2], alphaReal); - taccReal.packet[3] = pmul(resultReal.packet[3], alphaReal); - - taccImag.packet[0] = pmul(resultImag.packet[0], alphaReal); - taccImag.packet[1] = pmul(resultImag.packet[1], alphaReal); - taccImag.packet[2] = pmul(resultImag.packet[2], alphaReal); - taccImag.packet[3] = pmul(resultImag.packet[3], alphaReal); - - taccReal.packet[0] = psub(taccReal.packet[0], pmul(resultImag.packet[0], alphaImag)); - taccReal.packet[1] = psub(taccReal.packet[1], pmul(resultImag.packet[1], alphaImag)); - taccReal.packet[2] = psub(taccReal.packet[2], pmul(resultImag.packet[2], alphaImag)); - taccReal.packet[3] = psub(taccReal.packet[3], pmul(resultImag.packet[3], alphaImag)); - - taccImag.packet[0] = pmadd(resultReal.packet[0], alphaImag, taccImag.packet[0]); - taccImag.packet[1] = pmadd(resultReal.packet[1], alphaImag, taccImag.packet[1]); - taccImag.packet[2] = pmadd(resultReal.packet[2], alphaImag, taccImag.packet[2]); - taccImag.packet[3] = pmadd(resultReal.packet[3], alphaImag, taccImag.packet[3]); - - PacketBlock tRes; - tRes.packet[0] = data.template loadPacket(i + N*accColsC, j + 0); - tRes.packet[1] = data.template loadPacket(i + N*accColsC, j + 1); - tRes.packet[2] = data.template loadPacket(i + N*accColsC, j + 2); - tRes.packet[3] = data.template loadPacket(i + N*accColsC, j + 3); - - tRes.packet[4] = data.template loadPacket(i + (N+1)*accColsC, j + 0); - tRes.packet[5] = data.template loadPacket(i + (N+1)*accColsC, j + 1); - tRes.packet[6] = data.template loadPacket(i + (N+1)*accColsC, j + 2); - tRes.packet[7] = data.template loadPacket(i + (N+1)*accColsC, j + 3); - - PacketBlock acc1, acc2; - bcouple(taccReal, taccImag, tRes, acc1, acc2); - - data.template storePacketBlock(i + N*accColsC, j, acc1); - data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); -} - -// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments -template -EIGEN_STRONG_INLINE void pger(__vector_quad *acc, const RhsPacket& a, const LhsPacket& b) -{ - if(NegativeAccumulate) - { - __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b); - } else { - __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b); - } -} - -template<> -EIGEN_STRONG_INLINE void pger, false>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) -{ - Packetx2u p; - p.pair = a; - __builtin_mma_xvf64gerpp(acc, p.vectorpair, (__vector unsigned char)b); -} - -template<> -EIGEN_STRONG_INLINE void pger, true>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) -{ - Packetx2u p; - p.pair = a; - __builtin_mma_xvf64gernp(acc, p.vectorpair, (__vector unsigned char)b); -} -#else - // 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm). template EIGEN_STRONG_INLINE void pger(PacketBlock *acc, const Scalar* lhs, const Scalar* rhs) @@ -1561,25 +1395,6 @@ EIGEN_STRONG_INLINE void pgerc(PacketBlock& accReal, PacketBlock(rhsV4, lhsVi, accImag.packet[3]); } } -#endif - -// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. -template -EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs) -{ - return *((Packet *)rhs); -} - -#ifdef __MMA__ -template<> -EIGEN_STRONG_INLINE PacketBlock ploadRhs >(const double *rhs) -{ - PacketBlock pair; - pair.packet[0] = *((Packet2d *)rhs ); - pair.packet[1] = *(((Packet2d *)rhs) + 1); - return pair; -} -#endif template EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) @@ -1587,7 +1402,6 @@ EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) return *((Packet *)lhs); } -#ifndef __MMA__ // Zero the accumulator on PacketBlock. template EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) @@ -1656,7 +1470,7 @@ EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); } -#endif + // PEEL loop factor. #define PEEL 10 @@ -1682,31 +1496,6 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const const Scalar *lhs_base = blockA; Index row = 0; -#ifdef __MMA__ - for(; row + accCols <= rows; row += accCols) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; - - __vector_quad acc; - __builtin_mma_xxsetaccz(&acc); - - lhs_ptr1 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - Packet lhsV = ploadLhs(lhs_ptr1); - RhsPacket rhsV = ploadRhs(rhs_ptr); - - pger(&acc, rhsV, lhsV); - - lhs_ptr1 += accCols; - rhs_ptr += accRows; - } - - storeAccumulator(row, col, res, pAlpha, &acc); - } -#else for(; row + 6*accCols <= rows; row += 6*accCols) { #define MICRO() \ @@ -2135,7 +1924,6 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const res.template storePacketBlock(row, col, acc1); #undef MICRO } -#endif if(remaining_rows > 0) { const Scalar *rhs_ptr = rhs_base; @@ -2239,60 +2027,6 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl const Scalar *lhs_base = blockA; Index row = 0; -#ifdef __MMA__ - for(; row + accCols <= rows; row += accCols) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; - - __vector_quad accReal, accImag; - __builtin_mma_xxsetaccz(&accReal); - __builtin_mma_xxsetaccz(&accImag); - - lhs_ptr += accCols*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += accCols*offsetA; - rhs_ptr += accRows*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - Packet lhsV = ploadLhs(lhs_ptr); - RhsPacket rhsV = ploadRhs(rhs_ptr); - - Packet lhsVi = ploadLhs(lhs_ptr_imag); - RhsPacket rhsVi = ploadRhs(rhs_ptr_imag); - - if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi, conj); - if(ConjugateRhs && !RhsIsReal) rhsVi = pmul(rhsVi, conj); - - if(LhsIsReal) - { - pger(&accReal, rhsV, lhsV); - pger(&accImag, rhsVi, lhsV); - } else if(RhsIsReal) { - pger(&accReal, rhsV, lhsV); - pger(&accImag, rhsV, lhsVi); - } else { - pger(&accReal, rhsV, lhsV); - pger(&accReal, rhsVi, lhsVi); - pger(&accImag, rhsVi, lhsV); - pger(&accImag, rhsV, lhsVi); - } - - lhs_ptr += accCols; - rhs_ptr += accRows; - if(!LhsIsReal) - lhs_ptr_imag += accCols; - if(!RhsIsReal) - rhs_ptr_imag += accRows; - } - - storeComplexAccumulator(row, col, res, pAlphaReal, pAlphaImag, &accReal, &accImag, accColsC); - } -#else for(; row + accCols <= rows; row += accCols) { #define MICRO() \ @@ -2302,7 +2036,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl if(!LhsIsReal) \ lhs_ptr_imag1 += accCols; \ if(!RhsIsReal) \ - rhs_ptr_imag += accRows; + rhs_ptr_imag += accRows; const Scalar *rhs_ptr = rhs_base; const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; @@ -2356,7 +2090,6 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl res.template storePacketBlock(row + accColsC, col, acc2); #undef MICRO } -#endif if(remaining_rows > 0) { const Scalar *rhs_ptr = rhs_base; @@ -2383,7 +2116,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl lhsc.real(lhs_real); if(!LhsIsReal) { - if(ConjugateLhs) + if(ConjugateLhs) lhsc.imag(-lhs_imag); else lhsc.imag(lhs_imag); @@ -2457,7 +2190,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl { Scalar lhs_real = lhs_ptr[arow]; Scalar lhs_imag; - if(!LhsIsReal) + if(!LhsIsReal) { lhs_imag = lhs_ptr_imag[arow]; @@ -2534,7 +2267,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl lhsc.real(lhs_real); if(!LhsIsReal) { - if(ConjugateLhs) + if(ConjugateLhs) lhsc.imag(-lhs_imag); else lhsc.imag(lhs_imag); @@ -2819,8 +2552,22 @@ void gebp_kernel::rows; const int accCols = quad_traits::size; - - gemm(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index, const int, const int); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemmMMA; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemmMMA; + } + else{ + gemm_function = &Eigen::internal::gemm; + } + #else + gemm_function = &Eigen::internal::gemm; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2843,8 +2590,23 @@ void gebp_kernel, std::complex, Index, DataMapper, mr { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2867,8 +2629,22 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjugat { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const float*, const std::complex*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2891,8 +2667,22 @@ void gebp_kernel, float, Index, DataMapper, mr, nr, Conjugat { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const std::complex*, const float*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2914,8 +2704,22 @@ void gebp_kernel::rows; const int accCols = quad_traits::size; - - gemm(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index, const int, const int); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemmMMA; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemmMMA; + } + else{ + gemm_function = &Eigen::internal::gemm; + } + #else + gemm_function = &Eigen::internal::gemm; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2938,8 +2742,22 @@ void gebp_kernel, std::complex, Index, DataMapper, { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2962,8 +2780,22 @@ void gebp_kernel, double, Index, DataMapper, mr, nr, Conjug { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const std::complex*, const double*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } template @@ -2986,10 +2818,25 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug { const int accRows = quad_traits::rows; const int accCols = quad_traits::size; - - gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + void (*gemm_function)(const DataMapper&, const double*, const std::complex*, + Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } } // end namespace internal } // end namespace Eigen -#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H \ No newline at end of file + +#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h new file mode 100644 index 000000000..87b60c22c --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -0,0 +1,80 @@ +namespace Eigen { + +namespace internal { + +const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, + 16, 17, 18, 19, + 4, 5, 6, 7, + 20, 21, 22, 23}; + +const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, + 24, 25, 26, 27, + 12, 13, 14, 15, + 28, 29, 30, 31}; +//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 +const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 +const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + + +// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. +template +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +template<> +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs) +{ + return *((Packet *)rhs); +} + +} // end namespace internal +} // end namespace Eigen diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h new file mode 100644 index 000000000..1866a71bf --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -0,0 +1,638 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H +#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + +#pragma GCC target("cpu=power10") + +namespace Eigen { + +namespace internal { + +template +union Packetx2u +{ + __vector_pair vectorpair; + PacketBlock pair; +}; +const static Packet16uc MMA_p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, + 16, 17, 18, 19, + 4, 5, 6, 7, + 20, 21, 22, 23}; + +const static Packet16uc MMA_p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, + 24, 25, 26, 27, + 12, 13, 14, 15, + 28, 29, 30, 31}; +//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 +const static Packet16uc MMA_p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 +const static Packet16uc MMA_p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + + + +// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. +template +EIGEN_STRONG_INLINE void bcoupleMMA(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX32_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX32_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX32_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX32_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX32_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX32_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX32_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +template<> +EIGEN_STRONG_INLINE void bcoupleMMA(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX64_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX64_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX64_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX64_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX64_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX64_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX64_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX64_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +template +EIGEN_STRONG_INLINE Packet ploadLhsMMA(const Scalar *lhs) +{ + return *((Packet *)lhs); +} + +template +EIGEN_STRONG_INLINE PacketBlock pmul (const PacketBlock& a, const Packet& b) +{ + PacketBlock pb; + pb.packet[0] = a.packet[0]*b; + pb.packet[1] = a.packet[1]*b; + return pb; +} + +template +EIGEN_STRONG_INLINE void bsetzeroMMA(__vector_quad *acc) +{ + __builtin_mma_xxsetaccz(acc); +} + +template +EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad *acc) +{ + PacketBlock result; + __builtin_mma_disassemble_acc(&result.packet, acc); + + PacketBlock block; + block.packet[0] = data.template loadPacket(i, j + 0) + pmul(alpha, result.packet[0]); + block.packet[1] = data.template loadPacket(i, j + 1) + pmul(alpha, result.packet[1]); + block.packet[2] = data.template loadPacket(i, j + 2) + pmul(alpha, result.packet[2]); + block.packet[3] = data.template loadPacket(i, j + 3) + pmul(alpha, result.packet[3]); + + data.template storePacketBlock(i, j, block); +} + +template +EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad *accReal, __vector_quad *accImag, const int accColsC) +{ + PacketBlock resultReal, resultImag; + __builtin_mma_disassemble_acc(&resultReal.packet, accReal); + __builtin_mma_disassemble_acc(&resultImag.packet, accImag); + + PacketBlock taccReal, taccImag; + taccReal.packet[0] = pmul(resultReal.packet[0], alphaReal); + taccReal.packet[1] = pmul(resultReal.packet[1], alphaReal); + taccReal.packet[2] = pmul(resultReal.packet[2], alphaReal); + taccReal.packet[3] = pmul(resultReal.packet[3], alphaReal); + + taccImag.packet[0] = pmul(resultImag.packet[0], alphaReal); + taccImag.packet[1] = pmul(resultImag.packet[1], alphaReal); + taccImag.packet[2] = pmul(resultImag.packet[2], alphaReal); + taccImag.packet[3] = pmul(resultImag.packet[3], alphaReal); + + taccReal.packet[0] = psub(taccReal.packet[0], pmul(resultImag.packet[0], alphaImag)); + taccReal.packet[1] = psub(taccReal.packet[1], pmul(resultImag.packet[1], alphaImag)); + taccReal.packet[2] = psub(taccReal.packet[2], pmul(resultImag.packet[2], alphaImag)); + taccReal.packet[3] = psub(taccReal.packet[3], pmul(resultImag.packet[3], alphaImag)); + + taccImag.packet[0] = pmadd(resultReal.packet[0], alphaImag, taccImag.packet[0]); + taccImag.packet[1] = pmadd(resultReal.packet[1], alphaImag, taccImag.packet[1]); + taccImag.packet[2] = pmadd(resultReal.packet[2], alphaImag, taccImag.packet[2]); + taccImag.packet[3] = pmadd(resultReal.packet[3], alphaImag, taccImag.packet[3]); + + PacketBlock tRes; + tRes.packet[0] = data.template loadPacket(i + N*accColsC, j + 0); + tRes.packet[1] = data.template loadPacket(i + N*accColsC, j + 1); + tRes.packet[2] = data.template loadPacket(i + N*accColsC, j + 2); + tRes.packet[3] = data.template loadPacket(i + N*accColsC, j + 3); + + tRes.packet[4] = data.template loadPacket(i + (N+1)*accColsC, j + 0); + tRes.packet[5] = data.template loadPacket(i + (N+1)*accColsC, j + 1); + tRes.packet[6] = data.template loadPacket(i + (N+1)*accColsC, j + 2); + tRes.packet[7] = data.template loadPacket(i + (N+1)*accColsC, j + 3); + + PacketBlock acc1, acc2; + bcoupleMMA(taccReal, taccImag, tRes, acc1, acc2); + + data.template storePacketBlock(i + N*accColsC, j, acc1); + data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); +} + +// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments +template +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket& a, const LhsPacket& b) +{ + if(NegativeAccumulate) + { + __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } else { + __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } +} + +template<> +EIGEN_STRONG_INLINE void pgerMMA, false>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +{ + Packetx2u p; + p.pair = a; + __builtin_mma_xvf64gerpp(acc, p.vectorpair, (__vector unsigned char)b); +} + +template<> +EIGEN_STRONG_INLINE void pgerMMA, true>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +{ + Packetx2u p; + p.pair = a; + __builtin_mma_xvf64gernp(acc, p.vectorpair, (__vector unsigned char)b); +} + +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_STRONG_INLINE Packet ploadRhsMMA(const Scalar *rhs) +{ + return *((Packet *)rhs); +} + +template<> +EIGEN_STRONG_INLINE PacketBlock ploadRhsMMA >(const double *rhs) +{ + PacketBlock pair; + pair.packet[0] = *((Packet2d *)rhs ); + pair.packet[1] = *(((Packet2d *)rhs) + 1); + return pair; +} + +template +void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, + Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlpha = pset1(alpha); + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; + + __vector_quad acc; + bsetzeroMMA(&acc); + + lhs_ptr1 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + Packet lhsV = ploadLhsMMA(lhs_ptr1); + RhsPacket rhsV = ploadRhsMMA(rhs_ptr); + + pgerMMA(&acc, rhsV, lhsV); + + lhs_ptr1 += accCols; + rhs_ptr += accRows; + } + + storeAccumulator(row, col, res, pAlpha, &acc); + } + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += remaining_rows*offsetA; + rhs_ptr += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + for(Index acol = 0; acol < accRows; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += accRows; + lhs_ptr += remaining_rows; + } + } + } + + if(remaining_cols > 0) + { + const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += accCols*offsetA; + rhs_ptr += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < accCols; arow++) + { + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += accCols; + } + } + + if(remaining_rows > 0 ) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += remaining_rows*offsetA; + rhs_ptr += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += remaining_rows; + } + } + } +} + +template +void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, + Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +{ + const int remaining_rows = rows % accCols; + const int remaining_cols = cols % accRows; + const int accColsC = accCols / 2; + int advanceCols = 2; + int advanceRows = 2; + + if(LhsIsReal) advanceRows = 1; + if(RhsIsReal) advanceCols = 1; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlphaReal = pset1(alpha.real()); + const Packet pAlphaImag = pset1(alpha.imag()); + + const Scalar *blockA = (Scalar *) blockAc; + const Scalar *blockB = (Scalar *) blockBc; + + Packet conj = pset1((Scalar)-1.0f); + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; + + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; + + __vector_quad accReal, accImag; + __builtin_mma_xxsetaccz(&accReal); + __builtin_mma_xxsetaccz(&accImag); + + lhs_ptr += accCols*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += accCols*offsetA; + rhs_ptr += accRows*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + Packet lhsV = ploadLhsMMA(lhs_ptr); + RhsPacket rhsV = ploadRhs(rhs_ptr); + + Packet lhsVi = ploadLhsMMA(lhs_ptr_imag); + RhsPacket rhsVi = ploadRhs(rhs_ptr_imag); + + if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi, conj); + if(ConjugateRhs && !RhsIsReal) rhsVi = pmul(rhsVi, conj); + + if(LhsIsReal) + { + pgerMMA(&accReal, rhsV, lhsV); + pgerMMA(&accImag, rhsVi, lhsV); + } else if(RhsIsReal) { + pgerMMA(&accReal, rhsV, lhsV); + pgerMMA(&accImag, rhsV, lhsVi); + } else { + pgerMMA(&accReal, rhsV, lhsV); + pgerMMA(&accReal, rhsVi, lhsVi); + pgerMMA(&accImag, rhsVi, lhsV); + pgerMMA(&accImag, rhsV, lhsVi); + } + + lhs_ptr += accCols; + rhs_ptr += accRows; + if(!LhsIsReal) + lhs_ptr_imag += accCols; + if(!RhsIsReal) + rhs_ptr_imag += accRows; + } + + storeComplexAccumulator(row, col, res, pAlphaReal, pAlphaImag, &accReal, &accImag, accColsC); + } + + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; + + lhs_ptr += remaining_rows*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows*offsetA; + rhs_ptr += accRows*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; + + Scalarc lhsc; + + lhsc.real(lhs_real); + if(!LhsIsReal) + { + if(ConjugateLhs) + lhsc.imag(-lhs_imag); + else + lhsc.imag(lhs_imag); + } else { + //Lazy approach for now + lhsc.imag((Scalar)0); + } + + for(int acol = 0; acol < accRows; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; + Scalarc rhsc; + + rhsc.real(rhs_real); + if(!RhsIsReal) + { + if(ConjugateRhs) + rhsc.imag(-rhs_imag); + else + rhsc.imag(rhs_imag); + } else { + //Lazy approach for now + rhsc.imag((Scalar)0); + } + res(row + arow, col + acol) += alpha*lhsc*rhsc; + } + } + rhs_ptr += accRows; + lhs_ptr += remaining_rows; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows; + if(!RhsIsReal) + rhs_ptr_imag += accRows; + } + } + } + + if(remaining_cols > 0) + { + const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + Index row = 0; + + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; + + lhs_ptr += accCols*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += accCols*offsetA; + rhs_ptr += remaining_cols*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols*offsetB; + Scalarc scalarAcc[4][4]; + for(Index arow = 0; arow < 4; arow++ ) + { + for(Index acol = 0; acol < 4; acol++ ) + { + scalarAcc[arow][acol].real((Scalar)0.0f); + scalarAcc[arow][acol].imag((Scalar)0.0f); + } + } + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < accCols; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) + { + lhs_imag = lhs_ptr_imag[arow]; + + if(ConjugateLhs) + lhs_imag *= -1; + } else { + lhs_imag = (Scalar)0; + } + + for(int acol = 0; acol < remaining_cols; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) + { + rhs_imag = rhs_ptr_imag[acol]; + + if(ConjugateRhs) + rhs_imag *= -1; + } else { + rhs_imag = (Scalar)0; + } + + scalarAcc[arow][acol].real(scalarAcc[arow][acol].real() + lhs_real*rhs_real - lhs_imag*rhs_imag); + scalarAcc[arow][acol].imag(scalarAcc[arow][acol].imag() + lhs_imag*rhs_real + lhs_real*rhs_imag); + } + } + rhs_ptr += remaining_cols; + lhs_ptr += accCols; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols; + if(!LhsIsReal) + lhs_ptr_imag += accCols; + } + for(int arow = 0; arow < accCols; arow++ ) + { + for(int acol = 0; acol < remaining_cols; acol++ ) + { + Scalar accR = scalarAcc[arow][acol].real(); + Scalar accI = scalarAcc[arow][acol].imag(); + Scalar aR = alpha.real(); + Scalar aI = alpha.imag(); + Scalar resR = res(row + arow, col + acol).real(); + Scalar resI = res(row + arow, col + acol).imag(); + + res(row + arow, col + acol).real(resR + accR*aR - accI*aI); + res(row + arow, col + acol).imag(resI + accR*aI + accI*aR); + } + } + } + + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; + + lhs_ptr += remaining_rows*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows*offsetA; + rhs_ptr += remaining_cols*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; + Scalarc lhsc; + + lhsc.real(lhs_real); + if(!LhsIsReal) + { + if(ConjugateLhs) + lhsc.imag(-lhs_imag); + else + lhsc.imag(lhs_imag); + } else { + lhsc.imag((Scalar)0); + } + + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; + Scalarc rhsc; + + rhsc.real(rhs_real); + if(!RhsIsReal) + { + if(ConjugateRhs) + rhsc.imag(-rhs_imag); + else + rhsc.imag(rhs_imag); + } else { + rhsc.imag((Scalar)0); + } + res(row + arow, col + acol) += alpha*lhsc*rhsc; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += remaining_rows; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols; + } + } + } +} + +#pragma GCC reset_options +} // end namespace internal + +} // end namespace Eigen +#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + -- cgit v1.2.3