aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
diff options
context:
space:
mode:
authorGravatar Chip Kerchner <chip.kerchner@ibm.com>2021-03-25 11:08:19 +0000
committerGravatar David Tellenbach <david.tellenbach@me.com>2021-03-25 11:08:19 +0000
commitd59ef212e14012250127a244df1484f626d39e42 (patch)
tree92328c42d2b9214778bb1ac3ea6dbd9e7b45a0b9 /Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
parente7b8643d70dfbb02ad94186169a8f16041f05bc2 (diff)
Fixed performance issues for complex VSX and P10 MMA in gebp_kernel (level 3).
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h')
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h133
1 files changed, 114 insertions, 19 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
index a1799c061..024767868 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -1,3 +1,10 @@
+//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines
+#ifdef EIGEN_POWER_USE_PREFETCH
+#define EIGEN_POWER_PREFETCH(p) prefetch(p)
+#else
+#define EIGEN_POWER_PREFETCH(p)
+#endif
+
namespace Eigen {
namespace internal {
@@ -5,8 +12,8 @@ namespace internal {
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
EIGEN_STRONG_INLINE void gemm_extra_col(
const DataMapper& res,
- const Scalar *lhs_base,
- const Scalar *rhs_base,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
@@ -16,16 +23,17 @@ EIGEN_STRONG_INLINE void gemm_extra_col(
Index remaining_cols,
const Packet& pAlpha);
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_STRONG_INLINE void gemm_extra_row(
const DataMapper& res,
- const Scalar *lhs_base,
- const Scalar *rhs_base,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index row,
Index col,
+ Index rows,
Index cols,
Index remaining_rows,
const Packet& pAlpha,
@@ -34,8 +42,8 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
EIGEN_STRONG_INLINE void gemm_unrolled_col(
const DataMapper& res,
- const Scalar *lhs_base,
- const Scalar *rhs_base,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
@@ -48,6 +56,71 @@ EIGEN_STRONG_INLINE void gemm_unrolled_col(
template<typename Packet>
EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows);
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet>
+EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar* lhs);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename Packet>
+EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
+
+template<typename Packet, int N>
+EIGEN_STRONG_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
+
const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
16, 17, 18, 19,
4, 5, 6, 7,
@@ -68,7 +141,7 @@ const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14
// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
template<typename Packet, typename Packetc>
-EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+EIGEN_STRONG_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
@@ -79,6 +152,12 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Pa
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
@@ -91,8 +170,26 @@ EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Pa
acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
}
+template<typename Packet, typename Packetc>
+EIGEN_STRONG_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
+}
+
template<>
-EIGEN_STRONG_INLINE void bcouple<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd,8>& tRes, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
+EIGEN_STRONG_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
@@ -103,23 +200,21 @@ EIGEN_STRONG_INLINE void bcouple<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& t
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
+}
- acc1.packet[0] = padd<Packet1cd>(tRes.packet[0], acc1.packet[0]);
- acc1.packet[1] = padd<Packet1cd>(tRes.packet[1], acc1.packet[1]);
- acc1.packet[2] = padd<Packet1cd>(tRes.packet[2], acc1.packet[2]);
- acc1.packet[3] = padd<Packet1cd>(tRes.packet[3], acc1.packet[3]);
+template<>
+EIGEN_STRONG_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
- acc2.packet[0] = padd<Packet1cd>(tRes.packet[4], acc2.packet[0]);
- acc2.packet[1] = padd<Packet1cd>(tRes.packet[5], acc2.packet[1]);
- acc2.packet[2] = padd<Packet1cd>(tRes.packet[6], acc2.packet[2]);
- acc2.packet[3] = padd<Packet1cd>(tRes.packet[7], acc2.packet[3]);
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
}
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
template<typename Scalar, typename Packet>
-EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs)
+EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar* rhs)
{
- return *((Packet *)rhs);
+ return *((Packet *)rhs);
}
} // end namespace internal