From d59ef212e14012250127a244df1484f626d39e42 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 25 Mar 2021 11:08:19 +0000 Subject: Fixed performance issues for complex VSX and P10 MMA in gebp_kernel (level 3). --- Eigen/src/Core/arch/AltiVec/Complex.h | 147 +- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 2622 +++++++++++---------- Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h | 133 +- Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h | 886 +++---- 4 files changed, 1913 insertions(+), 1875 deletions(-) (limited to 'Eigen') diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h index d21e13979..8110f53d0 100644 --- a/Eigen/src/Core/arch/AltiVec/Complex.h +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -31,6 +31,52 @@ struct Packet2cf { EIGEN_STRONG_INLINE explicit Packet2cf() {} EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {} + + EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) + { + Packet4f v1, v2; + + // Permute and multiply the real parts of a and b + v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD); + // Get the imaginary parts of a + v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN); + // multiply a_re * b + v1 = vec_madd(v1, b.v, p4f_ZERO); + // multiply a_im * b and get the conjugate result + v2 = vec_madd(v2, b.v, p4f_ZERO); + v2 = reinterpret_cast(pxor(v2, reinterpret_cast(p4ui_CONJ_XOR))); + // permute back to a proper order + v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV); + + return Packet2cf(padd(v1, v2)); + } + + EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) { + v = pmul(Packet2cf(*this), b).v; + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const { + return Packet2cf(*this) *= b; + } + + EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const { + return Packet2cf(*this) += b; + } + EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const { + return Packet2cf(*this) -= b; + } + EIGEN_STRONG_INLINE Packet2cf operator-(void) const { + return Packet2cf(vec_neg(v)); + } + Packet4f v; }; @@ -81,6 +127,25 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex< template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cf& from) { pstore((float*)to, from.v); } template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cf& from) { pstoreu((float*)to, from.v); } +EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex* from0, const std::complex* from1) +{ + Packet4f res0, res1; +#ifdef __VSX__ + __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0)); + __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1)); +#ifdef _BIG_ENDIAN + __asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); +#else + __asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); +#endif +#else + *((std::complex *)&res0[0]) = *from0; + *((std::complex *)&res1[0]) = *from1; + res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI); +#endif + return Packet2cf(res0); +} + template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>(const std::complex* from, Index stride) { EIGEN_ALIGN16 std::complex af[2]; @@ -101,25 +166,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, con template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(a.v)); } template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor(a.v, reinterpret_cast(p4ui_CONJ_XOR))); } -template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) -{ - Packet4f v1, v2; - - // Permute and multiply the real parts of a and b - v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD); - // Get the imaginary parts of a - v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN); - // multiply a_re * b - v1 = vec_madd(v1, b.v, p4f_ZERO); - // multiply a_im * b and get the conjugate result - v2 = vec_madd(v2, b.v, p4f_ZERO); - v2 = reinterpret_cast(pxor(v2, reinterpret_cast(p4ui_CONJ_XOR))); - // permute back to a proper order - v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV); - - return Packet2cf(padd(v1, v2)); -} - template<> EIGEN_STRONG_INLINE Packet2cf pand (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand(a.v, b.v)); } template<> EIGEN_STRONG_INLINE Packet2cf por (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por(a.v, b.v)); } template<> EIGEN_STRONG_INLINE Packet2cf pxor (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor(a.v, b.v)); } @@ -239,6 +285,51 @@ struct Packet1cd { EIGEN_STRONG_INLINE Packet1cd() {} EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {} + + EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) + { + Packet2d a_re, a_im, v1, v2; + + // Permute and multiply the real parts of a and b + a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI); + // Get the imaginary parts of a + a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO); + // multiply a_re * b + v1 = vec_madd(a_re, b.v, p2d_ZERO); + // multiply a_im * b and get the conjugate result + v2 = vec_madd(a_im, b.v, p2d_ZERO); + v2 = reinterpret_cast(vec_sld(reinterpret_cast(v2), reinterpret_cast(v2), 8)); + v2 = pxor(v2, reinterpret_cast(p2ul_CONJ_XOR1)); + + return Packet1cd(padd(v1, v2)); + } + + EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) { + v = pmul(Packet1cd(*this), b).v; + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const { + return Packet1cd(*this) *= b; + } + + EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const { + return Packet1cd(*this) += b; + } + EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const { + return Packet1cd(*this) -= b; + } + EIGEN_STRONG_INLINE Packet1cd operator-(void) const { + return Packet1cd(vec_neg(v)); + } + Packet2d v; }; @@ -290,24 +381,6 @@ template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, con template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); } template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(pxor(a.v, reinterpret_cast(p2ul_CONJ_XOR2))); } -template<> EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) -{ - Packet2d a_re, a_im, v1, v2; - - // Permute and multiply the real parts of a and b - a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI); - // Get the imaginary parts of a - a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO); - // multiply a_re * b - v1 = vec_madd(a_re, b.v, p2d_ZERO); - // multiply a_im * b and get the conjugate result - v2 = vec_madd(a_im, b.v, p2d_ZERO); - v2 = reinterpret_cast(vec_sld(reinterpret_cast(v2), reinterpret_cast(v2), 8)); - v2 = pxor(v2, reinterpret_cast(p2ul_CONJ_XOR1)); - - return Packet1cd(padd(v1, v2)); -} - template<> EIGEN_STRONG_INLINE Packet1cd pand (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pand(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet1cd por (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(por(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet1cd pxor (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pxor(a.v,b.v)); } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 03d474a70..30b814241 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) +// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com) // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -12,6 +13,7 @@ #include "MatrixProductCommon.h" +// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX #if EIGEN_COMP_LLVM #if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY) #ifdef __MMA__ @@ -34,10 +36,8 @@ /************************************************************************************************** * TODO * - * - Check StorageOrder on lhs_pack (the innermost second loop seems unvectorized when it could). * + * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). * * - Check the possibility of transposing as GETREAL and GETIMAG when needed. * - * - Check if change conjugation to xor instead of mul gains any performance. * - * - Remove IsComplex template argument from complex packing. * **************************************************************************************************/ namespace Eigen { @@ -46,13 +46,11 @@ namespace internal { /************************** * Constants and typedefs * **************************/ -const int QuadRegisterCount = 8; - template struct quad_traits { typedef typename packet_traits::type vectortype; - typedef PacketBlock type; + typedef PacketBlock type; typedef vectortype rhstype; enum { @@ -66,7 +64,7 @@ template<> struct quad_traits { typedef Packet2d vectortype; - typedef PacketBlock type; + typedef PacketBlock type; typedef PacketBlock rhstype; enum { @@ -79,9 +77,6 @@ struct quad_traits // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then // are responsible to extract from convert between Eigen's and MatrixProduct approach. -const static Packet4f p4f_CONJUGATE = {float(-1.0), float(-1.0), float(-1.0), float(-1.0)}; - -const static Packet2d p2d_CONJUGATE = {-1.0, -1.0}; const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3, 8, 9, 10, 11, @@ -109,7 +104,7 @@ const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, * conjugated. There's no PanelMode available for symm packing. * * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using - * it's respective rank-update instructions. The float32/64 versions are different because at this moment + * its respective rank-update instructions. The float32/64 versions are different because at this moment * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements. * * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has @@ -123,156 +118,148 @@ EIGEN_STRONG_INLINE std::complex getAdjointVal(Index i, Index j, const_b std::complex v; if(i < j) { - v.real(dt(j,i).real()); + v.real( dt(j,i).real()); v.imag(-dt(j,i).imag()); } else if(i > j) { - v.real(dt(i,j).real()); - v.imag(dt(i,j).imag()); + v.real( dt(i,j).real()); + v.imag( dt(i,j).imag()); } else { - v.real(dt(i,j).real()); + v.real( dt(i,j).real()); v.imag((Scalar)0.0); } return v; } template -EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex *blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) { const Index depth = k2 + rows; const_blas_data_mapper, Index, StorageOrder> rhs(_rhs, rhsStride); - const int vectorSize = N*quad_traits::vectorsize; + const Index vectorSize = N*quad_traits::vectorsize; + const Index vectorDelta = vectorSize * rows; Scalar* blockBf = reinterpret_cast(blockB); - Index ri = 0, j = 0; + Index rir = 0, rii, j = 0; for(; j + vectorSize <= cols; j+=vectorSize) { - Index i = k2; - for(; i < depth; i++) - { - for(Index k = 0; k < vectorSize; k++) - { - std::complex v = getAdjointVal(i, j + k, rhs); - blockBf[ri + k] = v.real(); - } - ri += vectorSize; - } + rii = rir + vectorDelta; - i = k2; - - for(; i < depth; i++) - { - for(Index k = 0; k < vectorSize; k++) - { - std::complex v = getAdjointVal(i, j + k, rhs); - blockBf[ri + k] = v.imag(); - } - ri += vectorSize; - } - } - for(Index i = k2; i < depth; i++) - { - Index k = j; - for(; k < cols; k++) + for(Index i = k2; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) { - std::complex v = getAdjointVal(i, k, rhs); - blockBf[ri] = v.real(); - ri += 1; + std::complex v = getAdjointVal(i, j + k, rhs); + + blockBf[rir + k] = v.real(); + blockBf[rii + k] = v.imag(); } + rir += vectorSize; + rii += vectorSize; + } + + rir += vectorDelta; } - for(Index i = k2; i < depth; i++) + if (j < cols) { + rii = rir + ((cols - j) * rows); + + for(Index i = k2; i < depth; i++) + { Index k = j; for(; k < cols; k++) { std::complex v = getAdjointVal(i, k, rhs); - blockBf[ri] = v.imag(); - ri += 1; + + blockBf[rir] = v.real(); + blockBf[rii] = v.imag(); + + rir += 1; + rii += 1; } + } } } template -EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex *blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) +EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) { const Index depth = cols; const_blas_data_mapper, Index, StorageOrder> lhs(_lhs, lhsStride); - const int vectorSize = quad_traits::vectorsize; - Index ri = 0, j = 0; - Scalar *blockAf = (Scalar *)(blockA); + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * depth; + Scalar* blockAf = (Scalar *)(blockA); + Index rir = 0, rii, j = 0; for(; j + vectorSize <= rows; j+=vectorSize) { - Index i = 0; + rii = rir + vectorDelta; - for(; i < depth; i++) - { - for(int k = 0; k < vectorSize; k++) - { - std::complex v = getAdjointVal(j+k, i, lhs); - blockAf[ri + k] = v.real(); - } - ri += vectorSize; - } - i = 0; - for(; i < depth; i++) + for(Index i = 0; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) { - for(int k = 0; k < vectorSize; k++) - { - std::complex v = getAdjointVal(j+k, i, lhs); - blockAf[ri + k] = v.imag(); - } - ri += vectorSize; - } - } + std::complex v = getAdjointVal(j+k, i, lhs); - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < rows; k++) - { - std::complex v = getAdjointVal(k, i, lhs); - blockAf[ri] = v.real(); - ri += 1; + blockAf[rir + k] = v.real(); + blockAf[rii + k] = v.imag(); } + rir += vectorSize; + rii += vectorSize; + } + + rir += vectorDelta; } - for(Index i = 0; i < depth; i++) + + if (j < rows) { + rii = rir + ((rows - j) * depth); + + for(Index i = 0; i < depth; i++) + { Index k = j; for(; k < rows; k++) { - std::complex v = getAdjointVal(k, i, lhs); - blockAf[ri] = v.imag(); - ri += 1; + std::complex v = getAdjointVal(k, i, lhs); + + blockAf[rir] = v.real(); + blockAf[rii] = v.imag(); + + rir += 1; + rii += 1; } + } } } template -EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) { const Index depth = k2 + rows; const_blas_data_mapper rhs(_rhs, rhsStride); - const int vectorSize = quad_traits::vectorsize; + const Index vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; for(; j + N*vectorSize <= cols; j+=N*vectorSize) { - Index i = k2; - for(; i < depth; i++) + Index i = k2; + for(; i < depth; i++) + { + for(Index k = 0; k < N*vectorSize; k++) { - for(int k = 0; k < N*vectorSize; k++) - { - if(i <= j+k) - blockB[ri + k] = rhs(j+k, i); - else - blockB[ri + k] = rhs(i, j+k); - } - ri += N*vectorSize; + if(i <= j+k) + blockB[ri + k] = rhs(j+k, i); + else + blockB[ri + k] = rhs(i, j+k); } + ri += N*vectorSize; + } } - for(Index i = k2; i < depth; i++) + + if (j < cols) { + for(Index i = k2; i < depth; i++) + { Index k = j; for(; k < cols; k++) { @@ -282,45 +269,49 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs blockB[ri] = rhs(k, i); ri += 1; } + } } } template -EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) +EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) { const Index depth = cols; const_blas_data_mapper lhs(_lhs, lhsStride); - const int vectorSize = quad_traits::vectorsize; - Index ri = 0, j = 0; + const Index vectorSize = quad_traits::vectorsize; - for(j = 0; j + vectorSize <= rows; j+=vectorSize) + Index ri = 0, j = 0; + for(; j + vectorSize <= rows; j+=vectorSize) { - Index i = 0; + Index i = 0; - for(; i < depth; i++) + for(; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) { - for(int k = 0; k < vectorSize; k++) - { - if(i <= j+k) - blockA[ri + k] = lhs(j+k, i); - else - blockA[ri + k] = lhs(i, j+k); - } - ri += vectorSize; + if(i <= j+k) + blockA[ri + k] = lhs(j+k, i); + else + blockA[ri + k] = lhs(i, j+k); } + ri += vectorSize; + } } - for(Index i = 0; i < depth; i++) + if (j < rows) { + for(Index i = 0; i < depth; i++) + { Index k = j; for(; k < rows; k++) { - if(i <= k) - blockA[ri] = lhs(k, i); - else - blockA[ri] = lhs(i, k); - ri += 1; + if(i <= k) + blockA[ri] = lhs(k, i); + else + blockA[ri] = lhs(i, k); + ri += 1; } + } } } @@ -338,7 +329,7 @@ struct symm_pack_lhs, Index, Pack1, Pack2_dummy, StorageOrde { void operator()(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) { - symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); } }; @@ -368,7 +359,7 @@ struct symm_pack_rhs { void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) { - symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); } }; @@ -377,7 +368,7 @@ struct symm_pack_lhs { void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) { - symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); } }; @@ -396,7 +387,7 @@ struct symm_pack_lhs { void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) { - symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); } }; @@ -411,517 +402,252 @@ struct symm_pack_lhs * and offset and behaves accordingly. **/ -// General template for lhs complex packing. -template -struct lhs_cpack { +template +EIGEN_STRONG_INLINE void storeBlock(Scalar* to, PacketBlock& block) +{ + const Index size = 16 / sizeof(Scalar); + pstore(to + (0 * size), block.packet[0]); + pstore(to + (1 * size), block.packet[1]); + pstore(to + (2 * size), block.packet[2]); + pstore(to + (3 * size), block.packet[3]); +} + +template +EIGEN_STRONG_INLINE void storeBlock(Scalar* to, PacketBlock& block) +{ + const Index size = 16 / sizeof(Scalar); + pstore(to + (0 * size), block.packet[0]); + pstore(to + (1 * size), block.packet[1]); +} + +// General template for lhs & rhs complex packing. +template +struct dhs_cpack { EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; - Index ri = 0, j = 0; - Scalar *blockAt = reinterpret_cast(blockA); - Packet conj = pset1((Scalar)-1.0); + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; + Scalar* blockAt = reinterpret_cast(blockA); + Index j = 0; - for(j = 0; j + vectorSize <= rows; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; - if(PanelMode) ri += vectorSize*offset; + rii = rir + vectorDelta; for(; i + vectorSize <= depth; i+=vectorSize) { - PacketBlock block; + PacketBlock blockr, blocki; + PacketBlock cblock; - PacketBlock cblock; - if(StorageOrder == ColMajor) - { - cblock.packet[0] = lhs.template loadPacket(j, i + 0); - cblock.packet[1] = lhs.template loadPacket(j, i + 1); - cblock.packet[2] = lhs.template loadPacket(j, i + 2); - cblock.packet[3] = lhs.template loadPacket(j, i + 3); - - cblock.packet[4] = lhs.template loadPacket(j + 2, i + 0); - cblock.packet[5] = lhs.template loadPacket(j + 2, i + 1); - cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); - cblock.packet[7] = lhs.template loadPacket(j + 2, i + 3); + if (UseLhs) { + bload(cblock, lhs, j, i); } else { - cblock.packet[0] = lhs.template loadPacket(j + 0, i); - cblock.packet[1] = lhs.template loadPacket(j + 1, i); - cblock.packet[2] = lhs.template loadPacket(j + 2, i); - cblock.packet[3] = lhs.template loadPacket(j + 3, i); - - cblock.packet[4] = lhs.template loadPacket(j + 0, i + 2); - cblock.packet[5] = lhs.template loadPacket(j + 1, i + 2); - cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); - cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); + bload(cblock, lhs, i, j); } - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); - - if(StorageOrder == RowMajor) ptranspose(block); + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); - pstore(blockAt + ri , block.packet[0]); - pstore(blockAt + ri + 4, block.packet[1]); - pstore(blockAt + ri + 8, block.packet[2]); - pstore(blockAt + ri + 12, block.packet[3]); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); - ri += 4*vectorSize; - } - for(; i < depth; i++) - { - blockAt[ri + 0] = lhs(j + 0, i).real(); - blockAt[ri + 1] = lhs(j + 1, i).real(); - blockAt[ri + 2] = lhs(j + 2, i).real(); - blockAt[ri + 3] = lhs(j + 3, i).real(); - - ri += vectorSize; - } - if(PanelMode) ri += vectorSize*(stride - offset - depth); - - i = 0; - - if(PanelMode) ri += vectorSize*offset; - - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock cblock; - if(StorageOrder == ColMajor) + if(Conjugate) { - cblock.packet[0] = lhs.template loadPacket(j, i + 0); - cblock.packet[1] = lhs.template loadPacket(j, i + 1); - cblock.packet[2] = lhs.template loadPacket(j, i + 2); - cblock.packet[3] = lhs.template loadPacket(j, i + 3); - - cblock.packet[4] = lhs.template loadPacket(j + 2, i + 0); - cblock.packet[5] = lhs.template loadPacket(j + 2, i + 1); - cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); - cblock.packet[7] = lhs.template loadPacket(j + 2, i + 3); - } else { - cblock.packet[0] = lhs.template loadPacket(j + 0, i); - cblock.packet[1] = lhs.template loadPacket(j + 1, i); - cblock.packet[2] = lhs.template loadPacket(j + 2, i); - cblock.packet[3] = lhs.template loadPacket(j + 3, i); - - cblock.packet[4] = lhs.template loadPacket(j + 0, i + 2); - cblock.packet[5] = lhs.template loadPacket(j + 1, i + 2); - cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); - cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + blocki.packet[2] = -blocki.packet[2]; + blocki.packet[3] = -blocki.packet[3]; } - PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); - - if(Conjugate) + if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) { - block.packet[0] *= conj; - block.packet[1] *= conj; - block.packet[2] *= conj; - block.packet[3] *= conj; + ptranspose(blockr); + ptranspose(blocki); } - if(StorageOrder == RowMajor) ptranspose(block); + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); - pstore(blockAt + ri , block.packet[0]); - pstore(blockAt + ri + 4, block.packet[1]); - pstore(blockAt + ri + 8, block.packet[2]); - pstore(blockAt + ri + 12, block.packet[3]); - - ri += 4*vectorSize; + rir += 4*vectorSize; + rii += 4*vectorSize; } for(; i < depth; i++) { - if(Conjugate) + PacketBlock blockr, blocki; + PacketBlock cblock; + + if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) { - blockAt[ri + 0] = -lhs(j + 0, i).imag(); - blockAt[ri + 1] = -lhs(j + 1, i).imag(); - blockAt[ri + 2] = -lhs(j + 2, i).imag(); - blockAt[ri + 3] = -lhs(j + 3, i).imag(); + if (UseLhs) { + cblock.packet[0] = pload(&lhs(j + 0, i)); + cblock.packet[1] = pload(&lhs(j + 2, i)); + } else { + cblock.packet[0] = pload(&lhs(i, j + 0)); + cblock.packet[1] = pload(&lhs(i, j + 2)); + } } else { - blockAt[ri + 0] = lhs(j + 0, i).imag(); - blockAt[ri + 1] = lhs(j + 1, i).imag(); - blockAt[ri + 2] = lhs(j + 2, i).imag(); - blockAt[ri + 3] = lhs(j + 3, i).imag(); + if (UseLhs) { + cblock.packet[0] = pload2(&lhs(j + 0, i), &lhs(j + 1, i)); + cblock.packet[1] = pload2(&lhs(j + 2, i), &lhs(j + 3, i)); + } else { + cblock.packet[0] = pload2(&lhs(i, j + 0), &lhs(i, j + 1)); + cblock.packet[1] = pload2(&lhs(i, j + 2), &lhs(i, j + 3)); + } } - ri += vectorSize; - } - if(PanelMode) ri += vectorSize*(stride - offset - depth); - } - - if(PanelMode) ri += offset*(rows - j); - - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < rows; k++) - { - blockAt[ri] = lhs(k, i).real(); - ri += 1; - } - } - - if(PanelMode) ri += (rows - j)*(stride - offset - depth); + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32); - if(PanelMode) ri += offset*(rows - j); - - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < rows; k++) - { if(Conjugate) - blockAt[ri] = -lhs(k, i).imag(); - else - blockAt[ri] = lhs(k, i).imag(); - ri += 1; - } - } - - if(PanelMode) ri += (rows - j)*(stride - offset - depth); - } -}; - -// General template for lhs packing. -template -struct lhs_pack{ - EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) - { - const int vectorSize = quad_traits::vectorsize; - Index ri = 0, j = 0; - - for(j = 0; j + vectorSize <= rows; j+=vectorSize) - { - Index i = 0; - - if(PanelMode) ri += vectorSize*offset; - - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock block; - - if(StorageOrder == RowMajor) { - block.packet[0] = lhs.template loadPacket(j + 0, i); - block.packet[1] = lhs.template loadPacket(j + 1, i); - block.packet[2] = lhs.template loadPacket(j + 2, i); - block.packet[3] = lhs.template loadPacket(j + 3, i); - - ptranspose(block); - } else { - block.packet[0] = lhs.template loadPacket(j, i + 0); - block.packet[1] = lhs.template loadPacket(j, i + 1); - block.packet[2] = lhs.template loadPacket(j, i + 2); - block.packet[3] = lhs.template loadPacket(j, i + 3); + blocki.packet[0] = -blocki.packet[0]; } - pstore(blockA + ri , block.packet[0]); - pstore(blockA + ri + 4, block.packet[1]); - pstore(blockA + ri + 8, block.packet[2]); - pstore(blockA + ri + 12, block.packet[3]); + pstore(blockAt + rir, blockr.packet[0]); + pstore(blockAt + rii, blocki.packet[0]); - ri += 4*vectorSize; + rir += vectorSize; + rii += vectorSize; } - for(; i < depth; i++) - { - if(StorageOrder == RowMajor) - { - blockA[ri+0] = lhs(j+0, i); - blockA[ri+1] = lhs(j+1, i); - blockA[ri+2] = lhs(j+2, i); - blockA[ri+3] = lhs(j+3, i); - } else { - Packet lhsV = lhs.template loadPacket(j, i); - pstore(blockA + ri, lhsV); - } - ri += vectorSize; - } - if(PanelMode) ri += vectorSize*(stride - offset - depth); + rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); } - if(PanelMode) ri += offset*(rows - j); - if (j < rows) { + if(PanelMode) rir += (offset*(rows - j - vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); + for(Index i = 0; i < depth; i++) { Index k = j; for(; k < rows; k++) { - blockA[ri] = lhs(k, i); - ri += 1; - } - } - } - - if(PanelMode) ri += (rows - j)*(stride - offset - depth); - } -}; - -// General template for rhs complex packing. -template -struct rhs_cpack -{ - EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) - { - const int vectorSize = quad_traits::vectorsize; - Scalar *blockBt = reinterpret_cast(blockB); - Packet conj = pset1((Scalar)-1.0); - - Index ri = 0, j = 0; - for(; j + vectorSize <= cols; j+=vectorSize) - { - Index i = 0; - - if(PanelMode) ri += offset*vectorSize; - - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock cblock; - if(StorageOrder == ColMajor) - { - cblock.packet[0] = rhs.template loadPacket(i, j + 0); - cblock.packet[1] = rhs.template loadPacket(i, j + 1); - cblock.packet[2] = rhs.template loadPacket(i, j + 2); - cblock.packet[3] = rhs.template loadPacket(i, j + 3); - - cblock.packet[4] = rhs.template loadPacket(i + 2, j + 0); - cblock.packet[5] = rhs.template loadPacket(i + 2, j + 1); - cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); - cblock.packet[7] = rhs.template loadPacket(i + 2, j + 3); - } else { - cblock.packet[0] = rhs.template loadPacket(i + 0, j); - cblock.packet[1] = rhs.template loadPacket(i + 1, j); - cblock.packet[2] = rhs.template loadPacket(i + 2, j); - cblock.packet[3] = rhs.template loadPacket(i + 3, j); - - cblock.packet[4] = rhs.template loadPacket(i + 0, j + 2); - cblock.packet[5] = rhs.template loadPacket(i + 1, j + 2); - cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); - cblock.packet[7] = rhs.template loadPacket(i + 3, j + 2); - } - - PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); - - if(StorageOrder == ColMajor) ptranspose(block); - - pstore(blockBt + ri , block.packet[0]); - pstore(blockBt + ri + 4, block.packet[1]); - pstore(blockBt + ri + 8, block.packet[2]); - pstore(blockBt + ri + 12, block.packet[3]); - - ri += 4*vectorSize; - } - for(; i < depth; i++) - { - blockBt[ri+0] = rhs(i, j+0).real(); - blockBt[ri+1] = rhs(i, j+1).real(); - blockBt[ri+2] = rhs(i, j+2).real(); - blockBt[ri+3] = rhs(i, j+3).real(); - ri += vectorSize; - } - - if(PanelMode) ri += vectorSize*(stride - offset - depth); - - i = 0; - - if(PanelMode) ri += offset*vectorSize; - - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock cblock; - if(StorageOrder == ColMajor) - { - - cblock.packet[0] = rhs.template loadPacket(i, j + 0); - cblock.packet[1] = rhs.template loadPacket(i, j + 1); - cblock.packet[2] = rhs.template loadPacket(i, j + 2); - cblock.packet[3] = rhs.template loadPacket(i, j + 3); - - cblock.packet[4] = rhs.template loadPacket(i + 2, j + 0); - cblock.packet[5] = rhs.template loadPacket(i + 2, j + 1); - cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); - cblock.packet[7] = rhs.template loadPacket(i + 2, j + 3); - } else { - cblock.packet[0] = rhs.template loadPacket(i + 0, j); - cblock.packet[1] = rhs.template loadPacket(i + 1, j); - cblock.packet[2] = rhs.template loadPacket(i + 2, j); - cblock.packet[3] = rhs.template loadPacket(i + 3, j); - - cblock.packet[4] = rhs.template loadPacket(i + 0, j + 2); - cblock.packet[5] = rhs.template loadPacket(i + 1, j + 2); - cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); - cblock.packet[7] = rhs.template loadPacket(i + 3, j + 2); - } - - PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); + if (UseLhs) { + blockAt[rir] = lhs(k, i).real(); - if(Conjugate) - { - block.packet[0] *= conj; - block.packet[1] *= conj; - block.packet[2] *= conj; - block.packet[3] *= conj; - } - - if(StorageOrder == ColMajor) ptranspose(block); - - pstore(blockBt + ri , block.packet[0]); - pstore(blockBt + ri + 4, block.packet[1]); - pstore(blockBt + ri + 8, block.packet[2]); - pstore(blockBt + ri + 12, block.packet[3]); - - ri += 4*vectorSize; - } - for(; i < depth; i++) - { if(Conjugate) - { - blockBt[ri+0] = -rhs(i, j+0).imag(); - blockBt[ri+1] = -rhs(i, j+1).imag(); - blockBt[ri+2] = -rhs(i, j+2).imag(); - blockBt[ri+3] = -rhs(i, j+3).imag(); - } else { - blockBt[ri+0] = rhs(i, j+0).imag(); - blockBt[ri+1] = rhs(i, j+1).imag(); - blockBt[ri+2] = rhs(i, j+2).imag(); - blockBt[ri+3] = rhs(i, j+3).imag(); - } - ri += vectorSize; - } - - if(PanelMode) ri += vectorSize*(stride - offset - depth); - } - - if(PanelMode) ri += offset*(cols - j); - - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < cols; k++) - { - blockBt[ri] = rhs(i, k).real(); - ri += 1; - } - } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); - - if(PanelMode) ri += offset*(cols - j); + blockAt[rii] = -lhs(k, i).imag(); + else + blockAt[rii] = lhs(k, i).imag(); + } else { + blockAt[rir] = lhs(i, k).real(); - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < cols; k++) - { if(Conjugate) - blockBt[ri] = -rhs(i, k).imag(); + blockAt[rii] = -lhs(i, k).imag(); else - blockBt[ri] = rhs(i, k).imag(); - ri += 1; + blockAt[rii] = lhs(i, k).imag(); + } + + rir += 1; + rii += 1; } + } } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); } }; -// General template for rhs packing. -template -struct rhs_pack { - EIGEN_STRONG_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +// General template for lhs & rhs packing. +template +struct dhs_pack{ + EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; + const Index vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + vectorSize <= cols; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; - if(PanelMode) ri += offset*vectorSize; + if(PanelMode) ri += vectorSize*offset; for(; i + vectorSize <= depth; i+=vectorSize) { - PacketBlock block; - if(StorageOrder == ColMajor) - { - block.packet[0] = rhs.template loadPacket(i, j + 0); - block.packet[1] = rhs.template loadPacket(i, j + 1); - block.packet[2] = rhs.template loadPacket(i, j + 2); - block.packet[3] = rhs.template loadPacket(i, j + 3); + PacketBlock block; - ptranspose(block); + if (UseLhs) { + bload(block, lhs, j, i); } else { - block.packet[0] = rhs.template loadPacket(i + 0, j); - block.packet[1] = rhs.template loadPacket(i + 1, j); - block.packet[2] = rhs.template loadPacket(i + 2, j); - block.packet[3] = rhs.template loadPacket(i + 3, j); + bload(block, lhs, i, j); + } + if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) + { + ptranspose(block); } - pstore(blockB + ri , block.packet[0]); - pstore(blockB + ri + 4, block.packet[1]); - pstore(blockB + ri + 8, block.packet[2]); - pstore(blockB + ri + 12, block.packet[3]); + storeBlock(blockA + ri, block); ri += 4*vectorSize; } for(; i < depth; i++) { - if(StorageOrder == ColMajor) + if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) { - blockB[ri+0] = rhs(i, j+0); - blockB[ri+1] = rhs(i, j+1); - blockB[ri+2] = rhs(i, j+2); - blockB[ri+3] = rhs(i, j+3); + if (UseLhs) { + blockA[ri+0] = lhs(j+0, i); + blockA[ri+1] = lhs(j+1, i); + blockA[ri+2] = lhs(j+2, i); + blockA[ri+3] = lhs(j+3, i); + } else { + blockA[ri+0] = lhs(i, j+0); + blockA[ri+1] = lhs(i, j+1); + blockA[ri+2] = lhs(i, j+2); + blockA[ri+3] = lhs(i, j+3); + } } else { - Packet rhsV = rhs.template loadPacket(i, j); - pstore(blockB + ri, rhsV); + Packet lhsV; + if (UseLhs) { + lhsV = lhs.template loadPacket(j, i); + } else { + lhsV = lhs.template loadPacket(i, j); + } + pstore(blockA + ri, lhsV); } + ri += vectorSize; } if(PanelMode) ri += vectorSize*(stride - offset - depth); } - if(PanelMode) ri += offset*(cols - j); - - if (j < cols) + if (j < rows) { + if(PanelMode) ri += offset*(rows - j); + for(Index i = 0; i < depth; i++) { Index k = j; - for(; k < cols; k++) + for(; k < rows; k++) { - blockB[ri] = rhs(i, k); + if (UseLhs) { + blockA[ri] = lhs(k, i); + } else { + blockA[ri] = lhs(i, k); + } ri += 1; } } } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); } }; // General template for lhs packing, float64 specialization. template -struct lhs_pack +struct dhs_pack { EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; + const Index vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize <= rows; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; @@ -929,7 +655,7 @@ struct lhs_pack for(; i + vectorSize <= depth; i+=vectorSize) { - PacketBlock block; + PacketBlock block; if(StorageOrder == RowMajor) { block.packet[0] = lhs.template loadPacket(j + 0, i); @@ -941,8 +667,7 @@ struct lhs_pack block.packet[1] = lhs.template loadPacket(j, i + 1); } - pstore(blockA + ri , block.packet[0]); - pstore(blockA + ri + 2, block.packet[1]); + storeBlock(blockA + ri, block); ri += 2*vectorSize; } @@ -959,44 +684,48 @@ struct lhs_pack ri += vectorSize; } + if(PanelMode) ri += vectorSize*(stride - offset - depth); } - if(PanelMode) ri += offset*(rows - j); - - for(Index i = 0; i < depth; i++) + if (j < rows) { - Index k = j; - for(; k < rows; k++) + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) { - blockA[ri] = lhs(k, i); - ri += 1; + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } } } - - if(PanelMode) ri += (rows - j)*(stride - offset - depth); } }; // General template for rhs packing, float64 specialization. template -struct rhs_pack +struct dhs_pack { EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; + const Index vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; if(PanelMode) ri += offset*(2*vectorSize); + for(; i + vectorSize <= depth; i+=vectorSize) { - PacketBlock block; + PacketBlock block; if(StorageOrder == ColMajor) { - PacketBlock block1, block2; + PacketBlock block1, block2; block1.packet[0] = rhs.template loadPacket(i, j + 0); block1.packet[1] = rhs.template loadPacket(i, j + 1); block2.packet[0] = rhs.template loadPacket(i, j + 2); @@ -1015,10 +744,7 @@ struct rhs_pack block.packet[2] = rhs.template loadPacket(i + 1, j + 0); //[b1 b2] block.packet[3] = rhs.template loadPacket(i + 1, j + 2); //[b3 b4] - pstore(blockB + ri , block.packet[0]); - pstore(blockB + ri + 2, block.packet[1]); - pstore(blockB + ri + 4, block.packet[2]); - pstore(blockB + ri + 6, block.packet[3]); + storeBlock(blockB + ri, block); } ri += 4*vectorSize; @@ -1049,43 +775,46 @@ struct rhs_pack if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); } - if(PanelMode) ri += offset*(cols - j); - - for(Index i = 0; i < depth; i++) + if (j < cols) { - Index k = j; - for(; k < cols; k++) + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) { - blockB[ri] = rhs(i, k); - ri += 1; + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } } } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); } }; // General template for lhs complex packing, float64 specialization. -template -struct lhs_cpack +template +struct dhs_cpack { EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; - Index ri = 0, j = 0; - double *blockAt = reinterpret_cast(blockA); - Packet conj = pset1(-1.0); + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; + double* blockAt = reinterpret_cast(blockA); + Index j = 0; - for(j = 0; j + vectorSize <= rows; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; - if(PanelMode) ri += vectorSize*offset; + rii = rir + vectorDelta; for(; i + vectorSize <= depth; i+=vectorSize) { - PacketBlock block; + PacketBlock blockr, blocki; + PacketBlock cblock; - PacketBlock cblock; if(StorageOrder == ColMajor) { cblock.packet[0] = lhs.template loadPacket(j, i + 0); //[a1 a1i] @@ -1094,219 +823,157 @@ struct lhs_cpack(j + 1, i + 0); //[a2 a2i] cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] + blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); } else { cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); //[b1 b1i] - cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] - } - - pstore(blockAt + ri , block.packet[0]); - pstore(blockAt + ri + 2, block.packet[1]); - - ri += 2*vectorSize; - } - for(; i < depth; i++) - { - blockAt[ri + 0] = lhs(j + 0, i).real(); - blockAt[ri + 1] = lhs(j + 1, i).real(); - ri += vectorSize; - } - if(PanelMode) ri += vectorSize*(stride - offset - depth); - - i = 0; - - if(PanelMode) ri += vectorSize*offset; - - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock block; - - PacketBlock cblock; - if(StorageOrder == ColMajor) - { - cblock.packet[0] = lhs.template loadPacket(j, i + 0); - cblock.packet[1] = lhs.template loadPacket(j, i + 1); + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i - cblock.packet[2] = lhs.template loadPacket(j + 1, i + 0); - cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] + blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); - } else { - cblock.packet[0] = lhs.template loadPacket(j + 0, i); - cblock.packet[1] = lhs.template loadPacket(j + 1, i); - - cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); - cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); - - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); } if(Conjugate) { - block.packet[0] *= conj; - block.packet[1] *= conj; + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; } - pstore(blockAt + ri , block.packet[0]); - pstore(blockAt + ri + 2, block.packet[1]); + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); - ri += 2*vectorSize; + rir += 2*vectorSize; + rii += 2*vectorSize; } for(; i < depth; i++) { + PacketBlock blockr, blocki; + PacketBlock cblock; + + cblock.packet[0] = pload(&lhs(j + 0, i)); + cblock.packet[1] = pload(&lhs(j + 1, i)); + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + if(Conjugate) { - blockAt[ri + 0] = -lhs(j + 0, i).imag(); - blockAt[ri + 1] = -lhs(j + 1, i).imag(); - } else { - blockAt[ri + 0] = lhs(j + 0, i).imag(); - blockAt[ri + 1] = lhs(j + 1, i).imag(); + blocki.packet[0] = -blocki.packet[0]; } - ri += vectorSize; + pstore(blockAt + rir, blockr.packet[0]); + pstore(blockAt + rii, blocki.packet[0]); + + rir += vectorSize; + rii += vectorSize; } - if(PanelMode) ri += vectorSize*(stride - offset - depth); - } - if(PanelMode) ri += offset*(rows - j); + rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); + } - for(Index i = 0; i < depth; i++) + if (j < rows) { - Index k = j; - for(; k < rows; k++) - { - blockAt[ri] = lhs(k, i).real(); - ri += 1; - } - } + if(PanelMode) rir += (offset*(rows - j - vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); - if(PanelMode) ri += (rows - j)*(stride - offset - depth); + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockAt[rir] = lhs(k, i).real(); - if(PanelMode) ri += offset*(rows - j); + if(Conjugate) + blockAt[rii] = -lhs(k, i).imag(); + else + blockAt[rii] = lhs(k, i).imag(); - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < rows; k++) - { - if(Conjugate) - blockAt[ri] = -lhs(k, i).imag(); - else - blockAt[ri] = lhs(k, i).imag(); - ri += 1; + rir += 1; + rii += 1; + } } } - - if(PanelMode) ri += (rows - j)*(stride - offset - depth); } }; // General template for rhs complex packing, float64 specialization. template -struct rhs_cpack +struct dhs_cpack { EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - const int vectorSize = quad_traits::vectorsize; - double *blockBt = reinterpret_cast(blockB); - Packet conj = pset1(-1.0); + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii; + double* blockBt = reinterpret_cast(blockB); + Index j = 0; - Index ri = 0, j = 0; for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; - if(PanelMode) ri += offset*(2*vectorSize); + rii = rir + vectorDelta; for(; i < depth; i++) { - PacketBlock cblock; - PacketBlock block; - - cblock.packet[0] = rhs.template loadPacket(i, j + 0); - cblock.packet[1] = rhs.template loadPacket(i, j + 1); - cblock.packet[2] = rhs.template loadPacket(i, j + 2); - cblock.packet[3] = rhs.template loadPacket(i, j + 3); - - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); - block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); - - pstore(blockBt + ri , block.packet[0]); - pstore(blockBt + ri + 2, block.packet[1]); - - ri += 2*vectorSize; - } - - if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); + PacketBlock cblock; + PacketBlock blockr, blocki; - i = 0; - - if(PanelMode) ri += offset*(2*vectorSize); - - for(; i < depth; i++) - { - PacketBlock cblock; - PacketBlock block; + bload(cblock, rhs, i, j); - cblock.packet[0] = rhs.template loadPacket(i, j + 0); //[a1 a1i] - cblock.packet[1] = rhs.template loadPacket(i, j + 1); //[b1 b1i] - cblock.packet[2] = rhs.template loadPacket(i, j + 2); //[c1 c1i] - cblock.packet[3] = rhs.template loadPacket(i, j + 3); //[d1 d1i] + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); - block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); if(Conjugate) { - block.packet[0] *= conj; - block.packet[1] *= conj; + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; } - pstore(blockBt + ri , block.packet[0]); - pstore(blockBt + ri + 2, block.packet[1]); + storeBlock(blockBt + rir, blockr); + storeBlock(blockBt + rii, blocki); - ri += 2*vectorSize; + rir += 2*vectorSize; + rii += 2*vectorSize; } - if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); - } - if(PanelMode) ri += offset*(cols - j); + rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta); + } - for(Index i = 0; i < depth; i++) + if (j < cols) { - Index k = j; - for(; k < cols; k++) + if(PanelMode) rir += (offset*(cols - j - 2*vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (cols - j)); + + for(Index i = 0; i < depth; i++) { - blockBt[ri] = rhs(i, k).real(); - ri += 1; - } - } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); + Index k = j; + for(; k < cols; k++) + { + blockBt[rir] = rhs(i, k).real(); - if(PanelMode) ri += offset*(cols - j); + if(Conjugate) + blockBt[rii] = -rhs(i, k).imag(); + else + blockBt[rii] = rhs(i, k).imag(); - for(Index i = 0; i < depth; i++) - { - Index k = j; - for(; k < cols; k++) - { - if(Conjugate) - blockBt[ri] = -rhs(i, k).imag(); - else - blockBt[ri] = rhs(i, k).imag(); - ri += 1; + rir += 1; + rii += 1; + } } } - if(PanelMode) ri += (cols - j)*(stride - offset - depth); } }; @@ -1315,12 +982,9 @@ struct rhs_cpack -EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +template +EIGEN_STRONG_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) { - asm("#pger begin"); - Packet lhsV = pload(lhs); - if(NegativeAccumulate) { acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); @@ -1333,14 +997,11 @@ EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, con acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); } - asm("#pger end"); } -template -EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +template +EIGEN_STRONG_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) { - Packet lhsV = pload(lhs); - if(NegativeAccumulate) { acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); @@ -1349,130 +1010,95 @@ EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, con } } -template -EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +{ + Packet lhsV = pload(lhs); + + pger_common(acc, lhsV, rhsV); +} + +template +EIGEN_STRONG_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows) { #ifdef _ARCH_PWR9 - Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); + lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); #else - Packet lhsV; Index i = 0; do { lhsV[i] = lhs[i]; } while (++i < remaining_rows); #endif +} - if(NegativeAccumulate) +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ + Packet lhsV; + loadPacketRemaining(lhs, lhsV, remaining_rows); + + pger_common(acc, lhsV, rhsV); +} + +// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. +template +EIGEN_STRONG_INLINE void pgerc_common(PacketBlock* accReal, PacketBlock* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) +{ + pger_common(accReal, lhsV, rhsV); + if(LhsIsReal) { - acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); - acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); - acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); - acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + pger_common(accImag, lhsV, rhsVi); + EIGEN_UNUSED_VARIABLE(lhsVi); } else { - acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); - acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); - acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); - acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + if (!RhsIsReal) { + pger_common(accReal, lhsVi, rhsVi); + pger_common(accImag, lhsV, rhsVi); + } else { + EIGEN_UNUSED_VARIABLE(rhsVi); + } + pger_common(accImag, lhsVi, rhsV); } } -template -EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +template +EIGEN_STRONG_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) +{ + Packet lhsV = ploadLhs(lhs_ptr); + Packet lhsVi; + if(!LhsIsReal) lhsVi = ploadLhs(lhs_ptr_imag); + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + + pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); +} + +template +EIGEN_STRONG_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows) { #ifdef _ARCH_PWR9 - Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); + lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar)); + if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar)); + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); #else - Packet lhsV; Index i = 0; do { - lhsV[i] = lhs[i]; + lhsV[i] = lhs_ptr[i]; + if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i]; } while (++i < remaining_rows); + if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); #endif - - if(NegativeAccumulate) - { - acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); - } else { - acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); - } } -// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. -template -EIGEN_STRONG_INLINE void pgerc(PacketBlock& accReal, PacketBlock& accImag, const Scalar *rhs_ptr, const Scalar *rhs_ptr_imag, const Scalar *lhs_ptr, const Scalar* lhs_ptr_imag, Packet& conj) +template +EIGEN_STRONG_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows) { - Packet lhsV = *((Packet *) lhs_ptr); - Packet rhsV1 = pset1(rhs_ptr[0]); - Packet rhsV2 = pset1(rhs_ptr[1]); - Packet rhsV3 = pset1(rhs_ptr[2]); - Packet rhsV4 = pset1(rhs_ptr[3]); - - Packet lhsVi; - if(!LhsIsReal) lhsVi = *((Packet *) lhs_ptr_imag); - Packet rhsV1i, rhsV2i, rhsV3i, rhsV4i; - if(!RhsIsReal) - { - rhsV1i = pset1(rhs_ptr_imag[0]); - rhsV2i = pset1(rhs_ptr_imag[1]); - rhsV3i = pset1(rhs_ptr_imag[2]); - rhsV4i = pset1(rhs_ptr_imag[3]); - } - - if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi,conj); - if(ConjugateRhs && !RhsIsReal) - { - rhsV1i = pmul(rhsV1i,conj); - rhsV2i = pmul(rhsV2i,conj); - rhsV3i = pmul(rhsV3i,conj); - rhsV4i = pmul(rhsV4i,conj); - } + Packet lhsV, lhsVi; + loadPacketRemaining(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows); - if(LhsIsReal) - { - accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); - accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); - accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); - accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); - - accImag.packet[0] = pmadd(rhsV1i, lhsV, accImag.packet[0]); - accImag.packet[1] = pmadd(rhsV2i, lhsV, accImag.packet[1]); - accImag.packet[2] = pmadd(rhsV3i, lhsV, accImag.packet[2]); - accImag.packet[3] = pmadd(rhsV4i, lhsV, accImag.packet[3]); - } else if(RhsIsReal) { - accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); - accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); - accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); - accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); - - accImag.packet[0] = pmadd(rhsV1, lhsVi, accImag.packet[0]); - accImag.packet[1] = pmadd(rhsV2, lhsVi, accImag.packet[1]); - accImag.packet[2] = pmadd(rhsV3, lhsVi, accImag.packet[2]); - accImag.packet[3] = pmadd(rhsV4, lhsVi, accImag.packet[3]); - } else { - accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); - accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); - accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); - accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); - - accImag.packet[0] = pmadd(rhsV1i, lhsV, accImag.packet[0]); - accImag.packet[1] = pmadd(rhsV2i, lhsV, accImag.packet[1]); - accImag.packet[2] = pmadd(rhsV3i, lhsV, accImag.packet[2]); - accImag.packet[3] = pmadd(rhsV4i, lhsV, accImag.packet[3]); - - accReal.packet[0] = psub(accReal.packet[0], pmul(rhsV1i, lhsVi)); - accReal.packet[1] = psub(accReal.packet[1], pmul(rhsV2i, lhsVi)); - accReal.packet[2] = psub(accReal.packet[2], pmul(rhsV3i, lhsVi)); - accReal.packet[3] = psub(accReal.packet[3], pmul(rhsV4i, lhsVi)); - - accImag.packet[0] = pmadd(rhsV1, lhsVi, accImag.packet[0]); - accImag.packet[1] = pmadd(rhsV2, lhsVi, accImag.packet[1]); - accImag.packet[2] = pmadd(rhsV3, lhsVi, accImag.packet[2]); - accImag.packet[3] = pmadd(rhsV4, lhsVi, accImag.packet[3]); - } + pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); } template -EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) +EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar* lhs) { return *((Packet *)lhs); } @@ -1509,53 +1135,99 @@ EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock -EIGEN_STRONG_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) +EIGEN_STRONG_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { - cReal.packet[0] = pmul(aReal.packet[0], bReal); - cReal.packet[1] = pmul(aReal.packet[1], bReal); - cReal.packet[2] = pmul(aReal.packet[2], bReal); - cReal.packet[3] = pmul(aReal.packet[3], bReal); + acc.packet[0] = pmul(accZ.packet[0], pAlpha); + acc.packet[1] = pmul(accZ.packet[1], pAlpha); + acc.packet[2] = pmul(accZ.packet[2], pAlpha); + acc.packet[3] = pmul(accZ.packet[3], pAlpha); +} + +template +EIGEN_STRONG_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmul(accZ.packet[0], pAlpha); +} + +// Complex version of PacketBlock scaling. +template +EIGEN_STRONG_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) +{ + bscalec_common(cReal, aReal, bReal); - cImag.packet[0] = pmul(aImag.packet[0], bReal); - cImag.packet[1] = pmul(aImag.packet[1], bReal); - cImag.packet[2] = pmul(aImag.packet[2], bReal); - cImag.packet[3] = pmul(aImag.packet[3], bReal); + bscalec_common(cImag, aImag, bReal); - cReal.packet[0] = psub(cReal.packet[0], pmul(aImag.packet[0], bImag)); - cReal.packet[1] = psub(cReal.packet[1], pmul(aImag.packet[1], bImag)); - cReal.packet[2] = psub(cReal.packet[2], pmul(aImag.packet[2], bImag)); - cReal.packet[3] = psub(cReal.packet[3], pmul(aImag.packet[3], bImag)); + pger_common(&cReal, bImag, aImag.packet); - cImag.packet[0] = pmadd(aReal.packet[0], bImag, cImag.packet[0]); - cImag.packet[1] = pmadd(aReal.packet[1], bImag, cImag.packet[1]); - cImag.packet[2] = pmadd(aReal.packet[2], bImag, cImag.packet[2]); - cImag.packet[3] = pmadd(aReal.packet[3], bImag, cImag.packet[3]); + pger_common(&cImag, bImag, aReal.packet); } -// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. -template -EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col, Index accCols) +template +EIGEN_STRONG_INLINE void band(PacketBlock& acc, const Packet& pMask) { - acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); - acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); - acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); - acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + acc.packet[0] = pand(acc.packet[0], pMask); + acc.packet[1] = pand(acc.packet[1], pMask); + acc.packet[2] = pand(acc.packet[2], pMask); + acc.packet[3] = pand(acc.packet[3], pMask); +} + +template +EIGEN_STRONG_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) +{ + band(aReal, pMask); + band(aImag, pMask); + + bscalec(aReal, aImag, bReal, bImag, cReal, cImag); +} + +// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +{ + if (StorageOrder == RowMajor) { + acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); + acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); + acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); + acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); + } else { + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + } } // An overload of bload when you have a PacketBLock with 8 vectors. -template -EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col, Index accCols) +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +{ + if (StorageOrder == RowMajor) { + acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); + acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); + acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); + acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); + acc.packet[4] = res.template loadPacket(row + 0, col + (N+1)*accCols); + acc.packet[5] = res.template loadPacket(row + 1, col + (N+1)*accCols); + acc.packet[6] = res.template loadPacket(row + 2, col + (N+1)*accCols); + acc.packet[7] = res.template loadPacket(row + 3, col + (N+1)*accCols); + } else { + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + acc.packet[4] = res.template loadPacket(row + (N+1)*accCols, col + 0); + acc.packet[5] = res.template loadPacket(row + (N+1)*accCols, col + 1); + acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); + acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); + } +} + +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) { acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); - acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); - acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); - acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); - acc.packet[4] = res.template loadPacket(row + (N+1)*accCols, col + 0); - acc.packet[5] = res.template loadPacket(row + (N+1)*accCols, col + 1); - acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); - acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); + acc.packet[1] = res.template loadPacket(row + (N+1)*accCols, col + 0); } const static Packet4i mask41 = { -1, 0, 0, 0 }; @@ -1568,7 +1240,7 @@ template EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows) { if (remaining_rows == 0) { - return pset1(float(0.0)); + return pset1(float(0.0)); // Not used } else { switch (remaining_rows) { case 1: return Packet(mask41); @@ -1582,7 +1254,7 @@ template<> EIGEN_STRONG_INLINE Packet2d bmask(const int remaining_rows) { if (remaining_rows == 0) { - return pset1(double(0.0)); + return pset1(double(0.0)); // Not used } else { return Packet2d(mask21); } @@ -1591,14 +1263,30 @@ EIGEN_STRONG_INLINE Packet2d bmask(const int remaining_rows) template EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) { - acc.packet[0] = pmadd(pAlpha, pand(accZ.packet[0], pMask), acc.packet[0]); - acc.packet[1] = pmadd(pAlpha, pand(accZ.packet[1], pMask), acc.packet[1]); - acc.packet[2] = pmadd(pAlpha, pand(accZ.packet[2], pMask), acc.packet[2]); - acc.packet[3] = pmadd(pAlpha, pand(accZ.packet[3], pMask), acc.packet[3]); + band(accZ, pMask); + + bscale(acc, accZ, pAlpha); +} + +template +EIGEN_STRONG_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) +{ + pbroadcast4(a, a0, a1, a2, a3); +} + +template<> +EIGEN_STRONG_INLINE void pbroadcast4_old(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +{ + a1 = pload(a); + a3 = pload(a + 2); + a0 = vec_splat(a1, 0); + a1 = vec_splat(a1, 1); + a2 = vec_splat(a3, 0); + a3 = vec_splat(a3, 1); } // PEEL loop factor. -#define PEEL 10 +#define PEEL 7 template EIGEN_STRONG_INLINE void MICRO_EXTRA_COL( @@ -1610,7 +1298,7 @@ EIGEN_STRONG_INLINE void MICRO_EXTRA_COL( { Packet rhsV[1]; rhsV[0] = pset1(rhs_ptr[0]); - pger(&accZero, lhs_ptr, rhsV); + pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); lhs_ptr += remaining_rows; rhs_ptr += remaining_cols; } @@ -1618,8 +1306,8 @@ EIGEN_STRONG_INLINE void MICRO_EXTRA_COL( template EIGEN_STRONG_INLINE void gemm_extra_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -1629,9 +1317,9 @@ EIGEN_STRONG_INLINE void gemm_extra_col( Index remaining_cols, const Packet& pAlpha) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; - PacketBlock accZero, acc; + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero; bsetzero(accZero); @@ -1639,8 +1327,8 @@ EIGEN_STRONG_INLINE void gemm_extra_col( Index k = 0; for(; k + PEEL <= remaining_depth; k+= PEEL) { - prefetch(rhs_ptr); - prefetch(lhs_ptr); + EIGEN_POWER_PREFETCH(rhs_ptr); + EIGEN_POWER_PREFETCH(lhs_ptr); for (int l = 0; l < PEEL; l++) { MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); } @@ -1653,14 +1341,14 @@ EIGEN_STRONG_INLINE void gemm_extra_col( { Packet rhsV[1]; rhsV[0] = pset1(rhs_ptr[0]); - pger(&accZero, lhs_ptr, rhsV, remaining_rows); + pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); lhs_ptr += remaining_rows; rhs_ptr += remaining_cols; } - acc.packet[0] = vec_mul(pAlpha, accZero.packet[0]); - for(Index i = 0; i < remaining_rows; i++){ - res(row + i, col) += acc.packet[0][i]; + accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]); + for(Index i = 0; i < remaining_rows; i++) { + res(row + i, col) += accZero.packet[0][i]; } } @@ -1673,28 +1361,29 @@ EIGEN_STRONG_INLINE void MICRO_EXTRA_ROW( { Packet rhsV[4]; pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - pger(&accZero, lhs_ptr, rhsV); + pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); lhs_ptr += remaining_rows; rhs_ptr += accRows; } -template +template EIGEN_STRONG_INLINE void gemm_extra_row( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, Index row, Index col, + Index rows, Index cols, Index remaining_rows, const Packet& pAlpha, const Packet& pMask) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; PacketBlock accZero, acc; bsetzero(accZero); @@ -1703,8 +1392,8 @@ EIGEN_STRONG_INLINE void gemm_extra_row( Index k = 0; for(; k + PEEL <= remaining_depth; k+= PEEL) { - prefetch(rhs_ptr); - prefetch(lhs_ptr); + EIGEN_POWER_PREFETCH(rhs_ptr); + EIGEN_POWER_PREFETCH(lhs_ptr); for (int l = 0; l < PEEL; l++) { MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); } @@ -1714,78 +1403,103 @@ EIGEN_STRONG_INLINE void gemm_extra_row( MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); } - if (remaining_depth == depth) + if ((remaining_depth == depth) && (rows >= accCols)) { - for(Index j = 0; j < 4; j++){ + for(Index j = 0; j < 4; j++) { acc.packet[j] = res.template loadPacket(row, col + j); } - bscale(acc, accZero, pAlpha, pMask); - res.template storePacketBlock(row, col, acc); + bscale(acc, accZero, pAlpha, pMask); + res.template storePacketBlock(row, col, acc); } else { for(; k < depth; k++) { Packet rhsV[4]; pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - pger(&accZero, lhs_ptr, rhsV, remaining_rows); + pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); lhs_ptr += remaining_rows; rhs_ptr += accRows; } - for(Index j = 0; j < 4; j++){ - acc.packet[j] = vec_mul(pAlpha, accZero.packet[j]); + for(Index j = 0; j < 4; j++) { + accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]); } - for(Index j = 0; j < 4; j++){ - for(Index i = 0; i < remaining_rows; i++){ - res(row + i, col + j) += acc.packet[j][i]; + for(Index j = 0; j < 4; j++) { + for(Index i = 0; i < remaining_rows; i++) { + res(row + i, col + j) += accZero.packet[j][i]; } } } } -#define MICRO_DST \ - PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ - PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ - PacketBlock *accZero6, PacketBlock *accZero7 +#define MICRO_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) -#define MICRO_COL_DST \ - PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ - PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ - PacketBlock *accZero6, PacketBlock *accZero7 +#define MICRO_UNROLL_WORK(func, func2, peel) \ + MICRO_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) \ + func(4,peel) func(5,peel) func(6,peel) func(7,peel) -#define MICRO_SRC \ - const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \ - const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \ - const Scalar **lhs_ptr6, const Scalar **lhs_ptr7 +#define MICRO_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr##iter); \ + lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + } -#define MICRO_ONE \ - MICRO(\ - &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ - rhs_ptr, \ - &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); +#define MICRO_WORK_ONE(iter, peel) \ + if (unroll_factor > iter) { \ + pger_common(&accZero##iter, lhsV##iter, rhsV##peel); \ + } -#define MICRO_COL_ONE \ - MICRO_COL(\ - &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ - rhs_ptr, \ - &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7, \ - remaining_cols); +#define MICRO_TYPE_PEEL4(func, func2, peel) \ + if (PEEL > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + pbroadcast4(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + MICRO_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + } -#define MICRO_WORK_ONE(iter) \ - if (N > iter) { \ - pger(accZero##iter, *lhs_ptr##iter, rhsV); \ - *lhs_ptr##iter += accCols; \ +#define MICRO_TYPE_PEEL1(func, func2, peel) \ + if (PEEL > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + rhsV##peel[0] = pset1(rhs_ptr[remaining_cols * peel]); \ + MICRO_UNROLL_WORK(func, func2, peel) \ } else { \ - EIGEN_UNUSED_VARIABLE(accZero##iter); \ - EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ } -#define MICRO_UNROLL(func) \ - func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) +#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ + func(func1,func2,0); func(func1,func2,1); \ + func(func1,func2,2); func(func1,func2,3); \ + func(func1,func2,4); func(func1,func2,5); \ + func(func1,func2,6); func(func1,func2,7); \ + func(func1,func2,8); func(func1,func2,9); -#define MICRO_WORK MICRO_UNROLL(MICRO_WORK_ONE) +#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \ + Packet rhsV0[M]; \ + func(func1,func2,0); + +#define MICRO_ONE_PEEL4 \ + MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += (accRows * PEEL); + +#define MICRO_ONE4 \ + MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += accRows; + +#define MICRO_ONE_PEEL1 \ + MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += (remaining_cols * PEEL); + +#define MICRO_ONE1 \ + MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += remaining_cols; #define MICRO_DST_PTR_ONE(iter) \ - if (unroll_factor > iter){ \ + if (unroll_factor > iter) { \ bsetzero(accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ @@ -1803,52 +1517,38 @@ EIGEN_STRONG_INLINE void gemm_extra_row( #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) #define MICRO_PREFETCH_ONE(iter) \ - if (unroll_factor > iter){ \ - prefetch(lhs_ptr##iter); \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ } #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) #define MICRO_STORE_ONE(iter) \ - if (unroll_factor > iter){ \ + if (unroll_factor > iter) { \ acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ acc.packet[1] = res.template loadPacket(row + iter*accCols, col + 1); \ acc.packet[2] = res.template loadPacket(row + iter*accCols, col + 2); \ acc.packet[3] = res.template loadPacket(row + iter*accCols, col + 3); \ - bscale(acc, accZero##iter, pAlpha); \ - res.template storePacketBlock(row + iter*accCols, col, acc); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ } #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) #define MICRO_COL_STORE_ONE(iter) \ - if (unroll_factor > iter){ \ + if (unroll_factor > iter) { \ acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ - bscale(acc, accZero##iter, pAlpha); \ - res.template storePacketBlock(row + iter*accCols, col, acc); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ } #define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) -template -EIGEN_STRONG_INLINE void MICRO( - MICRO_SRC, - const Scalar* &rhs_ptr, - MICRO_DST) - { - Packet rhsV[4]; - pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - asm("#unrolled pger? begin"); - MICRO_WORK - asm("#unrolled pger? end"); - rhs_ptr += accRows; - } - template EIGEN_STRONG_INLINE void gemm_unrolled_iteration( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -1856,56 +1556,37 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration( Index col, const Packet& pAlpha) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; +asm("#gemm begin"); + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0, * lhs_ptr1, * lhs_ptr2, * lhs_ptr3, * lhs_ptr4, * lhs_ptr5, * lhs_ptr6, * lhs_ptr7; PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; PacketBlock acc; - asm("#unrolled start"); MICRO_SRC_PTR - asm("#unrolled zero?"); MICRO_DST_PTR Index k = 0; for(; k + PEEL <= depth; k+= PEEL) { - prefetch(rhs_ptr); + EIGEN_POWER_PREFETCH(rhs_ptr); MICRO_PREFETCH - asm("#unrolled inner loop?"); - for (int l = 0; l < PEEL; l++) { - MICRO_ONE - } - asm("#unrolled inner loop end?"); + MICRO_ONE_PEEL4 } for(; k < depth; k++) { - MICRO_ONE + MICRO_ONE4 } MICRO_STORE row += unroll_factor*accCols; +asm("#gemm end"); } -template -EIGEN_STRONG_INLINE void MICRO_COL( - MICRO_SRC, - const Scalar* &rhs_ptr, - MICRO_COL_DST, - Index remaining_rows) - { - Packet rhsV[1]; - rhsV[0] = pset1(rhs_ptr[0]); - asm("#unrolled pger? begin"); - MICRO_WORK - asm("#unrolled pger? end"); - rhs_ptr += remaining_rows; - } - template EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -1914,8 +1595,8 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( Index remaining_cols, const Packet& pAlpha) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0, * lhs_ptr1, * lhs_ptr2, * lhs_ptr3, * lhs_ptr4, * lhs_ptr5, * lhs_ptr6, *lhs_ptr7; PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; PacketBlock acc; @@ -1925,15 +1606,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( Index k = 0; for(; k + PEEL <= depth; k+= PEEL) { - prefetch(rhs_ptr); + EIGEN_POWER_PREFETCH(rhs_ptr); MICRO_PREFETCH - for (int l = 0; l < PEEL; l++) { - MICRO_COL_ONE - } + MICRO_ONE_PEEL1 } for(; k < depth; k++) { - MICRO_COL_ONE + MICRO_ONE1 } MICRO_COL_STORE @@ -1943,8 +1622,8 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( template EIGEN_STRONG_INLINE void gemm_unrolled_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -1955,10 +1634,10 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col( const Packet& pAlpha) { #define MAX_UNROLL 6 - while(row + MAX_UNROLL*accCols <= rows){ + while(row + MAX_UNROLL*accCols <= rows) { gemm_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); } - switch( (rows-row)/accCols ){ + switch( (rows-row)/accCols ) { #if MAX_UNROLL > 7 case 7: gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); @@ -2018,16 +1697,15 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB; - const Scalar *lhs_base = blockA; + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; Index row = 0; - asm("#jump table"); #define MAX_UNROLL 6 - while(row + MAX_UNROLL*accCols <= rows){ + while(row + MAX_UNROLL*accCols <= rows) { gemm_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); } - switch( (rows-row)/accCols ){ + switch( (rows-row)/accCols ) { #if MAX_UNROLL > 7 case 7: gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); @@ -2067,18 +1745,17 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const break; } #undef MAX_UNROLL - asm("#jump table end"); if(remaining_rows > 0) { - gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask); + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB; - const Scalar *lhs_base = blockA; + const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; for(; col < cols; col++) { @@ -2095,316 +1772,641 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const } } -template -EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, - Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +#define accColsC (accCols / 2) +#define advanceRows ((LhsIsReal) ? 1 : 2) +#define advanceCols ((RhsIsReal) ? 1 : 2) + +// PEEL_COMPLEX loop factor. +#define PEEL_COMPLEX 3 + +template +EIGEN_STRONG_INLINE void MICRO_COMPLEX_EXTRA_COL( + const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, + const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, + PacketBlock &accReal, PacketBlock &accImag, + Index remaining_rows, + Index remaining_cols) +{ + Packet rhsV[1], rhsVi[1]; + rhsV[0] = pset1(rhs_ptr_real[0]); + if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); + pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + rhs_ptr_real += remaining_cols; + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); +} + +template +EIGEN_STRONG_INLINE void pstore_add_half(std::complex* to, Packetc &from) +{ +#ifdef __VSX__ + Packetc from2; +#ifndef _BIG_ENDIAN + __asm__ ("xxswapd %x0, %x0" : : "wa" (from.v)); +#endif + __asm__ ("lxsdx %x0,%y1" : "=wa" (from2.v) : "Z" (*to)); + from2 += from; + __asm__ ("stxsdx %x0,%y1" : : "wa" (from2.v), "Z" (*to)); +#else + std::complex mem[accColsC]; + pstoreu >(mem, from); + *to += *mem; +#endif +} + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; + const Scalar* lhs_ptr_imag; + if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + PacketBlock accReal, accImag; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + + bsetzero(accReal); + bsetzero(accImag); + + Index remaining_depth = (depth & -accRows); + Index k = 0; + for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + EIGEN_POWER_PREFETCH(lhs_ptr_real); + if(!LhsIsReal) { + EIGEN_POWER_PREFETCH(lhs_ptr_imag); + } + for (int l = 0; l < PEEL_COMPLEX; l++) { + MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); + } + } + for(; k < remaining_depth; k++) + { + MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); + } + + for(; k < depth; k++) + { + Packet rhsV[1], rhsVi[1]; + rhsV[0] = pset1(rhs_ptr_real[0]); + if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); + pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + rhs_ptr_real += remaining_cols; + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + } + + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); + bcouple_common(taccReal, taccImag, acc0, acc1); + + if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) + { + pstore_add_half(&res(row + 0, col + 0), acc0.packet[0]); + } else { + acc0.packet[0] += res.template loadPacket(row + 0, col + 0); + res.template storePacketBlock(row + 0, col + 0, acc0); + if(remaining_rows > accColsC) { + pstore_add_half(&res(row + accColsC, col + 0), acc1.packet[0]); + } + } +} + +template +EIGEN_STRONG_INLINE void MICRO_COMPLEX_EXTRA_ROW( + const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, + const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, + PacketBlock &accReal, PacketBlock &accImag, + Index remaining_rows) +{ + Packet rhsV[4], rhsVi[4]; + pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + rhs_ptr_real += accRows; + if(!RhsIsReal) rhs_ptr_imag += accRows; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); +} + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask) +{ +asm("#gemm_complex begin"); + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; + const Scalar* lhs_ptr_imag; + if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + PacketBlock accReal, accImag; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + bsetzero(accReal); + bsetzero(accImag); + + Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index k = 0; + for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + EIGEN_POWER_PREFETCH(lhs_ptr_real); + if(!LhsIsReal) { + EIGEN_POWER_PREFETCH(lhs_ptr_imag); + } + for (int l = 0; l < PEEL_COMPLEX; l++) { + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); + } + } + for(; k < remaining_depth; k++) + { + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); + } + + if ((remaining_depth == depth) && (rows >= accCols)) + { + bload(tRes, res, row, col); + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); + bcouple(taccReal, taccImag, tRes, acc0, acc1); + res.template storePacketBlock(row + 0, col, acc0); + res.template storePacketBlock(row + accColsC, col, acc1); + } else { + for(; k < depth; k++) + { + Packet rhsV[4], rhsVi[4]; + pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + rhs_ptr_real += accRows; + if(!RhsIsReal) rhs_ptr_imag += accRows; + } + + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); + bcouple_common(taccReal, taccImag, acc0, acc1); + + if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) + { + for(Index j = 0; j < 4; j++) { + pstore_add_half(&res(row + 0, col + j), acc0.packet[j]); + } + } else { + for(Index j = 0; j < 4; j++) { + PacketBlock acc2; + acc2.packet[0] = res.template loadPacket(row + 0, col + j) + acc0.packet[j]; + res.template storePacketBlock(row + 0, col + j, acc2); + if(remaining_rows > accColsC) { + pstore_add_half(&res(row + accColsC, col + j), acc1.packet[j]); + } + } + } + } +asm("#gemm_complex end"); +} + +#define MICRO_COMPLEX_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) + +#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + MICRO_COMPLEX_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel) + +#define MICRO_COMPLEX_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ + lhs_ptr_real##iter += accCols; \ + if(!LhsIsReal) { \ + lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ + lhs_ptr_imag##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } + +#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \ + if (unroll_factor > iter) { \ + pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \ + if (unroll_factor > iter) { \ + pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \ + if (PEEL_COMPLEX > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + pbroadcast4_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + if(!RhsIsReal) { \ + pbroadcast4_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \ + if (PEEL_COMPLEX > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + rhsV##peel[0] = pset1(rhs_ptr_real[remaining_cols * peel]); \ + if(!RhsIsReal) { \ + rhsVi##peel[0] = pset1(rhs_ptr_imag[remaining_cols * peel]); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ + Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \ + func(func1,func2,0); func(func1,func2,1); \ + func(func1,func2,2); func(func1,func2,3); \ + func(func1,func2,4); func(func1,func2,5); \ + func(func1,func2,6); func(func1,func2,7); \ + func(func1,func2,8); func(func1,func2,9); + +#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \ + Packet rhsV0[M], rhsVi0[M];\ + func(func1,func2,0); + +#define MICRO_COMPLEX_ONE_PEEL4 \ + MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += (accRows * PEEL_COMPLEX); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX); + +#define MICRO_COMPLEX_ONE4 \ + MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += accRows; \ + if(!RhsIsReal) rhs_ptr_imag += accRows; + +#define MICRO_COMPLEX_ONE_PEEL1 \ + MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \ + if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX); + +#define MICRO_COMPLEX_ONE1 \ + MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += remaining_cols; \ + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + +#define MICRO_COMPLEX_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzero(accReal##iter); \ + bsetzero(accImag##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accReal##iter); \ + EIGEN_UNUSED_VARIABLE(accImag##iter); \ + } + +#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE) + +#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ + if(!LhsIsReal) { \ + lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } + +#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) + +#define MICRO_COMPLEX_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ + if(!LhsIsReal) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ + } \ + } + +#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) + +#define MICRO_COMPLEX_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + bload(tRes, res, row + iter*accCols, col); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ + res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ + } + +#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE) + +#define MICRO_COMPLEX_COL_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + bload(tRes, res, row + iter*accCols, col); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ + res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ + } + +#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ +asm("#gemm_complex_unrolled begin"); + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + accRows*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0, * lhs_ptr_imag0, * lhs_ptr_real1, * lhs_ptr_imag1; + const Scalar* lhs_ptr_real2, * lhs_ptr_imag2, * lhs_ptr_real3, * lhs_ptr_imag3; + const Scalar* lhs_ptr_real4, * lhs_ptr_imag4; + PacketBlock accReal0, accImag0, accReal1, accImag1; + PacketBlock accReal2, accImag2, accReal3, accImag3; + PacketBlock accReal4, accImag4; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + MICRO_COMPLEX_SRC_PTR + MICRO_COMPLEX_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_PREFETCH + MICRO_COMPLEX_ONE_PEEL4 + } + for(; k < depth; k++) + { + MICRO_COMPLEX_ONE4 + } + MICRO_COMPLEX_STORE + + row += unroll_factor*accCols; +asm("#gemm_complex_unrolled end"); +} + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + remaining_cols*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0, * lhs_ptr_imag0, * lhs_ptr_real1, * lhs_ptr_imag1; + const Scalar* lhs_ptr_real2, * lhs_ptr_imag2, * lhs_ptr_real3, * lhs_ptr_imag3; + const Scalar* lhs_ptr_real4, * lhs_ptr_imag4; + PacketBlock accReal0, accImag0, accReal1, accImag1; + PacketBlock accReal2, accImag2, accReal3, accImag3; + PacketBlock accReal4, accImag4; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + MICRO_COMPLEX_SRC_PTR + MICRO_COMPLEX_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_PREFETCH + MICRO_COMPLEX_ONE_PEEL1 + } + for(; k < depth; k++) + { + MICRO_COMPLEX_ONE1 + } + MICRO_COMPLEX_COL_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) { - const int remaining_rows = rows % accCols; - const int remaining_cols = cols % accRows; - const int accColsC = accCols / 2; - int advanceCols = 2; - int advanceRows = 2; +#define MAX_COMPLEX_UNROLL 3 + while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { + gemm_complex_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_UNROLL > 4 + case 4: + gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 3 + case 3: + gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 2 + case 2: + gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 1 + case 1: + gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_UNROLL +} - if(LhsIsReal) advanceRows = 1; - if(RhsIsReal) advanceCols = 1; +template +EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; const Packet pAlphaReal = pset1(alpha.real()); const Packet pAlphaImag = pset1(alpha.imag()); + const Packet pMask = bmask((const int)(remaining_rows)); - const Scalar *blockA = (Scalar *) blockAc; - const Scalar *blockB = (Scalar *) blockBc; - - Packet conj = pset1((Scalar)-1.0); + const Scalar* blockA = (Scalar *) blockAc; + const Scalar* blockB = (Scalar *) blockBc; Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; - const Scalar *lhs_base = blockA; - + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; Index row = 0; - for(; row + accCols <= rows; row += accCols) - { -#define MICRO() \ - pgerc(accReal1, accImag1, rhs_ptr, rhs_ptr_imag, lhs_ptr1, lhs_ptr_imag1, conj); \ - lhs_ptr1 += accCols; \ - rhs_ptr += accRows; \ - if(!LhsIsReal) \ - lhs_ptr_imag1 += accCols; \ - if(!RhsIsReal) \ - rhs_ptr_imag += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; - const Scalar *lhs_ptr1 = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag1 = lhs_ptr1 + accCols*strideA; - - PacketBlock accReal1, accImag1; - bsetzero(accReal1); - bsetzero(accImag1); - - lhs_ptr1 += accCols*offsetA; - if(!LhsIsReal) - lhs_ptr_imag1 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += accRows*offsetB; - Index k = 0; - for(; k + PEEL <= depth; k+=PEEL) - { - prefetch(rhs_ptr); - prefetch(rhs_ptr_imag); - prefetch(lhs_ptr1); - prefetch(lhs_ptr_imag1); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } - PacketBlock taccReal, taccImag; - bscalec(accReal1, accImag1, pAlphaReal, pAlphaImag, taccReal, taccImag); - - PacketBlock tRes; - bload(tRes, res, row, col, accColsC); - PacketBlock acc1, acc2; - bcouple(taccReal, taccImag, tRes, acc1, acc2); +#define MAX_COMPLEX_UNROLL 3 + while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { + gemm_complex_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_UNROLL > 4 + case 4: + gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 3 + case 3: + gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 2 + case 2: + gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 1 + case 1: + gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_UNROLL - res.template storePacketBlock(row + 0, col, acc1); - res.template storePacketBlock(row + accColsC, col, acc2); -#undef MICRO - } - if(remaining_rows > 0) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; - - lhs_ptr += remaining_rows*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows*offsetA; - rhs_ptr += accRows*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < remaining_rows; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; - - Scalarc lhsc; - - lhsc.real(lhs_real); - if(!LhsIsReal) - { - if(ConjugateLhs) - lhsc.imag(-lhs_imag); - else - lhsc.imag(lhs_imag); - } else { - //Lazy approach for now - lhsc.imag((Scalar)0); - } - - for(int acol = 0; acol < accRows; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; - Scalarc rhsc; - - rhsc.real(rhs_real); - if(!RhsIsReal) - { - if(ConjugateRhs) - rhsc.imag(-rhs_imag); - else - rhsc.imag(rhs_imag); - } else { - //Lazy approach for now - rhsc.imag((Scalar)0); - } - res(row + arow, col + acol) += alpha*lhsc*rhsc; - } - } - rhs_ptr += accRows; - lhs_ptr += remaining_rows; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows; - if(!RhsIsReal) - rhs_ptr_imag += accRows; - } - } + if(remaining_rows > 0) + { + gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; - const Scalar *lhs_base = blockA; - Index row = 0; + const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; - for(; row + accCols <= rows; row += accCols) + for(; col < cols; col++) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; - - lhs_ptr += accCols*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += accCols*offsetA; - rhs_ptr += remaining_cols*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols*offsetB; - Scalarc scalarAcc[4][4]; - for(Index arow = 0; arow < 4; arow++ ) - { - for(Index acol = 0; acol < 4; acol++ ) - { - scalarAcc[arow][acol].real((Scalar)0.0); - scalarAcc[arow][acol].imag((Scalar)0.0); - } - } - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < accCols; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) - { - lhs_imag = lhs_ptr_imag[arow]; - - if(ConjugateLhs) - lhs_imag *= -1; - } else { - lhs_imag = (Scalar)0; - } - - for(int acol = 0; acol < remaining_cols; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) - { - rhs_imag = rhs_ptr_imag[acol]; - - if(ConjugateRhs) - rhs_imag *= -1; - } else { - rhs_imag = (Scalar)0; - } - - scalarAcc[arow][acol].real(scalarAcc[arow][acol].real() + lhs_real*rhs_real - lhs_imag*rhs_imag); - scalarAcc[arow][acol].imag(scalarAcc[arow][acol].imag() + lhs_imag*rhs_real + lhs_real*rhs_imag); - } - } - rhs_ptr += remaining_cols; - lhs_ptr += accCols; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols; - if(!LhsIsReal) - lhs_ptr_imag += accCols; - } - for(int arow = 0; arow < accCols; arow++ ) - { - for(int acol = 0; acol < remaining_cols; acol++ ) - { - Scalar accR = scalarAcc[arow][acol].real(); - Scalar accI = scalarAcc[arow][acol].imag(); - Scalar aR = alpha.real(); - Scalar aI = alpha.imag(); - Scalar resR = res(row + arow, col + acol).real(); - Scalar resI = res(row + arow, col + acol).imag(); - - res(row + arow, col + acol).real(resR + accR*aR - accI*aI); - res(row + arow, col + acol).imag(resI + accR*aI + accI*aR); - } - } - } + Index row = 0; - if(remaining_rows > 0) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; - - lhs_ptr += remaining_rows*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows*offsetA; - rhs_ptr += remaining_cols*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) + gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); + + if (remaining_rows > 0) { - for(Index arow = 0; arow < remaining_rows; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; - Scalarc lhsc; - - lhsc.real(lhs_real); - if(!LhsIsReal) - { - if(ConjugateLhs) - lhsc.imag(-lhs_imag); - else - lhsc.imag(lhs_imag); - } else { - lhsc.imag((Scalar)0); - } - - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; - Scalarc rhsc; - - rhsc.real(rhs_real); - if(!RhsIsReal) - { - if(ConjugateRhs) - rhsc.imag(-rhs_imag); - else - rhsc.imag(rhs_imag); - } else { - rhsc.imag((Scalar)0); - } - res(row + arow, col + acol) += alpha*lhsc*rhsc; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols; + gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); } + rhs_base++; } } } +#undef accColsC +#undef advanceCols +#undef advanceRows + /************************************ * ppc64le template specializations * * **********************************/ @@ -2418,7 +2420,7 @@ template ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_pack pack; + dhs_pack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2432,7 +2434,7 @@ template ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_pack pack; + dhs_pack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2446,7 +2448,7 @@ template ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_pack pack; + dhs_pack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2460,7 +2462,7 @@ template ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_pack pack; + dhs_pack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2474,7 +2476,7 @@ template ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_pack pack; + dhs_pack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2488,7 +2490,7 @@ template ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_pack pack; + dhs_pack pack; pack(blockA, lhs, depth, rows, stride, offset); } template @@ -2501,7 +2503,7 @@ template, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_cpack pack; + dhs_cpack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2515,7 +2517,7 @@ template, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_cpack pack; + dhs_cpack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2529,7 +2531,7 @@ template ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_pack pack; + dhs_pack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2543,7 +2545,7 @@ template ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_pack pack; + dhs_pack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2557,7 +2559,7 @@ template, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_cpack pack; + dhs_cpack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2571,7 +2573,7 @@ template, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_cpack pack; + dhs_cpack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2585,7 +2587,7 @@ template, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_cpack pack; + dhs_cpack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2599,7 +2601,7 @@ template, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - lhs_cpack pack; + dhs_cpack pack; pack(blockA, lhs, depth, rows, stride, offset); } @@ -2613,7 +2615,7 @@ template, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_cpack pack; + dhs_cpack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2627,7 +2629,7 @@ template, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - rhs_cpack pack; + dhs_cpack pack; pack(blockB, rhs, depth, cols, stride, offset); } @@ -2687,8 +2689,8 @@ void gebp_kernel, std::complex, Index, DataMapper, mr Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, Index, Index, Index, std::complex, Index, Index, Index, Index); @@ -2726,8 +2728,8 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjugat Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const float*, const std::complex*, Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY @@ -2764,8 +2766,8 @@ void gebp_kernel, float, Index, DataMapper, mr, nr, Conjugat Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const float*, Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY @@ -2839,8 +2841,8 @@ void gebp_kernel, std::complex, Index, DataMapper, Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY @@ -2877,8 +2879,8 @@ void gebp_kernel, double, Index, DataMapper, mr, nr, Conjug Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const double*, Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY @@ -2915,8 +2917,8 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug Index rows, Index depth, Index cols, std::complex alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int accRows = quad_traits::rows; - const int accCols = quad_traits::size; + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const double*, const std::complex*, Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index a1799c061..024767868 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -1,3 +1,10 @@ +//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines +#ifdef EIGEN_POWER_USE_PREFETCH +#define EIGEN_POWER_PREFETCH(p) prefetch(p) +#else +#define EIGEN_POWER_PREFETCH(p) +#endif + namespace Eigen { namespace internal { @@ -5,8 +12,8 @@ namespace internal { template EIGEN_STRONG_INLINE void gemm_extra_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -16,16 +23,17 @@ EIGEN_STRONG_INLINE void gemm_extra_col( Index remaining_cols, const Packet& pAlpha); -template +template EIGEN_STRONG_INLINE void gemm_extra_row( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, Index row, Index col, + Index rows, Index cols, Index remaining_rows, const Packet& pAlpha, @@ -34,8 +42,8 @@ EIGEN_STRONG_INLINE void gemm_extra_row( template EIGEN_STRONG_INLINE void gemm_unrolled_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -48,6 +56,71 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col( template EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows); +template +EIGEN_STRONG_INLINE void gemm_complex_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask); + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template +EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar* lhs); + +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); + +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); + +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha); + +template +EIGEN_STRONG_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag); + const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, @@ -68,7 +141,7 @@ const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14 // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. template -EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) { acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); @@ -79,6 +152,12 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + bcouple_common(taccReal, taccImag, acc1, acc2); acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); @@ -91,8 +170,26 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock(tRes.packet[7], acc2.packet[3]); } +template +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); +} + +template +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + bcouple_common(taccReal, taccImag, acc1, acc2); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + + acc2.packet[0] = padd(tRes.packet[1], acc2.packet[0]); +} + template<> -EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) { acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); @@ -103,23 +200,21 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock& t acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); +} - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); +template<> +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); } // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. template -EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs) +EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar* rhs) { - return *((Packet *)rhs); + return *((Packet *)rhs); } } // end namespace internal diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index 64f11727f..8edf79c4b 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) +// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com) // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -22,149 +23,41 @@ namespace Eigen { namespace internal { -const static Packet16uc MMA_p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, - 16, 17, 18, 19, - 4, 5, 6, 7, - 20, 21, 22, 23}; - -const static Packet16uc MMA_p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, - 24, 25, 26, 27, - 12, 13, 14, 15, - 28, 29, 30, 31}; -//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 -const static Packet16uc MMA_p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; - -//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 -const static Packet16uc MMA_p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; - - -// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. -template -EIGEN_STRONG_INLINE void bcoupleMMA(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX32_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX32_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX32_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX32_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX32_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX32_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX32_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX32_SECOND); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); - - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); -} - -template<> -EIGEN_STRONG_INLINE void bcoupleMMA(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX64_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX64_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX64_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX64_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], MMA_p16uc_SETCOMPLEX64_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], MMA_p16uc_SETCOMPLEX64_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], MMA_p16uc_SETCOMPLEX64_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], MMA_p16uc_SETCOMPLEX64_SECOND); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); - - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); -} - template -EIGEN_STRONG_INLINE Packet ploadLhsMMA(const Scalar *lhs) -{ - return *((Packet *)lhs); -} - -template -EIGEN_STRONG_INLINE PacketBlock pmul(const PacketBlock& a, const Packet& b) -{ - PacketBlock pb; - pb.packet[0] = a.packet[0]*b; - pb.packet[1] = a.packet[1]*b; - return pb; -} - -template -EIGEN_STRONG_INLINE void bsetzeroMMA(__vector_quad *acc) +EIGEN_STRONG_INLINE void bsetzeroMMA(__vector_quad* acc) { __builtin_mma_xxsetaccz(acc); } -template -EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad *acc) +template +EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc) { PacketBlock result; __builtin_mma_disassemble_acc(&result.packet, acc); - result.packet[0] = pmadd(alpha, result.packet[0], data.template loadPacket(i, j + 0)); - result.packet[1] = pmadd(alpha, result.packet[1], data.template loadPacket(i, j + 1)); - result.packet[2] = pmadd(alpha, result.packet[2], data.template loadPacket(i, j + 2)); - result.packet[3] = pmadd(alpha, result.packet[3], data.template loadPacket(i, j + 3)); + PacketBlock tRes; + bload(tRes, data, i, j); + + bscale(tRes, result, alpha); - data.template storePacketBlock(i, j, result); + data.template storePacketBlock(i, j, tRes); } -template -EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad *accReal, __vector_quad *accImag, const int accColsC) +template +EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag) { PacketBlock resultReal, resultImag; __builtin_mma_disassemble_acc(&resultReal.packet, accReal); __builtin_mma_disassemble_acc(&resultImag.packet, accImag); - PacketBlock taccReal, taccImag; - taccReal.packet[0] = pmul(resultReal.packet[0], alphaReal); - taccReal.packet[1] = pmul(resultReal.packet[1], alphaReal); - taccReal.packet[2] = pmul(resultReal.packet[2], alphaReal); - taccReal.packet[3] = pmul(resultReal.packet[3], alphaReal); - - taccImag.packet[0] = pmul(resultImag.packet[0], alphaReal); - taccImag.packet[1] = pmul(resultImag.packet[1], alphaReal); - taccImag.packet[2] = pmul(resultImag.packet[2], alphaReal); - taccImag.packet[3] = pmul(resultImag.packet[3], alphaReal); - - taccReal.packet[0] = psub(taccReal.packet[0], pmul(resultImag.packet[0], alphaImag)); - taccReal.packet[1] = psub(taccReal.packet[1], pmul(resultImag.packet[1], alphaImag)); - taccReal.packet[2] = psub(taccReal.packet[2], pmul(resultImag.packet[2], alphaImag)); - taccReal.packet[3] = psub(taccReal.packet[3], pmul(resultImag.packet[3], alphaImag)); - - taccImag.packet[0] = pmadd(resultReal.packet[0], alphaImag, taccImag.packet[0]); - taccImag.packet[1] = pmadd(resultReal.packet[1], alphaImag, taccImag.packet[1]); - taccImag.packet[2] = pmadd(resultReal.packet[2], alphaImag, taccImag.packet[2]); - taccImag.packet[3] = pmadd(resultReal.packet[3], alphaImag, taccImag.packet[3]); - PacketBlock tRes; - tRes.packet[0] = data.template loadPacket(i + N*accColsC, j + 0); - tRes.packet[1] = data.template loadPacket(i + N*accColsC, j + 1); - tRes.packet[2] = data.template loadPacket(i + N*accColsC, j + 2); - tRes.packet[3] = data.template loadPacket(i + N*accColsC, j + 3); + bload(tRes, data, i, j); - tRes.packet[4] = data.template loadPacket(i + (N+1)*accColsC, j + 0); - tRes.packet[5] = data.template loadPacket(i + (N+1)*accColsC, j + 1); - tRes.packet[6] = data.template loadPacket(i + (N+1)*accColsC, j + 2); - tRes.packet[7] = data.template loadPacket(i + (N+1)*accColsC, j + 3); + PacketBlock taccReal, taccImag; + bscalec(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag); PacketBlock acc1, acc2; - bcoupleMMA(taccReal, taccImag, tRes, acc1, acc2); + bcouple(taccReal, taccImag, tRes, acc1, acc2); data.template storePacketBlock(i + N*accColsC, j, acc1); data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); @@ -172,7 +65,7 @@ EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMap // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments template -EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket& a, const LhsPacket& b) +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b) { if(NegativeAccumulate) { @@ -182,110 +75,148 @@ EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket& a, const L } } -template<> -EIGEN_STRONG_INLINE void pgerMMA, false>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +template +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock& a, const Packet2d& b) { - __vector_pair *a0 = (__vector_pair *)(&a.packet[0]); - __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b); + __vector_pair* a0 = (__vector_pair *)(&a.packet[0]); + if(NegativeAccumulate) + { + __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b); + } else { + __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b); + } } -template<> -EIGEN_STRONG_INLINE void pgerMMA, true>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +template +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b) { - __vector_pair *a0 = (__vector_pair *)(&a.packet[0]); - __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b); + if(NegativeAccumulate) + { + __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b); + } else { + __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b); + } } -template<> -EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet2d& b) +template +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet4f& b) { - __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b); + EIGEN_UNUSED_VARIABLE(acc); // Just for compilation + EIGEN_UNUSED_VARIABLE(a); + EIGEN_UNUSED_VARIABLE(b); } -template<> -EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet2d& b) +template +EIGEN_STRONG_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi) { - __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b); + pgerMMA(accReal, rhsV, lhsV); + if(LhsIsReal) { + pgerMMA(accImag, rhsVi, lhsV); + } else { + if(!RhsIsReal) { + pgerMMA(accReal, rhsVi, lhsVi); + pgerMMA(accImag, rhsVi, lhsV); + } else { + EIGEN_UNUSED_VARIABLE(rhsVi); + } + pgerMMA(accImag, rhsV, lhsVi); + } } -template<> -EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet4f& b) +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_STRONG_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) { - // Just for compilation - EIGEN_UNUSED_VARIABLE(acc) - EIGEN_UNUSED_VARIABLE(a) - EIGEN_UNUSED_VARIABLE(b) -} + rhsV = ploadRhs((const Scalar*)(rhs)); +} template<> -EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet4f& b) +EIGEN_STRONG_INLINE void ploadRhsMMA >(const double* rhs, PacketBlock& rhsV) { - // Just for compilation - EIGEN_UNUSED_VARIABLE(acc) - EIGEN_UNUSED_VARIABLE(a) - EIGEN_UNUSED_VARIABLE(b) + rhsV.packet[0] = ploadRhs((const double *)((Packet2d *)rhs )); + rhsV.packet[1] = ploadRhs((const double *)(((Packet2d *)rhs) + 1)); } -// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. -template -EIGEN_STRONG_INLINE void ploadRhsMMA(const Scalar *rhs, Packet &rhsV) -{ - rhsV = *((Packet *)rhs); -} - template<> -EIGEN_STRONG_INLINE void ploadRhsMMA >(const double *rhs, PacketBlock &rhsV) +EIGEN_STRONG_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) { - rhsV.packet[0] = *((Packet2d *)rhs ); - rhsV.packet[1] = *(((Packet2d *)rhs) + 1); +#if EIGEN_COMP_LLVM + __builtin_vsx_assemble_pair(&rhsV, + (__vector unsigned char)(ploadRhs((const double *)(((Packet2d *)rhs) + 1))), + (__vector unsigned char)(ploadRhs((const double *)((Packet2d *)rhs )))); +#else + __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs)); +#endif } template<> -EIGEN_STRONG_INLINE void ploadRhsMMA(const double *rhs, __vector_pair &rhsV) +EIGEN_STRONG_INLINE void ploadRhsMMA(const float* rhs, __vector_pair& rhsV) { - __builtin_vsx_assemble_pair(&rhsV, (__vector unsigned char)(*(((Packet2d *)rhs) + 1)), (__vector unsigned char)(*((Packet2d *)rhs))); + // Just for compilation + EIGEN_UNUSED_VARIABLE(rhs); + EIGEN_UNUSED_VARIABLE(rhsV); } -#define MICRO_MMA_DST \ - __vector_quad *accZero0, __vector_quad *accZero1, __vector_quad *accZero2, \ - __vector_quad *accZero3, __vector_quad *accZero4, __vector_quad *accZero5, \ - __vector_quad *accZero6, __vector_quad *accZero7 +// PEEL_MMA loop factor. +#define PEEL_MMA 7 -#define MICRO_MMA_SRC \ - const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \ - const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \ - const Scalar **lhs_ptr6, const Scalar **lhs_ptr7 +#define MICRO_MMA_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) -#define MICRO_MMA_ONE \ - if (sizeof(Scalar) == sizeof(float)) { \ - MICRO_MMA(\ - &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ - rhs_ptr, \ - &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); \ +#define MICRO_MMA_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr##iter); \ + lhs_ptr##iter += accCols; \ } else { \ - MICRO_MMA(\ - &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ - rhs_ptr, \ - &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ } -#define MICRO_MMA_WORK_ONE(iter) \ - if (N > iter) { \ - Packet lhsV = ploadLhsMMA(*lhs_ptr##iter); \ - pgerMMA(accZero##iter, rhsV, lhsV); \ - *lhs_ptr##iter += accCols; \ +#define MICRO_MMA_WORK_ONE(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgerMMA(&accZero##iter, rhsV##peel, lhsV##iter); \ + } + +#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \ + if (PEEL_MMA > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \ + MICRO_MMA_UNROLL(func2); \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ + func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \ } else { \ - EIGEN_UNUSED_VARIABLE(accZero##iter); \ - EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ } -#define MICRO_MMA_UNROLL(func) \ - func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) +#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ + type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ + MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9); + +#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \ + type rhsV0; \ + MICRO_MMA_TYPE_PEEL(func,func2,type,0); -#define MICRO_MMA_WORK MICRO_MMA_UNROLL(MICRO_MMA_WORK_ONE) +#define MICRO_MMA_ONE_PEEL \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr += (accRows * PEEL_MMA); + +#define MICRO_MMA_ONE \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr += accRows; #define MICRO_MMA_DST_PTR_ONE(iter) \ - if (unroll_factor > iter){ \ + if (unroll_factor > iter) { \ bsetzeroMMA(&accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ @@ -303,39 +234,24 @@ EIGEN_STRONG_INLINE void ploadRhsMMA(const double *rhs, _ #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE) #define MICRO_MMA_PREFETCH_ONE(iter) \ - if (unroll_factor > iter){ \ - prefetch(lhs_ptr##iter); \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ } #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE) #define MICRO_MMA_STORE_ONE(iter) \ - if (unroll_factor > iter){ \ - storeAccumulator(row + iter*accCols, col, res, pAlpha, &accZero##iter); \ + if (unroll_factor > iter) { \ + storeAccumulator(row + iter*accCols, col, res, pAlpha, &accZero##iter); \ } #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) -// PEEL_MMA loop factor. -#define PEEL_MMA 10 - -template -EIGEN_STRONG_INLINE void MICRO_MMA( - MICRO_MMA_SRC, - const Scalar* &rhs_ptr, - MICRO_MMA_DST) - { - RhsPacket rhsV; - ploadRhsMMA(rhs_ptr, rhsV); - MICRO_MMA_WORK - rhs_ptr += accRows; - } - template EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -343,22 +259,20 @@ EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( Index col, const Packet& pAlpha) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; +asm("#gemm_MMA begin"); + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0, * lhs_ptr1, * lhs_ptr2, * lhs_ptr3, * lhs_ptr4, * lhs_ptr5, * lhs_ptr6, * lhs_ptr7; __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; - asm("#unrolled MMA start"); MICRO_MMA_SRC_PTR MICRO_MMA_DST_PTR Index k = 0; for(; k + PEEL_MMA <= depth; k+= PEEL_MMA) { - prefetch(rhs_ptr); + EIGEN_POWER_PREFETCH(rhs_ptr); MICRO_MMA_PREFETCH - for (int l = 0; l < PEEL_MMA; l++) { - MICRO_MMA_ONE - } + MICRO_MMA_ONE_PEEL } for(; k < depth; k++) { @@ -367,7 +281,7 @@ EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( MICRO_MMA_STORE row += unroll_factor*accCols; - asm("#unrolled MMA end"); +asm("#gemm_MMA end"); } template @@ -385,15 +299,15 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB; - const Scalar *lhs_base = blockA; + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; Index row = 0; #define MAX_MMA_UNROLL 7 - while(row + MAX_MMA_UNROLL*accCols <= rows){ + while(row + MAX_MMA_UNROLL*accCols <= rows) { gemm_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); } - switch( (rows-row)/accCols ){ + switch( (rows-row)/accCols ) { #if MAX_MMA_UNROLL > 7 case 7: gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); @@ -436,334 +350,288 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, if(remaining_rows > 0) { - gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask); + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); } - } - - if(remaining_cols > 0) - { - const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB; - const Scalar *lhs_base = blockA; + } - for(; col < cols; col++) + if(remaining_cols > 0) { - Index row = 0; - - gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); + const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; - if (remaining_rows > 0) + for(; col < cols; col++) { - gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); + Index row = 0; + + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); + + if (remaining_rows > 0) + { + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); + } + rhs_base++; } - rhs_base++; } - } } -template -void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, - Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +#define accColsC (accCols / 2) +#define advanceRows ((LhsIsReal) ? 1 : 2) +#define advanceCols ((RhsIsReal) ? 1 : 2) + +// PEEL_COMPLEX_MMA loop factor. +#define PEEL_COMPLEX_MMA 7 + +#define MICRO_COMPLEX_MMA_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) + +#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ + lhs_ptr_real##iter += accCols; \ + if(!LhsIsReal) { \ + lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ + lhs_ptr_imag##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } + +#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \ + if (PEEL_COMPLEX_MMA > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \ + if(!RhsIsReal) { \ + ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_MMA_UNROLL(func2); \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ + type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ + type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9); + +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \ + type rhsV0, rhsVi0; \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); + +#define MICRO_COMPLEX_MMA_ONE_PEEL \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA); + +#define MICRO_COMPLEX_MMA_ONE \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr_real += accRows; \ + if(!RhsIsReal) rhs_ptr_imag += accRows; + +#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzeroMMA(&accReal##iter); \ + bsetzeroMMA(&accImag##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accReal##iter); \ + EIGEN_UNUSED_VARIABLE(accImag##iter); \ + } + +#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE) + +#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ + if(!LhsIsReal) { \ + lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } + +#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE) + +#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ + if(!LhsIsReal) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ + } \ + } + +#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE) + +#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + storeComplexAccumulator(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \ + } + +#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + const Packet& pAlphaReal, + const Packet& pAlphaImag) { - const int remaining_rows = rows % accCols; - const int remaining_cols = cols % accRows; - const int accColsC = accCols / 2; - int advanceCols = 2; - int advanceRows = 2; +asm("#gemm_complex_MMA begin"); + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + accRows*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0, * lhs_ptr_imag0, * lhs_ptr_real1, * lhs_ptr_imag1; + const Scalar* lhs_ptr_real2, * lhs_ptr_imag2, * lhs_ptr_real3, * lhs_ptr_imag3; + const Scalar* lhs_ptr_real4, * lhs_ptr_imag4; + __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4; - if(LhsIsReal) advanceRows = 1; - if(RhsIsReal) advanceCols = 1; + MICRO_COMPLEX_MMA_SRC_PTR + MICRO_COMPLEX_MMA_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_MMA_PREFETCH + MICRO_COMPLEX_MMA_ONE_PEEL + } + for(; k < depth; k++) + { + MICRO_COMPLEX_MMA_ONE + } + MICRO_COMPLEX_MMA_STORE + + row += unroll_factor*accCols; +asm("#gemm_complex_MMA end"); +} + +template +void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; const Packet pAlphaReal = pset1(alpha.real()); const Packet pAlphaImag = pset1(alpha.imag()); + const Packet pMask = bmask((const int)(remaining_rows)); - const Scalar *blockA = (Scalar *) blockAc; - const Scalar *blockB = (Scalar *) blockBc; - - Packet conj = pset1((Scalar)-1.0f); + const Scalar* blockA = (Scalar *) blockAc; + const Scalar* blockB = (Scalar *) blockBc; Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; - const Scalar *lhs_base = blockA; - + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; Index row = 0; - for(; row + accCols <= rows; row += accCols) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; - - __vector_quad accReal, accImag; - __builtin_mma_xxsetaccz(&accReal); - __builtin_mma_xxsetaccz(&accImag); - - lhs_ptr += accCols*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += accCols*offsetA; - rhs_ptr += accRows*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - Packet lhsV = ploadLhsMMA(lhs_ptr); - RhsPacket rhsV = ploadRhs(rhs_ptr); - - Packet lhsVi = ploadLhsMMA(lhs_ptr_imag); - RhsPacket rhsVi = ploadRhs(rhs_ptr_imag); - - if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi, conj); - if(ConjugateRhs && !RhsIsReal) rhsVi = pmul(rhsVi, conj); - - if(LhsIsReal) - { - pgerMMA(&accReal, rhsV, lhsV); - pgerMMA(&accImag, rhsVi, lhsV); - } else if(RhsIsReal) { - pgerMMA(&accReal, rhsV, lhsV); - pgerMMA(&accImag, rhsV, lhsVi); - } else { - pgerMMA(&accReal, rhsV, lhsV); - pgerMMA(&accReal, rhsVi, lhsVi); - pgerMMA(&accImag, rhsVi, lhsV); - pgerMMA(&accImag, rhsV, lhsVi); - } - - lhs_ptr += accCols; - rhs_ptr += accRows; - if(!LhsIsReal) - lhs_ptr_imag += accCols; - if(!RhsIsReal) - rhs_ptr_imag += accRows; - } - - storeComplexAccumulator(row, col, res, pAlphaReal, pAlphaImag, &accReal, &accImag, accColsC); +#define MAX_COMPLEX_MMA_UNROLL 4 + while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { + gemm_complex_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_MMA_UNROLL > 4 + case 4: + gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 3 + case 3: + gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 2 + case 2: + gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 1 + case 1: + gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; } +#undef MAX_COMPLEX_MMA_UNROLL - if(remaining_rows > 0) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; - - lhs_ptr += remaining_rows*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows*offsetA; - rhs_ptr += accRows*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < remaining_rows; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; - - Scalarc lhsc; - - lhsc.real(lhs_real); - if(!LhsIsReal) - { - if(ConjugateLhs) - lhsc.imag(-lhs_imag); - else - lhsc.imag(lhs_imag); - } else { - //Lazy approach for now - lhsc.imag((Scalar)0); - } - - for(int acol = 0; acol < accRows; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; - Scalarc rhsc; - - rhsc.real(rhs_real); - if(!RhsIsReal) - { - if(ConjugateRhs) - rhsc.imag(-rhs_imag); - else - rhsc.imag(rhs_imag); - } else { - //Lazy approach for now - rhsc.imag((Scalar)0); - } - res(row + arow, col + acol) += alpha*lhsc*rhsc; - } - } - rhs_ptr += accRows; - lhs_ptr += remaining_rows; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows; - if(!RhsIsReal) - rhs_ptr_imag += accRows; - } - } + if(remaining_rows > 0) + { + gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; - const Scalar *lhs_base = blockA; - Index row = 0; + const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; - for(; row + accCols <= rows; row += accCols) + for(; col < cols; col++) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; - - lhs_ptr += accCols*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += accCols*offsetA; - rhs_ptr += remaining_cols*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols*offsetB; - Scalarc scalarAcc[4][4]; - for(Index arow = 0; arow < 4; arow++ ) - { - for(Index acol = 0; acol < 4; acol++ ) - { - scalarAcc[arow][acol].real((Scalar)0.0f); - scalarAcc[arow][acol].imag((Scalar)0.0f); - } - } - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < accCols; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) - { - lhs_imag = lhs_ptr_imag[arow]; - - if(ConjugateLhs) - lhs_imag *= -1; - } else { - lhs_imag = (Scalar)0; - } - - for(int acol = 0; acol < remaining_cols; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) - { - rhs_imag = rhs_ptr_imag[acol]; - - if(ConjugateRhs) - rhs_imag *= -1; - } else { - rhs_imag = (Scalar)0; - } - - scalarAcc[arow][acol].real(scalarAcc[arow][acol].real() + lhs_real*rhs_real - lhs_imag*rhs_imag); - scalarAcc[arow][acol].imag(scalarAcc[arow][acol].imag() + lhs_imag*rhs_real + lhs_real*rhs_imag); - } - } - rhs_ptr += remaining_cols; - lhs_ptr += accCols; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols; - if(!LhsIsReal) - lhs_ptr_imag += accCols; - } - for(int arow = 0; arow < accCols; arow++ ) - { - for(int acol = 0; acol < remaining_cols; acol++ ) - { - Scalar accR = scalarAcc[arow][acol].real(); - Scalar accI = scalarAcc[arow][acol].imag(); - Scalar aR = alpha.real(); - Scalar aI = alpha.imag(); - Scalar resR = res(row + arow, col + acol).real(); - Scalar resI = res(row + arow, col + acol).imag(); - - res(row + arow, col + acol).real(resR + accR*aR - accI*aI); - res(row + arow, col + acol).imag(resI + accR*aI + accI*aR); - } - } - } + Index row = 0; - if(remaining_rows > 0) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; - const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; - const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; - - lhs_ptr += remaining_rows*offsetA; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows*offsetA; - rhs_ptr += remaining_cols*offsetB; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) + gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); + + if (remaining_rows > 0) { - for(Index arow = 0; arow < remaining_rows; arow++) - { - Scalar lhs_real = lhs_ptr[arow]; - Scalar lhs_imag; - if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; - Scalarc lhsc; - - lhsc.real(lhs_real); - if(!LhsIsReal) - { - if(ConjugateLhs) - lhsc.imag(-lhs_imag); - else - lhsc.imag(lhs_imag); - } else { - lhsc.imag((Scalar)0); - } - - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - Scalar rhs_real = rhs_ptr[acol]; - Scalar rhs_imag; - if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; - Scalarc rhsc; - - rhsc.real(rhs_real); - if(!RhsIsReal) - { - if(ConjugateRhs) - rhsc.imag(-rhs_imag); - else - rhsc.imag(rhs_imag); - } else { - rhsc.imag((Scalar)0); - } - res(row + arow, col + acol) += alpha*lhsc*rhsc; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; - if(!LhsIsReal) - lhs_ptr_imag += remaining_rows; - if(!RhsIsReal) - rhs_ptr_imag += remaining_cols; + gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); } + rhs_base++; } } } +#undef accColsC +#undef advanceRows +#undef advanceCols + #pragma GCC reset_options } // end namespace internal } // end namespace Eigen + #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H -- cgit v1.2.3