diff options
author | Chip Kerchner <chip.kerchner@ibm.com> | 2021-03-25 11:08:19 +0000 |
---|---|---|
committer | David Tellenbach <david.tellenbach@me.com> | 2021-03-25 11:08:19 +0000 |
commit | d59ef212e14012250127a244df1484f626d39e42 (patch) | |
tree | 92328c42d2b9214778bb1ac3ea6dbd9e7b45a0b9 /Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h | |
parent | e7b8643d70dfbb02ad94186169a8f16041f05bc2 (diff) |
Fixed performance issues for complex VSX and P10 MMA in gebp_kernel (level 3).
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h')
-rw-r--r-- | Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h | 133 |
1 files changed, 114 insertions, 19 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index a1799c061..024767868 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -1,3 +1,10 @@ +//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines +#ifdef EIGEN_POWER_USE_PREFETCH +#define EIGEN_POWER_PREFETCH(p) prefetch(p) +#else +#define EIGEN_POWER_PREFETCH(p) +#endif + namespace Eigen { namespace internal { @@ -5,8 +12,8 @@ namespace internal { template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows> EIGEN_STRONG_INLINE void gemm_extra_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -16,16 +23,17 @@ EIGEN_STRONG_INLINE void gemm_extra_col( Index remaining_cols, const Packet& pAlpha); -template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows> +template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols> EIGEN_STRONG_INLINE void gemm_extra_row( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, Index row, Index col, + Index rows, Index cols, Index remaining_rows, const Packet& pAlpha, @@ -34,8 +42,8 @@ EIGEN_STRONG_INLINE void gemm_extra_row( template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols> EIGEN_STRONG_INLINE void gemm_unrolled_col( const DataMapper& res, - const Scalar *lhs_base, - const Scalar *rhs_base, + const Scalar* lhs_base, + const Scalar* rhs_base, Index depth, Index strideA, Index offsetA, @@ -48,6 +56,71 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col( template<typename Packet> EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows); +template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> +EIGEN_STRONG_INLINE void gemm_complex_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> +EIGEN_STRONG_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask); + +template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template<typename Scalar, typename Packet> +EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar* lhs); + +template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder> +EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col); + +template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder> +EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col); + +template<typename Packet> +EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha); + +template<typename Packet, int N> +EIGEN_STRONG_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag); + const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, @@ -68,7 +141,7 @@ const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14 // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. template<typename Packet, typename Packetc> -EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2) +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2) { acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); @@ -79,6 +152,12 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Pa acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); +} + +template<typename Packet, typename Packetc> +EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2) +{ + bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2); acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]); acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]); @@ -91,8 +170,26 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Pa acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]); } +template<typename Packet, typename Packetc> +EIGEN_STRONG_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); +} + +template<typename Packet, typename Packetc> +EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2) +{ + bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2); + + acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]); + + acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]); +} + template<> -EIGEN_STRONG_INLINE void bcouple<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd,8>& tRes, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2) +EIGEN_STRONG_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2) { acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); @@ -103,23 +200,21 @@ EIGEN_STRONG_INLINE void bcouple<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& t acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); +} - acc1.packet[0] = padd<Packet1cd>(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd<Packet1cd>(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd<Packet1cd>(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd<Packet1cd>(tRes.packet[3], acc1.packet[3]); +template<> +EIGEN_STRONG_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); - acc2.packet[0] = padd<Packet1cd>(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd<Packet1cd>(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd<Packet1cd>(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd<Packet1cd>(tRes.packet[7], acc2.packet[3]); + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); } // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. template<typename Scalar, typename Packet> -EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs) +EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar* rhs) { - return *((Packet *)rhs); + return *((Packet *)rhs); } } // end namespace internal |