From 6fe88a3c9db27c00a3817e391cf70116451bf046 Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Wed, 20 May 2020 14:01:02 -0300 Subject: MatrixProuct enhancements: - Changes to Altivec/MatrixProduct Adapting code to gcc 10. Generic code style and performance enhancements. Adding PanelMode support. Adding stride/offset support. Enabling float64, std::complex and std::complex. Fixing lack of symm_pack. Enabling mixedtypes. - Adding std::complex tests to blasutil. - Adding an implementation of storePacketBlock when Incr!= 1. --- Eigen/Core | 2 +- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 3160 +++++++++++++++++++++++++-- Eigen/src/Core/util/BlasUtil.h | 71 + test/blasutil.cpp | 2 + 4 files changed, 2996 insertions(+), 239 deletions(-) diff --git a/Eigen/Core b/Eigen/Core index f44b77831..7d1bdd6e8 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -335,7 +335,7 @@ using std::ptrdiff_t; #include "src/Core/CoreIterators.h" #include "src/Core/ConditionEstimator.h" -#if EIGEN_ARCH_PPC +#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) #include "src/Core/arch/AltiVec/MatrixProduct.h" #endif diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 3bfbfdc87..57227e23b 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -10,31 +10,2642 @@ #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H +/************************************************************************************************** + * TODO * + * - Check StorageOrder on lhs_pack (the innermost second loop seems unvectorized when it could). * + * - Check the possibility of transposing as GETREAL and GETIMAG when needed. * + * - Check if change conjugation to xor instead of mul gains any performance. * + * - Remove IsComplex template argument from complex packing. * + **************************************************************************************************/ +namespace Eigen { + +namespace internal { + +/************************** + * Constants and typedefs * + **************************/ +const int QuadRegisterCount = 8; + +#ifdef __MMA__ + +template +union Packetx2u +{ + __vector_pair vectorpair; + PacketBlock pair; +}; + +#endif + + +template +struct quad_traits +{ + typedef typename packet_traits::type vectortype; + typedef PacketBlock type; + typedef vectortype rhstype; + enum + { + vectorsize = packet_traits::size, + size = 4, + rows = 4 + }; +}; + +template<> +struct quad_traits +{ + typedef Packet2d vectortype; + typedef PacketBlock type; + typedef PacketBlock rhstype; + enum + { + vectorsize = packet_traits::size, + size = 2, + rows = 4 + }; +}; + +// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out +// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then +// are responsible to extract from convert between Eigen's and MatrixProduct approach. +const static Packet4f p4f_CONJUGATE = {-1.0f, -1.0f, -1.0f, -1.0f}; + +const static Packet2d p2d_CONJUGATE = {-1.0f, -1.0f}; + +const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3, + 8, 9, 10, 11, + 16, 17, 18, 19, + 24, 25, 26, 27}; + +const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, + 12, 13, 14, 15, + 20, 21, 22, 23, + 28, 29, 30, 31}; + +const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, + 16, 17, 18, 19, + 4, 5, 6, 7, + 20, 21, 22, 23}; + +const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, + 24, 25, 26, 27, + 12, 13, 14, 15, + 28, 29, 30, 31}; +//[a,ai],[b,bi] = [a,b] +const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,ai],[b,bi] = [ai,bi] +const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + +//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 +const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 +const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + +/********************************************* + * Single precision real and complex packing * + * *******************************************/ + +/** + * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves + * the diagonal real, whatever is below it is copied from the respective upper diagonal element and + * conjugated. There's no PanelMode available for symm packing. + * + * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using + * it's respective rank-update instructions. The float32/64 versions are different because at this moment + * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements. + * + * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has + * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main + * reason why packing for complex is broken down into several different parts, also the reason why we endup having a + * float32/64 and complex float32/64 version. + **/ +template +EIGEN_STRONG_INLINE std::complex getAdjointVal(Index i, Index j, const_blas_data_mapper, Index, StorageOrder>& dt) +{ + std::complex v; + if(i < j) + { + v.real(dt(j,i).real()); + v.imag(-dt(j,i).imag()); + } else if(i > j) + { + v.real(dt(i,j).real()); + v.imag(dt(i,j).imag()); + } else { + v.real(dt(i,j).real()); + v.imag((Scalar)0.0f); + } + return v; +} + +template +EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex *blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +{ + const Index depth = k2 + rows; + const_blas_data_mapper, Index, StorageOrder> rhs(_rhs, rhsStride); + const int vectorSize = N*quad_traits::vectorsize; + Scalar* blockBf = reinterpret_cast(blockB); + + Index ri = 0, j = 0; + for(; j + vectorSize < cols; j+=vectorSize) + { + Index i = k2; + for(; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(i, j + k, rhs); + blockBf[ri + k] = v.real(); + } + ri += vectorSize; + } + + i = k2; + + for(; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(i, j + k, rhs); + blockBf[ri + k] = v.imag(); + } + ri += vectorSize; + } + } + for(Index i = k2; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + std::complex v = getAdjointVal(i, k, rhs); + blockBf[ri] = v.real(); + ri += 1; + } + } + for(Index i = k2; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + std::complex v = getAdjointVal(i, k, rhs); + blockBf[ri] = v.imag(); + ri += 1; + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex *blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) +{ + const Index depth = cols; + const_blas_data_mapper, Index, StorageOrder> lhs(_lhs, lhsStride); + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + Scalar *blockAf = (Scalar *)(blockA); + + for(; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + for(; i < depth; i++) + { + for(int k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(j+k, i, lhs); + blockAf[ri + k] = v.real(); + } + ri += vectorSize; + } + i = 0; + for(; i < depth; i++) + { + for(int k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(j+k, i, lhs); + blockAf[ri + k] = v.imag(); + } + ri += vectorSize; + } + } + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + std::complex v = getAdjointVal(k, i, lhs); + blockAf[ri] = v.real(); + ri += 1; + } + } + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + std::complex v = getAdjointVal(k, i, lhs); + blockAf[ri] = v.imag(); + ri += 1; + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +{ + const Index depth = k2 + rows; + const_blas_data_mapper rhs(_rhs, rhsStride); + const int vectorSize = quad_traits::vectorsize; + + Index ri = 0, j = 0; + for(; j + N*vectorSize < cols; j+=N*vectorSize) + { + Index i = k2; + for(; i < depth; i++) + { + for(int k = 0; k < N*vectorSize; k++) + { + if(i <= j+k) + blockB[ri + k] = rhs(j+k, i); + else + blockB[ri + k] = rhs(i, j+k); + } + ri += N*vectorSize; + } + } + for(Index i = k2; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + if(k <= i) + blockB[ri] = rhs(i, k); + else + blockB[ri] = rhs(k, i); + ri += 1; + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) +{ + const Index depth = cols; + const_blas_data_mapper lhs(_lhs, lhsStride); + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(j = 0; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + for(; i < depth; i++) + { + for(int k = 0; k < vectorSize; k++) + { + if(i <= j+k) + blockA[ri + k] = lhs(j+k, i); + else + blockA[ri + k] = lhs(i, j+k); + } + ri += vectorSize; + } + } + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if(i <= k) + blockA[ri] = lhs(k, i); + else + blockA[ri] = lhs(i, k); + ri += 1; + } + } +} + +template +struct symm_pack_rhs, Index, nr, StorageOrder> +{ + void operator()(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_complex_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs, Index, Pack1, Pack2_dummy, StorageOrder> +{ + void operator()(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack std::complex *********** + +template +struct symm_pack_rhs, Index, nr, StorageOrder> +{ + void operator()(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_complex_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs, Index, Pack1, Pack2_dummy, StorageOrder> +{ + void operator()(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack float32 *********** +template +struct symm_pack_rhs +{ + void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs +{ + void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack float64 *********** +template +struct symm_pack_rhs +{ + void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs +{ + void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +/** + * PanelMode + * Packing might be called several times before being multiplied by gebp_kernel, this happens because + * on special occasions it fills part of block with other parts of the matrix. Two variables control + * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever + * is going to be the real offset and stride in the future and this is what you should obey. The process + * is to behave as you would with normal packing but leave the start of each part with the correct offset + * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride + * and offset and behaves accordingly. + **/ + +// General template for lhs complex packing. +template +struct lhs_cpack { + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + Scalar *blockAt = reinterpret_cast(blockA); + Packet conj = pset1((Scalar)-1.0f); + + for(j = 0; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs.template loadPacket(j, i + 0); + cblock.packet[1] = lhs.template loadPacket(j, i + 1); + cblock.packet[2] = lhs.template loadPacket(j, i + 2); + cblock.packet[3] = lhs.template loadPacket(j, i + 3); + + cblock.packet[4] = lhs.template loadPacket(j + 2, i + 0); + cblock.packet[5] = lhs.template loadPacket(j + 2, i + 1); + cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); + cblock.packet[7] = lhs.template loadPacket(j + 2, i + 3); + } else { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); + cblock.packet[1] = lhs.template loadPacket(j + 1, i); + cblock.packet[2] = lhs.template loadPacket(j + 2, i); + cblock.packet[3] = lhs.template loadPacket(j + 3, i); + + cblock.packet[4] = lhs.template loadPacket(j + 0, i + 2); + cblock.packet[5] = lhs.template loadPacket(j + 1, i + 2); + cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); + cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); + } + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + + if(StorageOrder == RowMajor) ptranspose(block); + + pstore(blockAt + ri , block.packet[0]); + pstore(blockAt + ri + 4, block.packet[1]); + pstore(blockAt + ri + 8, block.packet[2]); + pstore(blockAt + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + blockAt[ri + 0] = lhs(j + 0, i).real(); + blockAt[ri + 1] = lhs(j + 1, i).real(); + blockAt[ri + 2] = lhs(j + 2, i).real(); + blockAt[ri + 3] = lhs(j + 3, i).real(); + + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + + i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs.template loadPacket(j, i + 0); + cblock.packet[1] = lhs.template loadPacket(j, i + 1); + cblock.packet[2] = lhs.template loadPacket(j, i + 2); + cblock.packet[3] = lhs.template loadPacket(j, i + 3); + + cblock.packet[4] = lhs.template loadPacket(j + 2, i + 0); + cblock.packet[5] = lhs.template loadPacket(j + 2, i + 1); + cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); + cblock.packet[7] = lhs.template loadPacket(j + 2, i + 3); + } else { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); + cblock.packet[1] = lhs.template loadPacket(j + 1, i); + cblock.packet[2] = lhs.template loadPacket(j + 2, i); + cblock.packet[3] = lhs.template loadPacket(j + 3, i); + + cblock.packet[4] = lhs.template loadPacket(j + 0, i + 2); + cblock.packet[5] = lhs.template loadPacket(j + 1, i + 2); + cblock.packet[6] = lhs.template loadPacket(j + 2, i + 2); + cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); + } + + PacketBlock block; + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + + if(Conjugate) + { + block.packet[0] *= conj; + block.packet[1] *= conj; + block.packet[2] *= conj; + block.packet[3] *= conj; + } + + if(StorageOrder == RowMajor) ptranspose(block); + + pstore(blockAt + ri , block.packet[0]); + pstore(blockAt + ri + 4, block.packet[1]); + pstore(blockAt + ri + 8, block.packet[2]); + pstore(blockAt + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(Conjugate) + { + blockAt[ri + 0] = -lhs(j + 0, i).imag(); + blockAt[ri + 1] = -lhs(j + 1, i).imag(); + blockAt[ri + 2] = -lhs(j + 2, i).imag(); + blockAt[ri + 3] = -lhs(j + 3, i).imag(); + } else { + blockAt[ri + 0] = lhs(j + 0, i).imag(); + blockAt[ri + 1] = lhs(j + 1, i).imag(); + blockAt[ri + 2] = lhs(j + 2, i).imag(); + blockAt[ri + 3] = lhs(j + 3, i).imag(); + } + + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockAt[ri] = lhs(k, i).real(); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if(Conjugate) + blockAt[ri] = -lhs(k, i).imag(); + else + blockAt[ri] = lhs(k, i).imag(); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + } +}; + +// General template for lhs packing. +template +struct lhs_pack{ + EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(j = 0; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + + if(StorageOrder == RowMajor) + { + block.packet[0] = lhs.template loadPacket(j + 0, i); + block.packet[1] = lhs.template loadPacket(j + 1, i); + block.packet[2] = lhs.template loadPacket(j + 2, i); + block.packet[3] = lhs.template loadPacket(j + 3, i); + + ptranspose(block); + } else { + block.packet[0] = lhs.template loadPacket(j, i + 0); + block.packet[1] = lhs.template loadPacket(j, i + 1); + block.packet[2] = lhs.template loadPacket(j, i + 2); + block.packet[3] = lhs.template loadPacket(j, i + 3); + } + + pstore(blockA + ri , block.packet[0]); + pstore(blockA + ri + 4, block.packet[1]); + pstore(blockA + ri + 8, block.packet[2]); + pstore(blockA + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == RowMajor) + { + blockA[ri+0] = lhs(j+0, i); + blockA[ri+1] = lhs(j+1, i); + blockA[ri+2] = lhs(j+2, i); + blockA[ri+3] = lhs(j+3, i); + } else { + Packet lhsV = lhs.template loadPacket(j, i); + pstore(blockA + ri, lhsV); + } + + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + } +}; + +// General template for rhs complex packing. +template +struct rhs_cpack +{ + EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Scalar *blockBt = reinterpret_cast(blockB); + Packet conj = pset1((Scalar)-1.0f); + + Index ri = 0, j = 0; + for(; j + vectorSize < cols; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += offset*vectorSize; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + cblock.packet[0] = rhs.template loadPacket(i, j + 0); + cblock.packet[1] = rhs.template loadPacket(i, j + 1); + cblock.packet[2] = rhs.template loadPacket(i, j + 2); + cblock.packet[3] = rhs.template loadPacket(i, j + 3); + + cblock.packet[4] = rhs.template loadPacket(i + 2, j + 0); + cblock.packet[5] = rhs.template loadPacket(i + 2, j + 1); + cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); + cblock.packet[7] = rhs.template loadPacket(i + 2, j + 3); + } else { + cblock.packet[0] = rhs.template loadPacket(i + 0, j); + cblock.packet[1] = rhs.template loadPacket(i + 1, j); + cblock.packet[2] = rhs.template loadPacket(i + 2, j); + cblock.packet[3] = rhs.template loadPacket(i + 3, j); + + cblock.packet[4] = rhs.template loadPacket(i + 0, j + 2); + cblock.packet[5] = rhs.template loadPacket(i + 1, j + 2); + cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); + cblock.packet[7] = rhs.template loadPacket(i + 3, j + 2); + } + + PacketBlock block; + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + + if(StorageOrder == ColMajor) ptranspose(block); + + pstore(blockBt + ri , block.packet[0]); + pstore(blockBt + ri + 4, block.packet[1]); + pstore(blockBt + ri + 8, block.packet[2]); + pstore(blockBt + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + blockBt[ri+0] = rhs(i, j+0).real(); + blockBt[ri+1] = rhs(i, j+1).real(); + blockBt[ri+2] = rhs(i, j+2).real(); + blockBt[ri+3] = rhs(i, j+3).real(); + ri += vectorSize; + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + + i = 0; + + if(PanelMode) ri += offset*vectorSize; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + + cblock.packet[0] = rhs.template loadPacket(i, j + 0); + cblock.packet[1] = rhs.template loadPacket(i, j + 1); + cblock.packet[2] = rhs.template loadPacket(i, j + 2); + cblock.packet[3] = rhs.template loadPacket(i, j + 3); + + cblock.packet[4] = rhs.template loadPacket(i + 2, j + 0); + cblock.packet[5] = rhs.template loadPacket(i + 2, j + 1); + cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); + cblock.packet[7] = rhs.template loadPacket(i + 2, j + 3); + } else { + cblock.packet[0] = rhs.template loadPacket(i + 0, j); + cblock.packet[1] = rhs.template loadPacket(i + 1, j); + cblock.packet[2] = rhs.template loadPacket(i + 2, j); + cblock.packet[3] = rhs.template loadPacket(i + 3, j); + + cblock.packet[4] = rhs.template loadPacket(i + 0, j + 2); + cblock.packet[5] = rhs.template loadPacket(i + 1, j + 2); + cblock.packet[6] = rhs.template loadPacket(i + 2, j + 2); + cblock.packet[7] = rhs.template loadPacket(i + 3, j + 2); + } + + PacketBlock block; + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + + if(Conjugate) + { + block.packet[0] *= conj; + block.packet[1] *= conj; + block.packet[2] *= conj; + block.packet[3] *= conj; + } + + if(StorageOrder == ColMajor) ptranspose(block); + + pstore(blockBt + ri , block.packet[0]); + pstore(blockBt + ri + 4, block.packet[1]); + pstore(blockBt + ri + 8, block.packet[2]); + pstore(blockBt + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(Conjugate) + { + blockBt[ri+0] = -rhs(i, j+0).imag(); + blockBt[ri+1] = -rhs(i, j+1).imag(); + blockBt[ri+2] = -rhs(i, j+2).imag(); + blockBt[ri+3] = -rhs(i, j+3).imag(); + } else { + blockBt[ri+0] = rhs(i, j+0).imag(); + blockBt[ri+1] = rhs(i, j+1).imag(); + blockBt[ri+2] = rhs(i, j+2).imag(); + blockBt[ri+3] = rhs(i, j+3).imag(); + } + ri += vectorSize; + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockBt[ri] = rhs(i, k).real(); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + if(Conjugate) + blockBt[ri] = -rhs(i, k).imag(); + else + blockBt[ri] = rhs(i, k).imag(); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + } +}; + +// General template for rhs packing. +template +struct rhs_pack { + EIGEN_STRONG_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + for(; j + vectorSize < cols; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += offset*vectorSize; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + if(StorageOrder == ColMajor) + { + block.packet[0] = rhs.template loadPacket(i, j + 0); + block.packet[1] = rhs.template loadPacket(i, j + 1); + block.packet[2] = rhs.template loadPacket(i, j + 2); + block.packet[3] = rhs.template loadPacket(i, j + 3); + + ptranspose(block); + } else { + block.packet[0] = rhs.template loadPacket(i + 0, j); + block.packet[1] = rhs.template loadPacket(i + 1, j); + block.packet[2] = rhs.template loadPacket(i + 2, j); + block.packet[3] = rhs.template loadPacket(i + 3, j); + } + + pstore(blockB + ri , block.packet[0]); + pstore(blockB + ri + 4, block.packet[1]); + pstore(blockB + ri + 8, block.packet[2]); + pstore(blockB + ri + 12, block.packet[3]); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == ColMajor) + { + blockB[ri+0] = rhs(i, j+0); + blockB[ri+1] = rhs(i, j+1); + blockB[ri+2] = rhs(i, j+2); + blockB[ri+3] = rhs(i, j+3); + } else { + Packet rhsV = rhs.template loadPacket(i, j); + pstore(blockB + ri, rhsV); + } + ri += vectorSize; + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + } +}; + +// General template for lhs packing, float64 specialization. +template +struct lhs_pack +{ + EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(j = 0; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + if(StorageOrder == RowMajor) + { + block.packet[0] = lhs.template loadPacket(j + 0, i); + block.packet[1] = lhs.template loadPacket(j + 1, i); + + ptranspose(block); + } else { + block.packet[0] = lhs.template loadPacket(j, i + 0); + block.packet[1] = lhs.template loadPacket(j, i + 1); + } + + pstore(blockA + ri , block.packet[0]); + pstore(blockA + ri + 2, block.packet[1]); + + ri += 2*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == RowMajor) + { + blockA[ri+0] = lhs(j+0, i); + blockA[ri+1] = lhs(j+1, i); + } else { + Packet2d lhsV = lhs.template loadPacket(j, i); + pstore(blockA + ri, lhsV); + } + + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + } +}; + +// General template for rhs packing, float64 specialization. +template +struct rhs_pack +{ + EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + for(; j + 2*vectorSize < cols; j+=2*vectorSize) + { + Index i = 0; + + if(PanelMode) ri += offset*(2*vectorSize); + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + if(StorageOrder == ColMajor) + { + PacketBlock block1, block2; + block1.packet[0] = rhs.template loadPacket(i, j + 0); + block1.packet[1] = rhs.template loadPacket(i, j + 1); + block2.packet[0] = rhs.template loadPacket(i, j + 2); + block2.packet[1] = rhs.template loadPacket(i, j + 3); + + ptranspose(block1); + ptranspose(block2); + + pstore(blockB + ri , block1.packet[0]); + pstore(blockB + ri + 2, block2.packet[0]); + pstore(blockB + ri + 4, block1.packet[1]); + pstore(blockB + ri + 6, block2.packet[1]); + } else { + block.packet[0] = rhs.template loadPacket(i + 0, j + 0); //[a1 a2] + block.packet[1] = rhs.template loadPacket(i + 0, j + 2); //[a3 a4] + block.packet[2] = rhs.template loadPacket(i + 1, j + 0); //[b1 b2] + block.packet[3] = rhs.template loadPacket(i + 1, j + 2); //[b3 b4] + + pstore(blockB + ri , block.packet[0]); + pstore(blockB + ri + 2, block.packet[1]); + pstore(blockB + ri + 4, block.packet[2]); + pstore(blockB + ri + 6, block.packet[3]); + } + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == ColMajor) + { + blockB[ri+0] = rhs(i, j+0); + blockB[ri+1] = rhs(i, j+1); + + ri += vectorSize; + + blockB[ri+0] = rhs(i, j+2); + blockB[ri+1] = rhs(i, j+3); + } else { + Packet2d rhsV = rhs.template loadPacket(i, j); + pstore(blockB + ri, rhsV); + + ri += vectorSize; + + rhsV = rhs.template loadPacket(i, j + 2); + pstore(blockB + ri, rhsV); + } + ri += vectorSize; + } + + if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + } +}; + +// General template for lhs complex packing, float64 specialization. +template +struct lhs_cpack +{ + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + double *blockAt = reinterpret_cast(blockA); + Packet conj = pset1((double)-1.0f); + + for(j = 0; j + vectorSize < rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs.template loadPacket(j, i + 0); //[a1 a1i] + cblock.packet[1] = lhs.template loadPacket(j, i + 1); //[b1 b1i] + + cblock.packet[2] = lhs.template loadPacket(j + 1, i + 0); //[a2 a2i] + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + } else { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] + cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] + + cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); //[b1 b1i] + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + } + + pstore(blockAt + ri , block.packet[0]); + pstore(blockAt + ri + 2, block.packet[1]); + + ri += 2*vectorSize; + } + for(; i < depth; i++) + { + blockAt[ri + 0] = lhs(j + 0, i).real(); + blockAt[ri + 1] = lhs(j + 1, i).real(); + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + + i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize < depth; i+=vectorSize) + { + PacketBlock block; + + PacketBlock cblock; + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs.template loadPacket(j, i + 0); + cblock.packet[1] = lhs.template loadPacket(j, i + 1); + + cblock.packet[2] = lhs.template loadPacket(j + 1, i + 0); + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETIMAG64); + } else { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); + cblock.packet[1] = lhs.template loadPacket(j + 1, i); + + cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + } + + if(Conjugate) + { + block.packet[0] *= conj; + block.packet[1] *= conj; + } + + pstore(blockAt + ri , block.packet[0]); + pstore(blockAt + ri + 2, block.packet[1]); + + ri += 2*vectorSize; + } + for(; i < depth; i++) + { + if(Conjugate) + { + blockAt[ri + 0] = -lhs(j + 0, i).imag(); + blockAt[ri + 1] = -lhs(j + 1, i).imag(); + } else { + blockAt[ri + 0] = lhs(j + 0, i).imag(); + blockAt[ri + 1] = lhs(j + 1, i).imag(); + } + + ri += vectorSize; + } + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockAt[ri] = lhs(k, i).real(); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if(Conjugate) + blockAt[ri] = -lhs(k, i).imag(); + else + blockAt[ri] = lhs(k, i).imag(); + ri += 1; + } + } + + if(PanelMode) ri += (rows - j)*(stride - offset - depth); + } +}; + +// General template for rhs complex packing, float64 specialization. +template +struct rhs_cpack +{ + EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const int vectorSize = quad_traits::vectorsize; + double *blockBt = reinterpret_cast(blockB); + Packet conj = pset1((double)-1.0f); + + Index ri = 0, j = 0; + for(; j + 2*vectorSize < cols; j+=2*vectorSize) + { + Index i = 0; + + if(PanelMode) ri += offset*(2*vectorSize); + + for(; i < depth; i++) + { + PacketBlock cblock; + PacketBlock block; + + cblock.packet[0] = rhs.template loadPacket(i, j + 0); + cblock.packet[1] = rhs.template loadPacket(i, j + 1); + cblock.packet[2] = rhs.template loadPacket(i, j + 2); + cblock.packet[3] = rhs.template loadPacket(i, j + 3); + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); + block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); + + pstore(blockBt + ri , block.packet[0]); + pstore(blockBt + ri + 2, block.packet[1]); + + ri += 2*vectorSize; + } + + if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); + + i = 0; + + if(PanelMode) ri += offset*(2*vectorSize); + + for(; i < depth; i++) + { + PacketBlock cblock; + PacketBlock block; + + cblock.packet[0] = rhs.template loadPacket(i, j + 0); //[a1 a1i] + cblock.packet[1] = rhs.template loadPacket(i, j + 1); //[b1 b1i] + cblock.packet[2] = rhs.template loadPacket(i, j + 2); //[c1 c1i] + cblock.packet[3] = rhs.template loadPacket(i, j + 3); //[d1 d1i] + + block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + + if(Conjugate) + { + block.packet[0] *= conj; + block.packet[1] *= conj; + } + + pstore(blockBt + ri , block.packet[0]); + pstore(blockBt + ri + 2, block.packet[1]); + + ri += 2*vectorSize; + } + if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); + } + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockBt[ri] = rhs(i, k).real(); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + if(Conjugate) + blockBt[ri] = -rhs(i, k).imag(); + else + blockBt[ri] = rhs(i, k).imag(); + ri += 1; + } + } + if(PanelMode) ri += (cols - j)*(stride - offset - depth); + } +}; + +/************** + * GEMM utils * + **************/ + +// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. +template +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +template<> +EIGEN_STRONG_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +#ifdef __MMA__ +template +EIGEN_STRONG_INLINE PacketBlock pmul (const PacketBlock& a, const Packet& b) +{ + PacketBlock pb; + pb.packet[0] = a.packet[0]*b; + pb.packet[1] = a.packet[1]*b; + return pb; +} +template +EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad *acc) +{ + PacketBlock result; + __builtin_mma_disassemble_acc(&result.packet, acc); + + PacketBlock block; + block.packet[0] = data.template loadPacket(i, j + 0) + pmul(alpha, result.packet[0]); + block.packet[1] = data.template loadPacket(i, j + 1) + pmul(alpha, result.packet[1]); + block.packet[2] = data.template loadPacket(i, j + 2) + pmul(alpha, result.packet[2]); + block.packet[3] = data.template loadPacket(i, j + 3) + pmul(alpha, result.packet[3]); + + data.template storePacketBlock(i, j, block); +} + +template +EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad *accReal, __vector_quad *accImag, const int accColsC) +{ + PacketBlock resultReal, resultImag; + __builtin_mma_disassemble_acc(&resultReal.packet, accReal); + __builtin_mma_disassemble_acc(&resultImag.packet, accImag); + + PacketBlock taccReal, taccImag; + taccReal.packet[0] = pmul(resultReal.packet[0], alphaReal); + taccReal.packet[1] = pmul(resultReal.packet[1], alphaReal); + taccReal.packet[2] = pmul(resultReal.packet[2], alphaReal); + taccReal.packet[3] = pmul(resultReal.packet[3], alphaReal); + + taccImag.packet[0] = pmul(resultImag.packet[0], alphaReal); + taccImag.packet[1] = pmul(resultImag.packet[1], alphaReal); + taccImag.packet[2] = pmul(resultImag.packet[2], alphaReal); + taccImag.packet[3] = pmul(resultImag.packet[3], alphaReal); + + taccReal.packet[0] = psub(taccReal.packet[0], pmul(resultImag.packet[0], alphaImag)); + taccReal.packet[1] = psub(taccReal.packet[1], pmul(resultImag.packet[1], alphaImag)); + taccReal.packet[2] = psub(taccReal.packet[2], pmul(resultImag.packet[2], alphaImag)); + taccReal.packet[3] = psub(taccReal.packet[3], pmul(resultImag.packet[3], alphaImag)); + + taccImag.packet[0] = pmadd(resultReal.packet[0], alphaImag, taccImag.packet[0]); + taccImag.packet[1] = pmadd(resultReal.packet[1], alphaImag, taccImag.packet[1]); + taccImag.packet[2] = pmadd(resultReal.packet[2], alphaImag, taccImag.packet[2]); + taccImag.packet[3] = pmadd(resultReal.packet[3], alphaImag, taccImag.packet[3]); + + PacketBlock tRes; + tRes.packet[0] = data.template loadPacket(i + N*accColsC, j + 0); + tRes.packet[1] = data.template loadPacket(i + N*accColsC, j + 1); + tRes.packet[2] = data.template loadPacket(i + N*accColsC, j + 2); + tRes.packet[3] = data.template loadPacket(i + N*accColsC, j + 3); + + tRes.packet[4] = data.template loadPacket(i + (N+1)*accColsC, j + 0); + tRes.packet[5] = data.template loadPacket(i + (N+1)*accColsC, j + 1); + tRes.packet[6] = data.template loadPacket(i + (N+1)*accColsC, j + 2); + tRes.packet[7] = data.template loadPacket(i + (N+1)*accColsC, j + 3); + + PacketBlock acc1, acc2; + bcouple(taccReal, taccImag, tRes, acc1, acc2); + + data.template storePacketBlock(i + N*accColsC, j, acc1); + data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); +} + +// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments +template +EIGEN_STRONG_INLINE void pger(__vector_quad *acc, const RhsPacket& a, const LhsPacket& b) +{ + if(NegativeAccumulate) + { + __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } else { + __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } +} + +template<> +EIGEN_STRONG_INLINE void pger, false>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +{ + Packetx2u p; + p.pair = a; + __builtin_mma_xvf64gerpp(acc, p.vectorpair, (__vector unsigned char)b); +} + +template<> +EIGEN_STRONG_INLINE void pger, true>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) +{ + Packetx2u p; + p.pair = a; + __builtin_mma_xvf64gernp(acc, p.vectorpair, (__vector unsigned char)b); +} +#else + +// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm). +template +EIGEN_STRONG_INLINE void pger(PacketBlock *acc, const Scalar* lhs, const Scalar* rhs) +{ + Packet lhsV = *((Packet *) lhs); + Packet rhsV1 = pset1(rhs[0]); + Packet rhsV2 = pset1(rhs[1]); + Packet rhsV3 = pset1(rhs[2]); + Packet rhsV4 = pset1(rhs[3]); + + if(NegativeAccumulate) + { + acc->packet[0] -= lhsV*rhsV1; + acc->packet[1] -= lhsV*rhsV2; + acc->packet[2] -= lhsV*rhsV3; + acc->packet[3] -= lhsV*rhsV4; + } else { + acc->packet[0] += lhsV*rhsV1; + acc->packet[1] += lhsV*rhsV2; + acc->packet[2] += lhsV*rhsV3; + acc->packet[3] += lhsV*rhsV4; + } +} + +// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. +template +EIGEN_STRONG_INLINE void pgerc(PacketBlock& accReal, PacketBlock& accImag, const Scalar *rhs_ptr, const Scalar *rhs_ptr_imag, const Scalar *lhs_ptr, const Scalar* lhs_ptr_imag, Packet& conj) +{ + Packet lhsV = *((Packet *) lhs_ptr); + Packet rhsV1 = pset1(rhs_ptr[0]); + Packet rhsV2 = pset1(rhs_ptr[1]); + Packet rhsV3 = pset1(rhs_ptr[2]); + Packet rhsV4 = pset1(rhs_ptr[3]); + + Packet lhsVi; + if(!LhsIsReal) lhsVi = *((Packet *) lhs_ptr_imag); + Packet rhsV1i, rhsV2i, rhsV3i, rhsV4i; + if(!RhsIsReal) + { + rhsV1i = pset1(rhs_ptr_imag[0]); + rhsV2i = pset1(rhs_ptr_imag[1]); + rhsV3i = pset1(rhs_ptr_imag[2]); + rhsV4i = pset1(rhs_ptr_imag[3]); + } + + if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi,conj); + if(ConjugateRhs && !RhsIsReal) + { + rhsV1i = pmul(rhsV1i,conj); + rhsV2i = pmul(rhsV2i,conj); + rhsV3i = pmul(rhsV3i,conj); + rhsV4i = pmul(rhsV4i,conj); + } + + if(LhsIsReal) + { + accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); + accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); + accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); + accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); + + accImag.packet[0] = pmadd(rhsV1i, lhsV, accImag.packet[0]); + accImag.packet[1] = pmadd(rhsV2i, lhsV, accImag.packet[1]); + accImag.packet[2] = pmadd(rhsV3i, lhsV, accImag.packet[2]); + accImag.packet[3] = pmadd(rhsV4i, lhsV, accImag.packet[3]); + } else if(RhsIsReal) { + accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); + accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); + accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); + accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); + + accImag.packet[0] = pmadd(rhsV1, lhsVi, accImag.packet[0]); + accImag.packet[1] = pmadd(rhsV2, lhsVi, accImag.packet[1]); + accImag.packet[2] = pmadd(rhsV3, lhsVi, accImag.packet[2]); + accImag.packet[3] = pmadd(rhsV4, lhsVi, accImag.packet[3]); + } else { + accReal.packet[0] = pmadd(rhsV1, lhsV, accReal.packet[0]); + accReal.packet[1] = pmadd(rhsV2, lhsV, accReal.packet[1]); + accReal.packet[2] = pmadd(rhsV3, lhsV, accReal.packet[2]); + accReal.packet[3] = pmadd(rhsV4, lhsV, accReal.packet[3]); + + accImag.packet[0] = pmadd(rhsV1i, lhsV, accImag.packet[0]); + accImag.packet[1] = pmadd(rhsV2i, lhsV, accImag.packet[1]); + accImag.packet[2] = pmadd(rhsV3i, lhsV, accImag.packet[2]); + accImag.packet[3] = pmadd(rhsV4i, lhsV, accImag.packet[3]); + + accReal.packet[0] = psub(accReal.packet[0], pmul(rhsV1i, lhsVi)); + accReal.packet[1] = psub(accReal.packet[1], pmul(rhsV2i, lhsVi)); + accReal.packet[2] = psub(accReal.packet[2], pmul(rhsV3i, lhsVi)); + accReal.packet[3] = psub(accReal.packet[3], pmul(rhsV4i, lhsVi)); + + accImag.packet[0] = pmadd(rhsV1, lhsVi, accImag.packet[0]); + accImag.packet[1] = pmadd(rhsV2, lhsVi, accImag.packet[1]); + accImag.packet[2] = pmadd(rhsV3, lhsVi, accImag.packet[2]); + accImag.packet[3] = pmadd(rhsV4, lhsVi, accImag.packet[3]); + } +} +#endif + +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs) +{ + return *((Packet *)rhs); +} + +#ifdef __MMA__ +template<> +EIGEN_STRONG_INLINE PacketBlock ploadRhs >(const double *rhs) +{ + PacketBlock pair; + pair.packet[0] = *((Packet2d *)rhs ); + pair.packet[1] = *(((Packet2d *)rhs) + 1); + return pair; +} +#endif + +template +EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) +{ + return *((Packet *)lhs); +} + +#ifndef __MMA__ +// Zero the accumulator on PacketBlock. +template +EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) +{ + acc.packet[0] = pset1((Scalar)0); + acc.packet[1] = pset1((Scalar)0); + acc.packet[2] = pset1((Scalar)0); + acc.packet[3] = pset1((Scalar)0); +} + +// Scale the PacketBlock vectors by alpha. +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmadd(pAlpha,accZ.packet[0], acc.packet[0]); + acc.packet[1] = pmadd(pAlpha,accZ.packet[1], acc.packet[1]); + acc.packet[2] = pmadd(pAlpha,accZ.packet[2], acc.packet[2]); + acc.packet[3] = pmadd(pAlpha,accZ.packet[3], acc.packet[3]); +} + +// Complex version of PacketBlock scaling. +template +EIGEN_STRONG_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) +{ + cReal.packet[0] = pmul(aReal.packet[0], bReal); + cReal.packet[1] = pmul(aReal.packet[1], bReal); + cReal.packet[2] = pmul(aReal.packet[2], bReal); + cReal.packet[3] = pmul(aReal.packet[3], bReal); + + cImag.packet[0] = pmul(aImag.packet[0], bReal); + cImag.packet[1] = pmul(aImag.packet[1], bReal); + cImag.packet[2] = pmul(aImag.packet[2], bReal); + cImag.packet[3] = pmul(aImag.packet[3], bReal); + + cReal.packet[0] = psub(cReal.packet[0], pmul(aImag.packet[0], bImag)); + cReal.packet[1] = psub(cReal.packet[1], pmul(aImag.packet[1], bImag)); + cReal.packet[2] = psub(cReal.packet[2], pmul(aImag.packet[2], bImag)); + cReal.packet[3] = psub(cReal.packet[3], pmul(aImag.packet[3], bImag)); + + cImag.packet[0] = pmadd(aReal.packet[0], bImag, cImag.packet[0]); + cImag.packet[1] = pmadd(aReal.packet[1], bImag, cImag.packet[1]); + cImag.packet[2] = pmadd(aReal.packet[2], bImag, cImag.packet[2]); + cImag.packet[3] = pmadd(aReal.packet[3], bImag, cImag.packet[3]); +} + +// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col, Index accCols) +{ + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); +} + +// An overload of bload when you have a PacketBLock with 8 vectors. +template +EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col, Index accCols) +{ + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + acc.packet[4] = res.template loadPacket(row + (N+1)*accCols, col + 0); + acc.packet[5] = res.template loadPacket(row + (N+1)*accCols, col + 1); + acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); + acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); +} +#endif + +// PEEL loop factor. +#define PEEL 10 + +/**************** + * GEMM kernels * + * **************/ +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, + Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlpha = pset1(alpha); + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; +#ifdef __MMA__ + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; + + __vector_quad acc; + __builtin_mma_xxsetaccz(&acc); + + lhs_ptr1 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + Packet lhsV = ploadLhs(lhs_ptr1); + RhsPacket rhsV = ploadRhs(rhs_ptr); + + pger(&acc, rhsV, lhsV); + + lhs_ptr1 += accCols; + rhs_ptr += accRows; + } + + storeAccumulator(row, col, res, pAlpha, &acc); + } +#else + for(; row + 6*accCols <= rows; row += 6*accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + pger(&accZero2, lhs_ptr2, rhs_ptr); \ + lhs_ptr2 += accCols; \ + pger(&accZero3, lhs_ptr3, rhs_ptr); \ + lhs_ptr3 += accCols; \ + pger(&accZero4, lhs_ptr4, rhs_ptr); \ + lhs_ptr4 += accCols; \ + pger(&accZero5, lhs_ptr5, rhs_ptr); \ + lhs_ptr5 += accCols; \ + pger(&accZero6, lhs_ptr6, rhs_ptr); \ + lhs_ptr6 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; + const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; + const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; + const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; + const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; + const Scalar *lhs_ptr6 = lhs_base + ((row/accCols) + 5)*strideA*accCols; + + PacketBlock acc1, accZero1; + PacketBlock acc2, accZero2; + PacketBlock acc3, accZero3; + PacketBlock acc4, accZero4; + PacketBlock acc5, accZero5; + PacketBlock acc6, accZero6; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + bload(acc2, res, row, col, accCols); + bsetzero(accZero2); + bload(acc3, res, row, col, accCols); + bsetzero(accZero3); + bload(acc4, res, row, col, accCols); + bsetzero(accZero4); + bload(acc5, res, row, col, accCols); + bsetzero(accZero5); + bload(acc6, res, row, col, accCols); + bsetzero(accZero6); + + lhs_ptr1 += accCols*offsetA; + lhs_ptr2 += accCols*offsetA; + lhs_ptr3 += accCols*offsetA; + lhs_ptr4 += accCols*offsetA; + lhs_ptr5 += accCols*offsetA; + lhs_ptr6 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + + Index k = 0; + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + prefetch(lhs_ptr2); + prefetch(lhs_ptr3); + prefetch(lhs_ptr4); + prefetch(lhs_ptr5); + prefetch(lhs_ptr6); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + bscale(acc2,accZero2, pAlpha); + bscale(acc3,accZero3, pAlpha); + bscale(acc4,accZero4, pAlpha); + bscale(acc5,accZero5, pAlpha); + bscale(acc6,accZero6, pAlpha); + + res.template storePacketBlock(row + 0*accCols, col, acc1); + res.template storePacketBlock(row + 1*accCols, col, acc2); + res.template storePacketBlock(row + 2*accCols, col, acc3); + res.template storePacketBlock(row + 3*accCols, col, acc4); + res.template storePacketBlock(row + 4*accCols, col, acc5); + res.template storePacketBlock(row + 5*accCols, col, acc6); +#undef MICRO + } + for(; row + 5*accCols <= rows; row += 5*accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + pger(&accZero2, lhs_ptr2, rhs_ptr); \ + lhs_ptr2 += accCols; \ + pger(&accZero3, lhs_ptr3, rhs_ptr); \ + lhs_ptr3 += accCols; \ + pger(&accZero4, lhs_ptr4, rhs_ptr); \ + lhs_ptr4 += accCols; \ + pger(&accZero5, lhs_ptr5, rhs_ptr); \ + lhs_ptr5 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; + const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; + const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; + const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; + const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; + + PacketBlock acc1, accZero1; + PacketBlock acc2, accZero2; + PacketBlock acc3, accZero3; + PacketBlock acc4, accZero4; + PacketBlock acc5, accZero5; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + bload(acc2, res, row, col, accCols); + bsetzero(accZero2); + bload(acc3, res, row, col, accCols); + bsetzero(accZero3); + bload(acc4, res, row, col, accCols); + bsetzero(accZero4); + bload(acc5, res, row, col, accCols); + bsetzero(accZero5); + + lhs_ptr1 += accCols*offsetA; + lhs_ptr2 += accCols*offsetA; + lhs_ptr3 += accCols*offsetA; + lhs_ptr4 += accCols*offsetA; + lhs_ptr5 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + Index k = 0; + + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + prefetch(lhs_ptr2); + prefetch(lhs_ptr3); + prefetch(lhs_ptr4); + prefetch(lhs_ptr5); + + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + bscale(acc2,accZero2, pAlpha); + bscale(acc3,accZero3, pAlpha); + bscale(acc4,accZero4, pAlpha); + bscale(acc5,accZero5, pAlpha); + + res.template storePacketBlock(row + 0*accCols, col, acc1); + res.template storePacketBlock(row + 1*accCols, col, acc2); + res.template storePacketBlock(row + 2*accCols, col, acc3); + res.template storePacketBlock(row + 3*accCols, col, acc4); + res.template storePacketBlock(row + 4*accCols, col, acc5); +#undef MICRO + } + for(; row + 4*accCols <= rows; row += 4*accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + pger(&accZero2, lhs_ptr2, rhs_ptr); \ + lhs_ptr2 += accCols; \ + pger(&accZero3, lhs_ptr3, rhs_ptr); \ + lhs_ptr3 += accCols; \ + pger(&accZero4, lhs_ptr4, rhs_ptr); \ + lhs_ptr4 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; + const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; + const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; + const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; + + PacketBlock acc1, accZero1; + PacketBlock acc2, accZero2; + PacketBlock acc3, accZero3; + PacketBlock acc4, accZero4; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + bload(acc2, res, row, col, accCols); + bsetzero(accZero2); + bload(acc3, res, row, col, accCols); + bsetzero(accZero3); + bload(acc4, res, row, col, accCols); + bsetzero(accZero4); + + lhs_ptr1 += accCols*offsetA; + lhs_ptr2 += accCols*offsetA; + lhs_ptr3 += accCols*offsetA; + lhs_ptr4 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + Index k = 0; + + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + prefetch(lhs_ptr2); + prefetch(lhs_ptr3); + prefetch(lhs_ptr4); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + bscale(acc2,accZero2, pAlpha); + bscale(acc3,accZero3, pAlpha); + bscale(acc4,accZero4, pAlpha); + + res.template storePacketBlock(row + 0*accCols, col, acc1); + res.template storePacketBlock(row + 1*accCols, col, acc2); + res.template storePacketBlock(row + 2*accCols, col, acc3); + res.template storePacketBlock(row + 3*accCols, col, acc4); +#undef MICRO + } + for(; row + 3*accCols <= rows; row += 3*accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + pger(&accZero2, lhs_ptr2, rhs_ptr); \ + lhs_ptr2 += accCols; \ + pger(&accZero3, lhs_ptr3, rhs_ptr); \ + lhs_ptr3 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; + const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; + const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; + + PacketBlock acc1, accZero1; + PacketBlock acc2, accZero2; + PacketBlock acc3, accZero3; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + bload(acc2, res, row, col, accCols); + bsetzero(accZero2); + bload(acc3, res, row, col, accCols); + bsetzero(accZero3); + + lhs_ptr1 += accCols*offsetA; + lhs_ptr2 += accCols*offsetA; + lhs_ptr3 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + Index k = 0; + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + prefetch(lhs_ptr2); + prefetch(lhs_ptr3); + + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + bscale(acc2,accZero2, pAlpha); + bscale(acc3,accZero3, pAlpha); + + res.template storePacketBlock(row + 0*accCols, col, acc1); + res.template storePacketBlock(row + 1*accCols, col, acc2); + res.template storePacketBlock(row + 2*accCols, col, acc3); +#undef MICRO + } + for(; row + 2*accCols <= rows; row += 2*accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + pger(&accZero2, lhs_ptr2, rhs_ptr); \ + lhs_ptr2 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; + const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; + + PacketBlock acc1, accZero1; + PacketBlock acc2, accZero2; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + bload(acc2, res, row, col, accCols); + bsetzero(accZero2); + + lhs_ptr1 += accCols*offsetA; + lhs_ptr2 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + Index k = 0; + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + prefetch(lhs_ptr2); + + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + bscale(acc2,accZero2, pAlpha); + + res.template storePacketBlock(row + 0*accCols, col, acc1); + res.template storePacketBlock(row + 1*accCols, col, acc2); +#undef MICRO + } + + for(; row + accCols <= rows; row += accCols) + { +#define MICRO() \ + pger(&accZero1, lhs_ptr1, rhs_ptr); \ + lhs_ptr1 += accCols; \ + rhs_ptr += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; + + PacketBlock acc1, accZero1; + + bload(acc1, res, row, col, accCols); + bsetzero(accZero1); + + lhs_ptr1 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + Index k = 0; + for(; k + PEEL < depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr1); + + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + + bscale(acc1,accZero1, pAlpha); + + res.template storePacketBlock(row, col, acc1); +#undef MICRO + } +#endif + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += remaining_rows*offsetA; + rhs_ptr += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + for(Index acol = 0; acol < accRows; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += accRows; + lhs_ptr += remaining_rows; + } + } + } + + if(remaining_cols > 0) + { + const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += accCols*offsetA; + rhs_ptr += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < accCols; arow++) + { + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += accCols; + } + } + + if(remaining_rows > 0 ) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + + lhs_ptr += remaining_rows*offsetA; + rhs_ptr += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += remaining_rows; + } + } + } +} + +template +EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, + Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +{ + const int remaining_rows = rows % accCols; + const int remaining_cols = cols % accRows; + const int accColsC = accCols / 2; + int advanceCols = 2; + int advanceRows = 2; + + if(LhsIsReal) advanceRows = 1; + if(RhsIsReal) advanceCols = 1; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlphaReal = pset1(alpha.real()); + const Packet pAlphaImag = pset1(alpha.imag()); + + const Scalar *blockA = (Scalar *) blockAc; + const Scalar *blockB = (Scalar *) blockBc; + + Packet conj = pset1((Scalar)-1.0f); + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + + Index row = 0; #ifdef __MMA__ + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; + + __vector_quad accReal, accImag; + __builtin_mma_xxsetaccz(&accReal); + __builtin_mma_xxsetaccz(&accImag); + + lhs_ptr += accCols*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += accCols*offsetA; + rhs_ptr += accRows*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + Packet lhsV = ploadLhs(lhs_ptr); + RhsPacket rhsV = ploadRhs(rhs_ptr); + + Packet lhsVi = ploadLhs(lhs_ptr_imag); + RhsPacket rhsVi = ploadRhs(rhs_ptr_imag); + + if(ConjugateLhs && !LhsIsReal) lhsVi = pmul(lhsVi, conj); + if(ConjugateRhs && !RhsIsReal) rhsVi = pmul(rhsVi, conj); + + if(LhsIsReal) + { + pger(&accReal, rhsV, lhsV); + pger(&accImag, rhsVi, lhsV); + } else if(RhsIsReal) { + pger(&accReal, rhsV, lhsV); + pger(&accImag, rhsV, lhsVi); + } else { + pger(&accReal, rhsV, lhsV); + pger(&accReal, rhsVi, lhsVi); + pger(&accImag, rhsVi, lhsV); + pger(&accImag, rhsV, lhsVi); + } + + lhs_ptr += accCols; + rhs_ptr += accRows; + if(!LhsIsReal) + lhs_ptr_imag += accCols; + if(!RhsIsReal) + rhs_ptr_imag += accRows; + } + + storeComplexAccumulator(row, col, res, pAlphaReal, pAlphaImag, &accReal, &accImag, accColsC); + } +#else + for(; row + accCols <= rows; row += accCols) + { +#define MICRO() \ + pgerc(accReal1, accImag1, rhs_ptr, rhs_ptr_imag, lhs_ptr1, lhs_ptr_imag1, conj); \ + lhs_ptr1 += accCols; \ + rhs_ptr += accRows; \ + if(!LhsIsReal) \ + lhs_ptr_imag1 += accCols; \ + if(!RhsIsReal) \ + rhs_ptr_imag += accRows; + + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; + const Scalar *lhs_ptr1 = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag1 = lhs_ptr1 + accCols*strideA; + + PacketBlock accReal1, accImag1; + bsetzero(accReal1); + bsetzero(accImag1); + + lhs_ptr1 += accCols*offsetA; + if(!LhsIsReal) + lhs_ptr_imag1 += accCols*offsetA; + rhs_ptr += accRows*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += accRows*offsetB; + Index k = 0; + for(; k + PEEL < depth; k+=PEEL) + { + prefetch(rhs_ptr); + prefetch(rhs_ptr_imag); + prefetch(lhs_ptr1); + prefetch(lhs_ptr_imag1); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); + MICRO(); +#if PEEL > 8 + MICRO(); + MICRO(); +#endif + } + for(; k < depth; k++) + { + MICRO(); + } + PacketBlock taccReal, taccImag; + bscalec(accReal1, accImag1, pAlphaReal, pAlphaImag, taccReal, taccImag); + + PacketBlock tRes; + bload(tRes, res, row, col, accColsC); + + PacketBlock acc1, acc2; + bcouple(taccReal, taccImag, tRes, acc1, acc2); + + res.template storePacketBlock(row + 0, col, acc1); + res.template storePacketBlock(row + accColsC, col, acc2); +#undef MICRO + } +#endif + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; + + lhs_ptr += remaining_rows*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows*offsetA; + rhs_ptr += accRows*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += accRows*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; + + Scalarc lhsc; -namespace Eigen { + lhsc.real(lhs_real); + if(!LhsIsReal) + { + if(ConjugateLhs) + lhsc.imag(-lhs_imag); + else + lhsc.imag(lhs_imag); + } else { + //Lazy approach for now + lhsc.imag((Scalar)0); + } -namespace internal { + for(int acol = 0; acol < accRows; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; + Scalarc rhsc; + + rhsc.real(rhs_real); + if(!RhsIsReal) + { + if(ConjugateRhs) + rhsc.imag(-rhs_imag); + else + rhsc.imag(rhs_imag); + } else { + //Lazy approach for now + rhsc.imag((Scalar)0); + } + res(row + arow, col + acol) += alpha*lhsc*rhsc; + } + } + rhs_ptr += accRows; + lhs_ptr += remaining_rows; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows; + if(!RhsIsReal) + rhs_ptr_imag += accRows; + } + } + } + + if(remaining_cols > 0) + { + const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows; + const Scalar *lhs_base = blockA; + Index row = 0; + + for(; row + accCols <= rows; row += accCols) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA; + + lhs_ptr += accCols*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += accCols*offsetA; + rhs_ptr += remaining_cols*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols*offsetB; + Scalarc scalarAcc[4][4]; + for(Index arow = 0; arow < 4; arow++ ) + { + for(Index acol = 0; acol < 4; acol++ ) + { + scalarAcc[arow][acol].real((Scalar)0.0f); + scalarAcc[arow][acol].imag((Scalar)0.0f); + } + } + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < accCols; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) + { + lhs_imag = lhs_ptr_imag[arow]; + + if(ConjugateLhs) + lhs_imag *= -1; + } else { + lhs_imag = (Scalar)0; + } + + for(int acol = 0; acol < remaining_cols; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) + { + rhs_imag = rhs_ptr_imag[acol]; + + if(ConjugateRhs) + rhs_imag *= -1; + } else { + rhs_imag = (Scalar)0; + } + + scalarAcc[arow][acol].real(scalarAcc[arow][acol].real() + lhs_real*rhs_real - lhs_imag*rhs_imag); + scalarAcc[arow][acol].imag(scalarAcc[arow][acol].imag() + lhs_imag*rhs_real + lhs_real*rhs_imag); + } + } + rhs_ptr += remaining_cols; + lhs_ptr += accCols; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols; + if(!LhsIsReal) + lhs_ptr_imag += accCols; + } + for(int arow = 0; arow < accCols; arow++ ) + { + for(int acol = 0; acol < remaining_cols; acol++ ) + { + Scalar accR = scalarAcc[arow][acol].real(); + Scalar accI = scalarAcc[arow][acol].imag(); + Scalar aR = alpha.real(); + Scalar aI = alpha.imag(); + Scalar resR = res(row + arow, col + acol).real(); + Scalar resI = res(row + arow, col + acol).imag(); + + res(row + arow, col + acol).real(resR + accR*aR - accI*aI); + res(row + arow, col + acol).imag(resI + accR*aI + accI*aR); + } + } + } + + if(remaining_rows > 0) + { + const Scalar *rhs_ptr = rhs_base; + const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB; + const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols; + const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA; + + lhs_ptr += remaining_rows*offsetA; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows*offsetA; + rhs_ptr += remaining_cols*offsetB; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols*offsetB; + for(Index k = 0; k < depth; k++) + { + for(Index arow = 0; arow < remaining_rows; arow++) + { + Scalar lhs_real = lhs_ptr[arow]; + Scalar lhs_imag; + if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow]; + Scalarc lhsc; + + lhsc.real(lhs_real); + if(!LhsIsReal) + { + if(ConjugateLhs) + lhsc.imag(-lhs_imag); + else + lhsc.imag(lhs_imag); + } else { + lhsc.imag((Scalar)0); + } + + for(Index acol = 0; acol < remaining_cols; acol++ ) + { + Scalar rhs_real = rhs_ptr[acol]; + Scalar rhs_imag; + if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol]; + Scalarc rhsc; + + rhsc.real(rhs_real); + if(!RhsIsReal) + { + if(ConjugateRhs) + rhsc.imag(-rhs_imag); + else + rhsc.imag(rhs_imag); + } else { + rhsc.imag((Scalar)0); + } + res(row + arow, col + acol) += alpha*lhsc*rhsc; + } + } + rhs_ptr += remaining_cols; + lhs_ptr += remaining_rows; + if(!LhsIsReal) + lhs_ptr_imag += remaining_rows; + if(!RhsIsReal) + rhs_ptr_imag += remaining_cols; + } + } + } +} + +/************************************ + * ppc64le template specializations * + * **********************************/ +template +struct gemm_pack_lhs +{ + void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs +{ + void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_rhs +{ + void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} -const int accRows = 4; -const int accCols = 4; -const int accCount = 4; -const int floatVectorSize = 4; +template +struct gemm_pack_rhs +{ + void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; -typedef struct +template +void gemm_pack_rhs + ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - __vector float v0; - __vector float v1; - __vector float v2; - __vector float v3; -} Packet4fx4; + rhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} -union PacketQuad +template +struct gemm_pack_lhs { - __struct_quad sc; - Packet4fx4 sf; + void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); }; +template +void gemm_pack_lhs + ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + template struct gemm_pack_lhs { @@ -45,40 +2656,35 @@ template ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - int ri = 0, j; - for(j = 0; j + floatVectorSize < rows; j+=floatVectorSize) - { - int i; - for(i = 0; i + floatVectorSize < depth; i+=floatVectorSize) - { - PacketBlock block; - block.packet[0] = lhs.template loadPacket(j, i + 0); - block.packet[1] = lhs.template loadPacket(j, i + 1); - block.packet[2] = lhs.template loadPacket(j, i + 2); - block.packet[3] = lhs.template loadPacket(j, i + 3); + lhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; - pstore((float *)(blockA + ri ), block.packet[0]); - pstore((float *)(blockA + ri + 4), block.packet[1]); - pstore((float *)(blockA + ri + 8), block.packet[2]); - pstore((float *)(blockA + ri + 12), block.packet[3]); - ri += 4*floatVectorSize; - } - for(; i < depth; i++) - { - Packet4f lhsV = lhs.template loadPacket(j, i); - pstore((float *)(blockA + ri), lhsV); - ri += floatVectorSize; - } - } - for(int i = 0; i < depth; i++) - { - int k = j; - for(; k < rows; k++) - { - blockA[ri] = lhs(k, i); - ri += 1; - } - } +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); } template @@ -91,221 +2697,299 @@ template ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { - int ri = 0, j; - for(j = 0; j + floatVectorSize < cols; j+=floatVectorSize) - { - int i; - for(i = 0; i + floatVectorSize < depth; i+=floatVectorSize) - { - PacketBlock block; - block.packet[0] = rhs.template loadPacket(i, j + 0); - block.packet[1] = rhs.template loadPacket(i, j + 1); - block.packet[2] = rhs.template loadPacket(i, j + 2); - block.packet[3] = rhs.template loadPacket(i, j + 3); + rhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} - ptranspose(block); +template +struct gemm_pack_rhs +{ + void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; - pstore((float *)(blockB + ri ), block.packet[0]); - pstore((float *)(blockB + ri + 4), block.packet[1]); - pstore((float *)(blockB + ri + 8), block.packet[2]); - pstore((float *)(blockB + ri + 12), block.packet[3]); +template +void gemm_pack_rhs + ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} - ri += 4*floatVectorSize; - } - for(; i < depth; i++) - { - blockB[ri+0] = rhs(i, j+0); - blockB[ri+1] = rhs(i, j+1); - blockB[ri+2] = rhs(i, j+2); - blockB[ri+3] = rhs(i, j+3); - ri += floatVectorSize; - } - } - for(int i = 0; i < depth; i++) - { - int k = j; - for(; k < cols; k++) - { - blockB[ri] = rhs(i, k); - ri += 1; - } - } +template +struct gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); } -template -EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, Scalar alpha, __vector_quad *acc) +template +struct gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> { - //[TODO] - // - //Packet4fx4 r; - // - //__builtin_mma_disassemble_acc((void *)&r, *acc); - // - PacketQuad result; - result.sc = __builtin_mma_disassemble_acc(*acc); + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; - Packet4f pAlpha = pset1(alpha); +template +void gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} - PacketBlock block; - block.packet[0] = pAlpha*result.sf.v3; - block.packet[1] = pAlpha*result.sf.v2; - block.packet[2] = pAlpha*result.sf.v1; - block.packet[3] = pAlpha*result.sf.v0; +template +struct gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; - data.template storePacketBlock(i, j, block); +template +void gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); } -template -struct gebp_kernel +// ********* gebp specializations ********* +template +struct gebp_kernel { - void operator()(const DataMapper& res, const float* blockA, const RhsScalar* blockB, + typedef typename quad_traits::vectortype Packet; + typedef typename quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols, float alpha, Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); }; -template -void gebp_kernel - ::operator()(const DataMapper& res, const float* blockA, const RhsScalar* blockB, +template +void gebp_kernel + ::operator()(const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols, float alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - const int remaining_rows = rows % accRows; - const int remaining_cols = cols % accCols; - const int remaining_depth = depth % floatVectorSize; - const int timesRows = (rows / accRows); - const int timesCols = (cols / accCols); - - int row; - for(row = 0; row + accRows <= rows; row += accRows) - { - const float *rhs_base = blockB; - const float *lhs_base = blockA + (row/accRows)*depth*floatVectorSize; - - int col; - for(col = 0; col + accCount*accCols <= cols; col += accCount*accCols){ - const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize; - const float *rhs_ptr2 = rhs_base + ((col/accCols) + 1)*depth*floatVectorSize; - const float *rhs_ptr3 = rhs_base + ((col/accCols) + 2)*depth*floatVectorSize; - const float *rhs_ptr4 = rhs_base + ((col/accCols) + 3)*depth*floatVectorSize; - const float *lhs_ptr = lhs_base; - - __vector_quad acc, acc2, acc3, acc4; - __builtin_mma_xxsetaccz(&acc); - __builtin_mma_xxsetaccz(&acc2); - __builtin_mma_xxsetaccz(&acc3); - __builtin_mma_xxsetaccz(&acc4); - - for(int k = 0; k < depth; k++) - { - __vector float lhsV = *((__vector float *)lhs_ptr ); - __vector float rhsV = *((__vector float *)rhs_ptr ); - __vector float rhs2V = *((__vector float *)rhs_ptr2); - __vector float rhs3V = *((__vector float *)rhs_ptr3); - __vector float rhs4V = *((__vector float *)rhs_ptr4); - - __builtin_mma_xvf32gerpp(&acc, (__vector unsigned char) rhsV, (__vector unsigned char) lhsV); - __builtin_mma_xvf32gerpp(&acc2, (__vector unsigned char) rhs2V, (__vector unsigned char) lhsV); - __builtin_mma_xvf32gerpp(&acc3, (__vector unsigned char) rhs3V, (__vector unsigned char) lhsV); - __builtin_mma_xvf32gerpp(&acc4, (__vector unsigned char) rhs4V, (__vector unsigned char) lhsV); - - lhs_ptr += floatVectorSize; - rhs_ptr += floatVectorSize; - rhs_ptr2 += floatVectorSize; - rhs_ptr3 += floatVectorSize; - rhs_ptr4 += floatVectorSize; - } + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; - storeAccumulator(row, col , res, alpha, &acc ); - storeAccumulator(row, col + 1*accCols, res, alpha, &acc2); - storeAccumulator(row, col + 2*accCols, res, alpha, &acc3); - storeAccumulator(row, col + 3*accCols, res, alpha, &acc4); - } - for(; col + accCols <= cols; col += accCols){ - const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize; - const float *lhs_ptr = lhs_base; - - __vector_quad acc; - __builtin_mma_xxsetaccz(&acc); - for(int k = 0; k < depth; k++) - { - __vector float lhsV = *((__vector float *)lhs_ptr); - __vector float rhsV = *((__vector float *)rhs_ptr); - - __builtin_mma_xvf32gerpp(&acc, (__vector unsigned char) rhsV, (__vector unsigned char) lhsV); - - lhs_ptr += floatVectorSize; - rhs_ptr += floatVectorSize; - } + gemm(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } - storeAccumulator(row, col, res, alpha, &acc); - } - - if(remaining_cols > 0) - { - const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize; - const float *lhs_ptr = lhs_base; - for(int k = 0; k < depth; k++) - { - for(int arow = 0; arow < accRows; arow++) - { - for(int acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += floatVectorSize; - } - } - } +template +struct gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; - if(remaining_rows > 0) - { - const float *rhs_base = blockB; - const float *lhs_base = blockA + (row/accRows)*depth*floatVectorSize; + void operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; - int col; - for(col = 0; col + accCols <= cols; col += accCols) - { - const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize; - const float *lhs_ptr = lhs_base; - for(int k = 0; k < depth; k++) - { - for(int arow = 0; arow < remaining_rows; arow++) - { - for(int acol = 0; acol < accCols; acol++ ) - { - res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += floatVectorSize; - lhs_ptr += remaining_rows; - } - } - - if(remaining_cols > 0) - { - const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize; - const float *lhs_ptr = lhs_base; - for(int k = 0; k < depth; k++) - { - for(int arow = 0; arow < remaining_rows; arow++) - { - for(int acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; - } - } - } +template +void gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } + +template +struct gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; + + void operator()(const DataMapper& res, const float* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const float* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } + +template +struct gebp_kernel, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const float* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const float* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } + +template +struct gebp_kernel +{ + typedef typename quad_traits::vectortype Packet; + typedef typename quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const double* blockA, const double* blockB, + Index rows, Index depth, Index cols, double alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel + ::operator()(const DataMapper& res, const double* blockA, const double* blockB, + Index rows, Index depth, Index cols, double alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } + +template +struct gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } + +template +struct gebp_kernel, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const double* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const double* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); } +template +struct gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const double* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const double* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const int accRows = quad_traits::rows; + const int accCols = quad_traits::size; + + gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + } } // end namespace internal } // end namespace Eigen - -#endif // __MMA__ -#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H +#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H \ No newline at end of file diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 01e647f17..a90e57446 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -391,6 +391,77 @@ public: return pgather(&operator()(i, j), m_stride); } + // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the Complex types. + template + struct storePacketBlock_helper + { + storePacketBlock_helper spbh; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper* sup, Index i, Index j, const PacketBlock& block) const { + spbh.store(sup, i,j,block); + for(int l = 0; l < unpacket_traits::size; l++) + { + ScalarT *v = &sup->operator()(i+l, j+idx); + *v = block.packet[idx][l]; + } + } + }; + + template + struct storePacketBlock_helper, n, idx> + { + storePacketBlock_helper, n, idx-1> spbh; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper* sup, Index i, Index j, const PacketBlock& block) const { + spbh.store(sup,i,j,block); + for(int l = 0; l < unpacket_traits::size; l++) + { + std::complex *v = &sup->operator()(i+l, j+idx); + v->real(block.packet[idx].v[2*l+0]); + v->imag(block.packet[idx].v[2*l+1]); + } + } + }; + + template + struct storePacketBlock_helper, n, idx> + { + storePacketBlock_helper, n, idx-1> spbh; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper* sup, Index i, Index j, const PacketBlock& block) const { + spbh.store(sup,i,j,block); + for(int l = 0; l < unpacket_traits::size; l++) + { + std::complex *v = &sup->operator()(i+l, j+idx); + v->real(block.packet[idx].v[2*l+0]); + v->imag(block.packet[idx].v[2*l+1]); + } + } + }; + + template + struct storePacketBlock_helper + { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper*, Index, Index, const PacketBlock& ) const { + } + }; + + template + struct storePacketBlock_helper, n, -1> + { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper*, Index, Index, const PacketBlock& ) const { + } + }; + + template + struct storePacketBlock_helper, n, -1> + { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper*, Index, Index, const PacketBlock& ) const { + } + }; + // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be avoided when possible. + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock&block) const { + storePacketBlock_helper spb; + spb.store(this, i,j,block); + } protected: Scalar* EIGEN_RESTRICT m_data; const Index m_stride; diff --git a/test/blasutil.cpp b/test/blasutil.cpp index 9caacfbab..01942918b 100644 --- a/test/blasutil.cpp +++ b/test/blasutil.cpp @@ -200,5 +200,7 @@ EIGEN_DECLARE_TEST(blasutil) CALL_SUBTEST_5(run_test()); CALL_SUBTEST_6(run_test()); + CALL_SUBTEST_7(run_test >()); + CALL_SUBTEST_8(run_test >()); } } -- cgit v1.2.3