aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Everton Constantino <everton.constantino@ibm.com>2020-05-20 14:01:02 -0300
committerGravatar Everton Constantino <everton.constantino@ibm.com>2020-09-02 18:21:36 -0300
commit6fe88a3c9db27c00a3817e391cf70116451bf046 (patch)
tree4d62e610f6fdb0c3f5f571f70cf1f984fbeff907
parent6568856275de8bfcdd74e1de8fdf8656aca5ddb4 (diff)
MatrixProuct enhancements:
- Changes to Altivec/MatrixProduct Adapting code to gcc 10. Generic code style and performance enhancements. Adding PanelMode support. Adding stride/offset support. Enabling float64, std::complex and std::complex. Fixing lack of symm_pack. Enabling mixedtypes. - Adding std::complex tests to blasutil. - Adding an implementation of storePacketBlock when Incr!= 1.
-rw-r--r--Eigen/Core2
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProduct.h3114
-rwxr-xr-xEigen/src/Core/util/BlasUtil.h71
-rw-r--r--test/blasutil.cpp2
4 files changed, 2973 insertions, 216 deletions
diff --git a/Eigen/Core b/Eigen/Core
index f44b77831..7d1bdd6e8 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -335,7 +335,7 @@ using std::ptrdiff_t;
#include "src/Core/CoreIterators.h"
#include "src/Core/ConditionEstimator.h"
-#if EIGEN_ARCH_PPC
+#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
#include "src/Core/arch/AltiVec/MatrixProduct.h"
#endif
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 3bfbfdc87..57227e23b 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -10,302 +10,2986 @@
#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
-#ifdef __MMA__
-
+/**************************************************************************************************
+ * TODO *
+ * - Check StorageOrder on lhs_pack (the innermost second loop seems unvectorized when it could). *
+ * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
+ * - Check if change conjugation to xor instead of mul gains any performance. *
+ * - Remove IsComplex template argument from complex packing. *
+ **************************************************************************************************/
namespace Eigen {
namespace internal {
-const int accRows = 4;
-const int accCols = 4;
-const int accCount = 4;
-const int floatVectorSize = 4;
+/**************************
+ * Constants and typedefs *
+ **************************/
+const int QuadRegisterCount = 8;
+
+#ifdef __MMA__
-typedef struct
+template<typename Packet>
+union Packetx2u
{
- __vector float v0;
- __vector float v1;
- __vector float v2;
- __vector float v3;
-} Packet4fx4;
+ __vector_pair vectorpair;
+ PacketBlock<Packet, 2> pair;
+};
+
+#endif
-union PacketQuad
+
+template<typename Scalar>
+struct quad_traits
{
- __struct_quad sc;
- Packet4fx4 sf;
+ typedef typename packet_traits<Scalar>::type vectortype;
+ typedef PacketBlock<vectortype, 4> type;
+ typedef vectortype rhstype;
+ enum
+ {
+ vectorsize = packet_traits<Scalar>::size,
+ size = 4,
+ rows = 4
+ };
};
-template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
-struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+template<>
+struct quad_traits<double>
{
- void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+ typedef Packet2d vectortype;
+ typedef PacketBlock<vectortype, 4> type;
+ typedef PacketBlock<Packet2d,2> rhstype;
+ enum
+ {
+ vectorsize = packet_traits<double>::size,
+ size = 2,
+ rows = 4
+ };
};
-template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
-void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
- ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
+// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
+// are responsible to extract from convert between Eigen's and MatrixProduct approach.
+const static Packet4f p4f_CONJUGATE = {-1.0f, -1.0f, -1.0f, -1.0f};
+
+const static Packet2d p2d_CONJUGATE = {-1.0f, -1.0f};
+
+const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
+ 8, 9, 10, 11,
+ 16, 17, 18, 19,
+ 24, 25, 26, 27};
+
+const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
+ 12, 13, 14, 15,
+ 20, 21, 22, 23,
+ 28, 29, 30, 31};
+
+const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
+ 16, 17, 18, 19,
+ 4, 5, 6, 7,
+ 20, 21, 22, 23};
+
+const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
+ 24, 25, 26, 27,
+ 12, 13, 14, 15,
+ 28, 29, 30, 31};
+//[a,ai],[b,bi] = [a,b]
+const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,ai],[b,bi] = [ai,bi]
+const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
+const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
+const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+/*********************************************
+ * Single precision real and complex packing *
+ * *******************************************/
+
+/**
+ * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves
+ * the diagonal real, whatever is below it is copied from the respective upper diagonal element and
+ * conjugated. There's no PanelMode available for symm packing.
+ *
+ * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using
+ * it's respective rank-update instructions. The float32/64 versions are different because at this moment
+ * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements.
+ *
+ * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has
+ * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main
+ * reason why packing for complex is broken down into several different parts, also the reason why we endup having a
+ * float32/64 and complex float32/64 version.
+ **/
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
+{
+ std::complex<Scalar> v;
+ if(i < j)
+ {
+ v.real(dt(j,i).real());
+ v.imag(-dt(j,i).imag());
+ } else if(i > j)
+ {
+ v.real(dt(i,j).real());
+ v.imag(dt(i,j).imag());
+ } else {
+ v.real(dt(i,j).real());
+ v.imag((Scalar)0.0f);
+ }
+ return v;
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar> *blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const int vectorSize = N*quad_traits<Scalar>::vectorsize;
+ Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
+
+ Index ri = 0, j = 0;
+ for(; j + vectorSize < cols; j+=vectorSize)
+ {
+ Index i = k2;
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
+ blockBf[ri + k] = v.real();
+ }
+ ri += vectorSize;
+ }
+
+ i = k2;
+
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
+ blockBf[ri + k] = v.imag();
+ }
+ ri += vectorSize;
+ }
+ }
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+ blockBf[ri] = v.real();
+ ri += 1;
+ }
+ }
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+ blockBf[ri] = v.imag();
+ ri += 1;
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar> *blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
{
- int ri = 0, j;
- for(j = 0; j + floatVectorSize < rows; j+=floatVectorSize)
+ const Index depth = cols;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+ Scalar *blockAf = (Scalar *)(blockA);
+
+ for(; j + vectorSize < rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ for(; i < depth; i++)
+ {
+ for(int k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
+ blockAf[ri + k] = v.real();
+ }
+ ri += vectorSize;
+ }
+ i = 0;
+ for(; i < depth; i++)
+ {
+ for(int k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
+ blockAf[ri + k] = v.imag();
+ }
+ ri += vectorSize;
+ }
+ }
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
+ blockAf[ri] = v.real();
+ ri += 1;
+ }
+ }
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
+ blockAf[ri] = v.imag();
+ ri += 1;
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+
+ Index ri = 0, j = 0;
+ for(; j + N*vectorSize < cols; j+=N*vectorSize)
+ {
+ Index i = k2;
+ for(; i < depth; i++)
+ {
+ for(int k = 0; k < N*vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockB[ri + k] = rhs(j+k, i);
+ else
+ blockB[ri + k] = rhs(i, j+k);
+ }
+ ri += N*vectorSize;
+ }
+ }
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ if(k <= i)
+ blockB[ri] = rhs(i, k);
+ else
+ blockB[ri] = rhs(k, i);
+ ri += 1;
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
+{
+ const Index depth = cols;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ for(; i < depth; i++)
+ {
+ for(int k = 0; k < vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockA[ri + k] = lhs(j+k, i);
+ else
+ blockA[ri + k] = lhs(i, j+k);
+ }
+ ri += vectorSize;
+ }
+ }
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if(i <= k)
+ blockA[ri] = lhs(k, i);
+ else
+ blockA[ri] = lhs(i, k);
+ ri += 1;
+ }
+ }
+}
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack std::complex<float64> ***********
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float32 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<float, Index, nr, StorageOrder>
+{
+ void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float64 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<double, Index, nr, StorageOrder>
+{
+ void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+/**
+ * PanelMode
+ * Packing might be called several times before being multiplied by gebp_kernel, this happens because
+ * on special occasions it fills part of block with other parts of the matrix. Two variables control
+ * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever
+ * is going to be the real offset and stride in the future and this is what you should obey. The process
+ * is to behave as you would with normal packing but leave the start of each part with the correct offset
+ * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride
+ * and offset and behaves accordingly.
+ **/
+
+// General template for lhs complex packing.
+template<typename Scalar, bool IsComplex, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct lhs_cpack {
+ EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+ Scalar *blockAt = reinterpret_cast<Scalar *>(blockA);
+ Packet conj = pset1<Packet>((Scalar)-1.0f);
+
+ for(j = 0; j + vectorSize < rows; j+=vectorSize)
{
- int i;
- for(i = 0; i + floatVectorSize < depth; i+=floatVectorSize)
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet, 4> block;
+
+ PacketBlock<PacketC, 8> cblock;
+ if(StorageOrder == ColMajor)
{
- PacketBlock<Packet4f, 4> block;
- block.packet[0] = lhs.template loadPacket<Packet4f>(j, i + 0);
- block.packet[1] = lhs.template loadPacket<Packet4f>(j, i + 1);
- block.packet[2] = lhs.template loadPacket<Packet4f>(j, i + 2);
- block.packet[3] = lhs.template loadPacket<Packet4f>(j, i + 3);
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1);
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j, i + 2);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j, i + 3);
- pstore<float>((float *)(blockA + ri ), block.packet[0]);
- pstore<float>((float *)(blockA + ri + 4), block.packet[1]);
- pstore<float>((float *)(blockA + ri + 8), block.packet[2]);
- pstore<float>((float *)(blockA + ri + 12), block.packet[3]);
- ri += 4*floatVectorSize;
+ cblock.packet[4] = lhs.template loadPacket<PacketC>(j + 2, i + 0);
+ cblock.packet[5] = lhs.template loadPacket<PacketC>(j + 2, i + 1);
+ cblock.packet[6] = lhs.template loadPacket<PacketC>(j + 2, i + 2);
+ cblock.packet[7] = lhs.template loadPacket<PacketC>(j + 2, i + 3);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 2, i);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 3, i);
+
+ cblock.packet[4] = lhs.template loadPacket<PacketC>(j + 0, i + 2);
+ cblock.packet[5] = lhs.template loadPacket<PacketC>(j + 1, i + 2);
+ cblock.packet[6] = lhs.template loadPacket<PacketC>(j + 2, i + 2);
+ cblock.packet[7] = lhs.template loadPacket<PacketC>(j + 3, i + 2);
}
- for(; i < depth; i++)
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32);
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32);
+ block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32);
+ block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32);
+
+ if(StorageOrder == RowMajor) ptranspose(block);
+
+ pstore<Scalar>(blockAt + ri , block.packet[0]);
+ pstore<Scalar>(blockAt + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockAt + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockAt + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ blockAt[ri + 0] = lhs(j + 0, i).real();
+ blockAt[ri + 1] = lhs(j + 1, i).real();
+ blockAt[ri + 2] = lhs(j + 2, i).real();
+ blockAt[ri + 3] = lhs(j + 3, i).real();
+
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+
+ i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<PacketC, 8> cblock;
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1);
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j, i + 2);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j, i + 3);
+
+ cblock.packet[4] = lhs.template loadPacket<PacketC>(j + 2, i + 0);
+ cblock.packet[5] = lhs.template loadPacket<PacketC>(j + 2, i + 1);
+ cblock.packet[6] = lhs.template loadPacket<PacketC>(j + 2, i + 2);
+ cblock.packet[7] = lhs.template loadPacket<PacketC>(j + 2, i + 3);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 2, i);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 3, i);
+
+ cblock.packet[4] = lhs.template loadPacket<PacketC>(j + 0, i + 2);
+ cblock.packet[5] = lhs.template loadPacket<PacketC>(j + 1, i + 2);
+ cblock.packet[6] = lhs.template loadPacket<PacketC>(j + 2, i + 2);
+ cblock.packet[7] = lhs.template loadPacket<PacketC>(j + 3, i + 2);
+ }
+
+ PacketBlock<Packet, 4> block;
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32);
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32);
+ block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32);
+ block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32);
+
+ if(Conjugate)
+ {
+ block.packet[0] *= conj;
+ block.packet[1] *= conj;
+ block.packet[2] *= conj;
+ block.packet[3] *= conj;
+ }
+
+ if(StorageOrder == RowMajor) ptranspose(block);
+
+ pstore<Scalar>(blockAt + ri , block.packet[0]);
+ pstore<Scalar>(blockAt + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockAt + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockAt + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(Conjugate)
{
- Packet4f lhsV = lhs.template loadPacket<Packet4f>(j, i);
- pstore<float>((float *)(blockA + ri), lhsV);
- ri += floatVectorSize;
+ blockAt[ri + 0] = -lhs(j + 0, i).imag();
+ blockAt[ri + 1] = -lhs(j + 1, i).imag();
+ blockAt[ri + 2] = -lhs(j + 2, i).imag();
+ blockAt[ri + 3] = -lhs(j + 3, i).imag();
+ } else {
+ blockAt[ri + 0] = lhs(j + 0, i).imag();
+ blockAt[ri + 1] = lhs(j + 1, i).imag();
+ blockAt[ri + 2] = lhs(j + 2, i).imag();
+ blockAt[ri + 3] = lhs(j + 3, i).imag();
}
+
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockAt[ri] = lhs(k, i).real();
+ ri += 1;
+ }
+ }
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if(Conjugate)
+ blockAt[ri] = -lhs(k, i).imag();
+ else
+ blockAt[ri] = lhs(k, i).imag();
+ ri += 1;
+ }
}
- for(int i = 0; i < depth; i++)
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+ }
+};
+
+// General template for lhs packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode>
+struct lhs_pack{
+ EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(j = 0; j + vectorSize < rows; j+=vectorSize)
{
- int k = j;
- for(; k < rows; k++)
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet, 4> block;
+
+ if(StorageOrder == RowMajor)
{
- blockA[ri] = lhs(k, i);
- ri += 1;
+ block.packet[0] = lhs.template loadPacket<Packet>(j + 0, i);
+ block.packet[1] = lhs.template loadPacket<Packet>(j + 1, i);
+ block.packet[2] = lhs.template loadPacket<Packet>(j + 2, i);
+ block.packet[3] = lhs.template loadPacket<Packet>(j + 3, i);
+
+ ptranspose(block);
+ } else {
+ block.packet[0] = lhs.template loadPacket<Packet>(j, i + 0);
+ block.packet[1] = lhs.template loadPacket<Packet>(j, i + 1);
+ block.packet[2] = lhs.template loadPacket<Packet>(j, i + 2);
+ block.packet[3] = lhs.template loadPacket<Packet>(j, i + 3);
}
+
+ pstore<Scalar>(blockA + ri , block.packet[0]);
+ pstore<Scalar>(blockA + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockA + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockA + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == RowMajor)
+ {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ blockA[ri+2] = lhs(j+2, i);
+ blockA[ri+3] = lhs(j+3, i);
+ } else {
+ Packet lhsV = lhs.template loadPacket<Packet>(j, i);
+ pstore<Scalar>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
}
-}
-template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
-struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
-{
- void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockA[ri] = lhs(k, i);
+ ri += 1;
+ }
+ }
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+ }
};
-template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
-void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
- ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+// General template for rhs complex packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct rhs_cpack
{
- int ri = 0, j;
- for(j = 0; j + floatVectorSize < cols; j+=floatVectorSize)
+ EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Scalar *blockBt = reinterpret_cast<Scalar *>(blockB);
+ Packet conj = pset1<Packet>((Scalar)-1.0f);
+
+ Index ri = 0, j = 0;
+ for(; j + vectorSize < cols; j+=vectorSize)
{
- int i;
- for(i = 0; i + floatVectorSize < depth; i+=floatVectorSize)
+ Index i = 0;
+
+ if(PanelMode) ri += offset*vectorSize;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<PacketC, 8> cblock;
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i, j + 0);
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i, j + 1);
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2);
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3);
+
+ cblock.packet[4] = rhs.template loadPacket<PacketC>(i + 2, j + 0);
+ cblock.packet[5] = rhs.template loadPacket<PacketC>(i + 2, j + 1);
+ cblock.packet[6] = rhs.template loadPacket<PacketC>(i + 2, j + 2);
+ cblock.packet[7] = rhs.template loadPacket<PacketC>(i + 2, j + 3);
+ } else {
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i + 0, j);
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i + 1, j);
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i + 2, j);
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i + 3, j);
+
+ cblock.packet[4] = rhs.template loadPacket<PacketC>(i + 0, j + 2);
+ cblock.packet[5] = rhs.template loadPacket<PacketC>(i + 1, j + 2);
+ cblock.packet[6] = rhs.template loadPacket<PacketC>(i + 2, j + 2);
+ cblock.packet[7] = rhs.template loadPacket<PacketC>(i + 3, j + 2);
+ }
+
+ PacketBlock<Packet, 4> block;
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32);
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32);
+ block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32);
+ block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32);
+
+ if(StorageOrder == ColMajor) ptranspose(block);
+
+ pstore<Scalar>(blockBt + ri , block.packet[0]);
+ pstore<Scalar>(blockBt + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockBt + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockBt + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ blockBt[ri+0] = rhs(i, j+0).real();
+ blockBt[ri+1] = rhs(i, j+1).real();
+ blockBt[ri+2] = rhs(i, j+2).real();
+ blockBt[ri+3] = rhs(i, j+3).real();
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+
+ i = 0;
+
+ if(PanelMode) ri += offset*vectorSize;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
{
- PacketBlock<Packet4f, 4> block;
- block.packet[0] = rhs.template loadPacket<Packet4f>(i, j + 0);
- block.packet[1] = rhs.template loadPacket<Packet4f>(i, j + 1);
- block.packet[2] = rhs.template loadPacket<Packet4f>(i, j + 2);
- block.packet[3] = rhs.template loadPacket<Packet4f>(i, j + 3);
+ PacketBlock<PacketC, 8> cblock;
+ if(StorageOrder == ColMajor)
+ {
+
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i, j + 0);
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i, j + 1);
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2);
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3);
+
+ cblock.packet[4] = rhs.template loadPacket<PacketC>(i + 2, j + 0);
+ cblock.packet[5] = rhs.template loadPacket<PacketC>(i + 2, j + 1);
+ cblock.packet[6] = rhs.template loadPacket<PacketC>(i + 2, j + 2);
+ cblock.packet[7] = rhs.template loadPacket<PacketC>(i + 2, j + 3);
+ } else {
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i + 0, j);
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i + 1, j);
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i + 2, j);
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i + 3, j);
- ptranspose(block);
+ cblock.packet[4] = rhs.template loadPacket<PacketC>(i + 0, j + 2);
+ cblock.packet[5] = rhs.template loadPacket<PacketC>(i + 1, j + 2);
+ cblock.packet[6] = rhs.template loadPacket<PacketC>(i + 2, j + 2);
+ cblock.packet[7] = rhs.template loadPacket<PacketC>(i + 3, j + 2);
+ }
+
+ PacketBlock<Packet, 4> block;
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32);
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32);
+ block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32);
+ block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32);
- pstore<float>((float *)(blockB + ri ), block.packet[0]);
- pstore<float>((float *)(blockB + ri + 4), block.packet[1]);
- pstore<float>((float *)(blockB + ri + 8), block.packet[2]);
- pstore<float>((float *)(blockB + ri + 12), block.packet[3]);
+ if(Conjugate)
+ {
+ block.packet[0] *= conj;
+ block.packet[1] *= conj;
+ block.packet[2] *= conj;
+ block.packet[3] *= conj;
+ }
- ri += 4*floatVectorSize;
+ if(StorageOrder == ColMajor) ptranspose(block);
+
+ pstore<Scalar>(blockBt + ri , block.packet[0]);
+ pstore<Scalar>(blockBt + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockBt + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockBt + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
}
for(; i < depth; i++)
{
- blockB[ri+0] = rhs(i, j+0);
- blockB[ri+1] = rhs(i, j+1);
- blockB[ri+2] = rhs(i, j+2);
- blockB[ri+3] = rhs(i, j+3);
- ri += floatVectorSize;
+ if(Conjugate)
+ {
+ blockBt[ri+0] = -rhs(i, j+0).imag();
+ blockBt[ri+1] = -rhs(i, j+1).imag();
+ blockBt[ri+2] = -rhs(i, j+2).imag();
+ blockBt[ri+3] = -rhs(i, j+3).imag();
+ } else {
+ blockBt[ri+0] = rhs(i, j+0).imag();
+ blockBt[ri+1] = rhs(i, j+1).imag();
+ blockBt[ri+2] = rhs(i, j+2).imag();
+ blockBt[ri+3] = rhs(i, j+3).imag();
+ }
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockBt[ri] = rhs(i, k).real();
+ ri += 1;
}
}
- for(int i = 0; i < depth; i++)
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
{
- int k = j;
+ Index k = j;
for(; k < cols; k++)
{
- blockB[ri] = rhs(i, k);
+ if(Conjugate)
+ blockBt[ri] = -rhs(i, k).imag();
+ else
+ blockBt[ri] = rhs(i, k).imag();
ri += 1;
}
}
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+ }
+};
+
+// General template for rhs packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode>
+struct rhs_pack {
+ EIGEN_STRONG_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+ for(; j + vectorSize < cols; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += offset*vectorSize;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet, 4> block;
+ if(StorageOrder == ColMajor)
+ {
+ block.packet[0] = rhs.template loadPacket<Packet>(i, j + 0);
+ block.packet[1] = rhs.template loadPacket<Packet>(i, j + 1);
+ block.packet[2] = rhs.template loadPacket<Packet>(i, j + 2);
+ block.packet[3] = rhs.template loadPacket<Packet>(i, j + 3);
+
+ ptranspose(block);
+ } else {
+ block.packet[0] = rhs.template loadPacket<Packet>(i + 0, j);
+ block.packet[1] = rhs.template loadPacket<Packet>(i + 1, j);
+ block.packet[2] = rhs.template loadPacket<Packet>(i + 2, j);
+ block.packet[3] = rhs.template loadPacket<Packet>(i + 3, j);
+ }
+
+ pstore<Scalar>(blockB + ri , block.packet[0]);
+ pstore<Scalar>(blockB + ri + 4, block.packet[1]);
+ pstore<Scalar>(blockB + ri + 8, block.packet[2]);
+ pstore<Scalar>(blockB + ri + 12, block.packet[3]);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == ColMajor)
+ {
+ blockB[ri+0] = rhs(i, j+0);
+ blockB[ri+1] = rhs(i, j+1);
+ blockB[ri+2] = rhs(i, j+2);
+ blockB[ri+3] = rhs(i, j+3);
+ } else {
+ Packet rhsV = rhs.template loadPacket<Packet>(i, j);
+ pstore<Scalar>(blockB + ri, rhsV);
+ }
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockB[ri] = rhs(i, k);
+ ri += 1;
+ }
+ }
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+ }
+};
+
+// General template for lhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct lhs_pack<double,Index, DataMapper, Packet2d, StorageOrder, PanelMode>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d, 2> block;
+ if(StorageOrder == RowMajor)
+ {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
+
+ ptranspose(block);
+ } else {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
+ }
+
+ pstore<double>(blockA + ri , block.packet[0]);
+ pstore<double>(blockA + ri + 2, block.packet[1]);
+
+ ri += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == RowMajor)
+ {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ } else {
+ Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
+ pstore<double>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockA[ri] = lhs(k, i);
+ ri += 1;
+ }
+ }
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+ }
+};
+
+// General template for rhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct rhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+ for(; j + 2*vectorSize < cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += offset*(2*vectorSize);
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d, 4> block;
+ if(StorageOrder == ColMajor)
+ {
+ PacketBlock<Packet2d, 2> block1, block2;
+ block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
+ block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
+ block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
+ block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
+
+ ptranspose(block1);
+ ptranspose(block2);
+
+ pstore<double>(blockB + ri , block1.packet[0]);
+ pstore<double>(blockB + ri + 2, block2.packet[0]);
+ pstore<double>(blockB + ri + 4, block1.packet[1]);
+ pstore<double>(blockB + ri + 6, block2.packet[1]);
+ } else {
+ block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
+ block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
+ block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
+ block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
+
+ pstore<double>(blockB + ri , block.packet[0]);
+ pstore<double>(blockB + ri + 2, block.packet[1]);
+ pstore<double>(blockB + ri + 4, block.packet[2]);
+ pstore<double>(blockB + ri + 6, block.packet[3]);
+ }
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == ColMajor)
+ {
+ blockB[ri+0] = rhs(i, j+0);
+ blockB[ri+1] = rhs(i, j+1);
+
+ ri += vectorSize;
+
+ blockB[ri+0] = rhs(i, j+2);
+ blockB[ri+1] = rhs(i, j+3);
+ } else {
+ Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
+ pstore<double>(blockB + ri, rhsV);
+
+ ri += vectorSize;
+
+ rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
+ pstore<double>(blockB + ri, rhsV);
+ }
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockB[ri] = rhs(i, k);
+ ri += 1;
+ }
+ }
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+ }
+};
+
+// General template for lhs complex packing, float64 specialization.
+template<bool IsComplex, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct lhs_cpack<double, IsComplex, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+ double *blockAt = reinterpret_cast<double *>(blockA);
+ Packet conj = pset1<Packet>((double)-1.0f);
+
+ for(j = 0; j + vectorSize < rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet, 2> block;
+
+ PacketBlock<PacketC, 4> cblock;
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
+ block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+ }
+
+ pstore<double>(blockAt + ri , block.packet[0]);
+ pstore<double>(blockAt + ri + 2, block.packet[1]);
+
+ ri += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ blockAt[ri + 0] = lhs(j + 0, i).real();
+ blockAt[ri + 1] = lhs(j + 1, i).real();
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+
+ i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize < depth; i+=vectorSize)
+ {
+ PacketBlock<Packet, 2> block;
+
+ PacketBlock<PacketC, 4> cblock;
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1);
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1);
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETIMAG64);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1);
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1);
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64);
+ }
+
+ if(Conjugate)
+ {
+ block.packet[0] *= conj;
+ block.packet[1] *= conj;
+ }
+
+ pstore<double>(blockAt + ri , block.packet[0]);
+ pstore<double>(blockAt + ri + 2, block.packet[1]);
+
+ ri += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(Conjugate)
+ {
+ blockAt[ri + 0] = -lhs(j + 0, i).imag();
+ blockAt[ri + 1] = -lhs(j + 1, i).imag();
+ } else {
+ blockAt[ri + 0] = lhs(j + 0, i).imag();
+ blockAt[ri + 1] = lhs(j + 1, i).imag();
+ }
+
+ ri += vectorSize;
+ }
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockAt[ri] = lhs(k, i).real();
+ ri += 1;
+ }
+ }
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if(Conjugate)
+ blockAt[ri] = -lhs(k, i).imag();
+ else
+ blockAt[ri] = lhs(k, i).imag();
+ ri += 1;
+ }
+ }
+
+ if(PanelMode) ri += (rows - j)*(stride - offset - depth);
+ }
+};
+
+// General template for rhs complex packing, float64 specialization.
+template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct rhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const int vectorSize = quad_traits<double>::vectorsize;
+ double *blockBt = reinterpret_cast<double *>(blockB);
+ Packet conj = pset1<Packet>((double)-1.0f);
+
+ Index ri = 0, j = 0;
+ for(; j + 2*vectorSize < cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += offset*(2*vectorSize);
+
+ for(; i < depth; i++)
+ {
+ PacketBlock<PacketC, 4> cblock;
+ PacketBlock<Packet, 2> block;
+
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i, j + 0);
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i, j + 1);
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2);
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3);
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64);
+ block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64);
+
+ pstore<double>(blockBt + ri , block.packet[0]);
+ pstore<double>(blockBt + ri + 2, block.packet[1]);
+
+ ri += 2*vectorSize;
+ }
+
+ if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
+
+ i = 0;
+
+ if(PanelMode) ri += offset*(2*vectorSize);
+
+ for(; i < depth; i++)
+ {
+ PacketBlock<PacketC, 4> cblock;
+ PacketBlock<Packet, 2> block;
+
+ cblock.packet[0] = rhs.template loadPacket<PacketC>(i, j + 0); //[a1 a1i]
+ cblock.packet[1] = rhs.template loadPacket<PacketC>(i, j + 1); //[b1 b1i]
+ cblock.packet[2] = rhs.template loadPacket<PacketC>(i, j + 2); //[c1 c1i]
+ cblock.packet[3] = rhs.template loadPacket<PacketC>(i, j + 3); //[d1 d1i]
+
+ block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64);
+ block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64);
+
+ if(Conjugate)
+ {
+ block.packet[0] *= conj;
+ block.packet[1] *= conj;
+ }
+
+ pstore<double>(blockBt + ri , block.packet[0]);
+ pstore<double>(blockBt + ri + 2, block.packet[1]);
+
+ ri += 2*vectorSize;
+ }
+ if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
+ }
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockBt[ri] = rhs(i, k).real();
+ ri += 1;
+ }
+ }
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ if(Conjugate)
+ blockBt[ri] = -rhs(i, k).imag();
+ else
+ blockBt[ri] = rhs(i, k).imag();
+ ri += 1;
+ }
+ }
+ if(PanelMode) ri += (cols - j)*(stride - offset - depth);
+ }
+};
+
+/**************
+ * GEMM utils *
+ **************/
+
+// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
+template<typename Packet, typename Packetc>
+EIGEN_STRONG_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+ acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
+ acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
+ acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
+ acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
+ acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
+ acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
}
-template<typename DataMapper, typename Index, typename Scalar>
-EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, Scalar alpha, __vector_quad *acc)
+template<>
+EIGEN_STRONG_INLINE void bcouple<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd,8>& tRes, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
{
- //[TODO]
- //
- //Packet4fx4 r;
- //
- //__builtin_mma_disassemble_acc((void *)&r, *acc);
- //
- PacketQuad result;
- result.sc = __builtin_mma_disassemble_acc(*acc);
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
- Packet4f pAlpha = pset1<Packet4f>(alpha);
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
- PacketBlock<Packet4f, 4> block;
- block.packet[0] = pAlpha*result.sf.v3;
- block.packet[1] = pAlpha*result.sf.v2;
- block.packet[2] = pAlpha*result.sf.v1;
- block.packet[3] = pAlpha*result.sf.v0;
+ acc1.packet[0] = padd<Packet1cd>(tRes.packet[0], acc1.packet[0]);
+ acc1.packet[1] = padd<Packet1cd>(tRes.packet[1], acc1.packet[1]);
+ acc1.packet[2] = padd<Packet1cd>(tRes.packet[2], acc1.packet[2]);
+ acc1.packet[3] = padd<Packet1cd>(tRes.packet[3], acc1.packet[3]);
- data.template storePacketBlock<Packet4f, 4>(i, j, block);
+ acc2.packet[0] = padd<Packet1cd>(tRes.packet[4], acc2.packet[0]);
+ acc2.packet[1] = padd<Packet1cd>(tRes.packet[5], acc2.packet[1]);
+ acc2.packet[2] = padd<Packet1cd>(tRes.packet[6], acc2.packet[2]);
+ acc2.packet[3] = padd<Packet1cd>(tRes.packet[7], acc2.packet[3]);
}
-template<typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
-struct gebp_kernel<float, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+#ifdef __MMA__
+template<typename Packet>
+EIGEN_STRONG_INLINE PacketBlock<Packet,2> pmul (const PacketBlock<Packet,2>& a, const Packet& b)
{
- void operator()(const DataMapper& res, const float* blockA, const RhsScalar* blockB,
- Index rows, Index depth, Index cols, float alpha,
- Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
-};
+ PacketBlock<Packet,2> pb;
+ pb.packet[0] = a.packet[0]*b;
+ pb.packet[1] = a.packet[1]*b;
+ return pb;
+}
+template<typename DataMapper, typename Index, typename Packet>
+EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad *acc)
+{
+ PacketBlock<Packet, 4> result;
+ __builtin_mma_disassemble_acc(&result.packet, acc);
-template<typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
-void gebp_kernel<float, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
- ::operator()(const DataMapper& res, const float* blockA, const RhsScalar* blockB,
- Index rows, Index depth, Index cols, float alpha,
- Index strideA, Index strideB, Index offsetA, Index offsetB)
+ PacketBlock<Packet, 4> block;
+ block.packet[0] = data.template loadPacket<Packet>(i, j + 0) + pmul<Packet>(alpha, result.packet[0]);
+ block.packet[1] = data.template loadPacket<Packet>(i, j + 1) + pmul<Packet>(alpha, result.packet[1]);
+ block.packet[2] = data.template loadPacket<Packet>(i, j + 2) + pmul<Packet>(alpha, result.packet[2]);
+ block.packet[3] = data.template loadPacket<Packet>(i, j + 3) + pmul<Packet>(alpha, result.packet[3]);
+
+ data.template storePacketBlock<Packet, 4>(i, j, block);
+}
+
+template<typename DataMapper, typename Index, typename Packet, typename Packetc, int N>
+EIGEN_STRONG_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad *accReal, __vector_quad *accImag, const int accColsC)
+{
+ PacketBlock<Packet, 4> resultReal, resultImag;
+ __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
+ __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
+
+ PacketBlock<Packet,4> taccReal, taccImag;
+ taccReal.packet[0] = pmul<Packet>(resultReal.packet[0], alphaReal);
+ taccReal.packet[1] = pmul<Packet>(resultReal.packet[1], alphaReal);
+ taccReal.packet[2] = pmul<Packet>(resultReal.packet[2], alphaReal);
+ taccReal.packet[3] = pmul<Packet>(resultReal.packet[3], alphaReal);
+
+ taccImag.packet[0] = pmul<Packet>(resultImag.packet[0], alphaReal);
+ taccImag.packet[1] = pmul<Packet>(resultImag.packet[1], alphaReal);
+ taccImag.packet[2] = pmul<Packet>(resultImag.packet[2], alphaReal);
+ taccImag.packet[3] = pmul<Packet>(resultImag.packet[3], alphaReal);
+
+ taccReal.packet[0] = psub<Packet>(taccReal.packet[0], pmul<Packet>(resultImag.packet[0], alphaImag));
+ taccReal.packet[1] = psub<Packet>(taccReal.packet[1], pmul<Packet>(resultImag.packet[1], alphaImag));
+ taccReal.packet[2] = psub<Packet>(taccReal.packet[2], pmul<Packet>(resultImag.packet[2], alphaImag));
+ taccReal.packet[3] = psub<Packet>(taccReal.packet[3], pmul<Packet>(resultImag.packet[3], alphaImag));
+
+ taccImag.packet[0] = pmadd<Packet>(resultReal.packet[0], alphaImag, taccImag.packet[0]);
+ taccImag.packet[1] = pmadd<Packet>(resultReal.packet[1], alphaImag, taccImag.packet[1]);
+ taccImag.packet[2] = pmadd<Packet>(resultReal.packet[2], alphaImag, taccImag.packet[2]);
+ taccImag.packet[3] = pmadd<Packet>(resultReal.packet[3], alphaImag, taccImag.packet[3]);
+
+ PacketBlock<Packetc, 8> tRes;
+ tRes.packet[0] = data.template loadPacket<Packetc>(i + N*accColsC, j + 0);
+ tRes.packet[1] = data.template loadPacket<Packetc>(i + N*accColsC, j + 1);
+ tRes.packet[2] = data.template loadPacket<Packetc>(i + N*accColsC, j + 2);
+ tRes.packet[3] = data.template loadPacket<Packetc>(i + N*accColsC, j + 3);
+
+ tRes.packet[4] = data.template loadPacket<Packetc>(i + (N+1)*accColsC, j + 0);
+ tRes.packet[5] = data.template loadPacket<Packetc>(i + (N+1)*accColsC, j + 1);
+ tRes.packet[6] = data.template loadPacket<Packetc>(i + (N+1)*accColsC, j + 2);
+ tRes.packet[7] = data.template loadPacket<Packetc>(i + (N+1)*accColsC, j + 3);
+
+ PacketBlock<Packetc, 4> acc1, acc2;
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
+
+ data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
+ data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
+}
+
+// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_STRONG_INLINE void pger(__vector_quad *acc, const RhsPacket& a, const LhsPacket& b)
+{
+ if(NegativeAccumulate)
{
- const int remaining_rows = rows % accRows;
- const int remaining_cols = cols % accCols;
- const int remaining_depth = depth % floatVectorSize;
- const int timesRows = (rows / accRows);
- const int timesCols = (cols / accCols);
-
- int row;
- for(row = 0; row + accRows <= rows; row += accRows)
- {
- const float *rhs_base = blockB;
- const float *lhs_base = blockA + (row/accRows)*depth*floatVectorSize;
-
- int col;
- for(col = 0; col + accCount*accCols <= cols; col += accCount*accCols){
- const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize;
- const float *rhs_ptr2 = rhs_base + ((col/accCols) + 1)*depth*floatVectorSize;
- const float *rhs_ptr3 = rhs_base + ((col/accCols) + 2)*depth*floatVectorSize;
- const float *rhs_ptr4 = rhs_base + ((col/accCols) + 3)*depth*floatVectorSize;
- const float *lhs_ptr = lhs_base;
-
- __vector_quad acc, acc2, acc3, acc4;
- __builtin_mma_xxsetaccz(&acc);
- __builtin_mma_xxsetaccz(&acc2);
- __builtin_mma_xxsetaccz(&acc3);
- __builtin_mma_xxsetaccz(&acc4);
-
- for(int k = 0; k < depth; k++)
- {
- __vector float lhsV = *((__vector float *)lhs_ptr );
- __vector float rhsV = *((__vector float *)rhs_ptr );
- __vector float rhs2V = *((__vector float *)rhs_ptr2);
- __vector float rhs3V = *((__vector float *)rhs_ptr3);
- __vector float rhs4V = *((__vector float *)rhs_ptr4);
-
- __builtin_mma_xvf32gerpp(&acc, (__vector unsigned char) rhsV, (__vector unsigned char) lhsV);
- __builtin_mma_xvf32gerpp(&acc2, (__vector unsigned char) rhs2V, (__vector unsigned char) lhsV);
- __builtin_mma_xvf32gerpp(&acc3, (__vector unsigned char) rhs3V, (__vector unsigned char) lhsV);
- __builtin_mma_xvf32gerpp(&acc4, (__vector unsigned char) rhs4V, (__vector unsigned char) lhsV);
-
- lhs_ptr += floatVectorSize;
- rhs_ptr += floatVectorSize;
- rhs_ptr2 += floatVectorSize;
- rhs_ptr3 += floatVectorSize;
- rhs_ptr4 += floatVectorSize;
- }
+ __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ }
+}
+
+template<>
+EIGEN_STRONG_INLINE void pger<Packet2d, PacketBlock<Packet2d, 2>, false>(__vector_quad *acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
+{
+ Packetx2u<Packet2d> p;
+ p.pair = a;
+ __builtin_mma_xvf64gerpp(acc, p.vectorpair, (__vector unsigned char)b);
+}
+
+template<>
+EIGEN_STRONG_INLINE void pger<Packet2d, PacketBlock<Packet2d, 2>, true>(__vector_quad *acc, const PacketBlock<Packet2d, 2>& a, const Packet2d& b)
+{
+ Packetx2u<Packet2d> p;
+ p.pair = a;
+ __builtin_mma_xvf64gernp(acc, p.vectorpair, (__vector unsigned char)b);
+}
+#else
+
+// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
+template<typename Scalar, typename Packet, bool NegativeAccumulate>
+EIGEN_STRONG_INLINE void pger(PacketBlock<Packet, 4> *acc, const Scalar* lhs, const Scalar* rhs)
+{
+ Packet lhsV = *((Packet *) lhs);
+ Packet rhsV1 = pset1<Packet>(rhs[0]);
+ Packet rhsV2 = pset1<Packet>(rhs[1]);
+ Packet rhsV3 = pset1<Packet>(rhs[2]);
+ Packet rhsV4 = pset1<Packet>(rhs[3]);
+
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] -= lhsV*rhsV1;
+ acc->packet[1] -= lhsV*rhsV2;
+ acc->packet[2] -= lhsV*rhsV3;
+ acc->packet[3] -= lhsV*rhsV4;
+ } else {
+ acc->packet[0] += lhsV*rhsV1;
+ acc->packet[1] += lhsV*rhsV2;
+ acc->packet[2] += lhsV*rhsV3;
+ acc->packet[3] += lhsV*rhsV4;
+ }
+}
+
+// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
+template<typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void pgerc(PacketBlock<Packet, 4>& accReal, PacketBlock<Packet,4>& accImag, const Scalar *rhs_ptr, const Scalar *rhs_ptr_imag, const Scalar *lhs_ptr, const Scalar* lhs_ptr_imag, Packet& conj)
+{
+ Packet lhsV = *((Packet *) lhs_ptr);
+ Packet rhsV1 = pset1<Packet>(rhs_ptr[0]);
+ Packet rhsV2 = pset1<Packet>(rhs_ptr[1]);
+ Packet rhsV3 = pset1<Packet>(rhs_ptr[2]);
+ Packet rhsV4 = pset1<Packet>(rhs_ptr[3]);
+
+ Packet lhsVi;
+ if(!LhsIsReal) lhsVi = *((Packet *) lhs_ptr_imag);
+ Packet rhsV1i, rhsV2i, rhsV3i, rhsV4i;
+ if(!RhsIsReal)
+ {
+ rhsV1i = pset1<Packet>(rhs_ptr_imag[0]);
+ rhsV2i = pset1<Packet>(rhs_ptr_imag[1]);
+ rhsV3i = pset1<Packet>(rhs_ptr_imag[2]);
+ rhsV4i = pset1<Packet>(rhs_ptr_imag[3]);
+ }
+
+ if(ConjugateLhs && !LhsIsReal) lhsVi = pmul<Packet>(lhsVi,conj);
+ if(ConjugateRhs && !RhsIsReal)
+ {
+ rhsV1i = pmul<Packet>(rhsV1i,conj);
+ rhsV2i = pmul<Packet>(rhsV2i,conj);
+ rhsV3i = pmul<Packet>(rhsV3i,conj);
+ rhsV4i = pmul<Packet>(rhsV4i,conj);
+ }
+
+ if(LhsIsReal)
+ {
+ accReal.packet[0] = pmadd<Packet>(rhsV1, lhsV, accReal.packet[0]);
+ accReal.packet[1] = pmadd<Packet>(rhsV2, lhsV, accReal.packet[1]);
+ accReal.packet[2] = pmadd<Packet>(rhsV3, lhsV, accReal.packet[2]);
+ accReal.packet[3] = pmadd<Packet>(rhsV4, lhsV, accReal.packet[3]);
+
+ accImag.packet[0] = pmadd<Packet>(rhsV1i, lhsV, accImag.packet[0]);
+ accImag.packet[1] = pmadd<Packet>(rhsV2i, lhsV, accImag.packet[1]);
+ accImag.packet[2] = pmadd<Packet>(rhsV3i, lhsV, accImag.packet[2]);
+ accImag.packet[3] = pmadd<Packet>(rhsV4i, lhsV, accImag.packet[3]);
+ } else if(RhsIsReal) {
+ accReal.packet[0] = pmadd<Packet>(rhsV1, lhsV, accReal.packet[0]);
+ accReal.packet[1] = pmadd<Packet>(rhsV2, lhsV, accReal.packet[1]);
+ accReal.packet[2] = pmadd<Packet>(rhsV3, lhsV, accReal.packet[2]);
+ accReal.packet[3] = pmadd<Packet>(rhsV4, lhsV, accReal.packet[3]);
+
+ accImag.packet[0] = pmadd<Packet>(rhsV1, lhsVi, accImag.packet[0]);
+ accImag.packet[1] = pmadd<Packet>(rhsV2, lhsVi, accImag.packet[1]);
+ accImag.packet[2] = pmadd<Packet>(rhsV3, lhsVi, accImag.packet[2]);
+ accImag.packet[3] = pmadd<Packet>(rhsV4, lhsVi, accImag.packet[3]);
+ } else {
+ accReal.packet[0] = pmadd<Packet>(rhsV1, lhsV, accReal.packet[0]);
+ accReal.packet[1] = pmadd<Packet>(rhsV2, lhsV, accReal.packet[1]);
+ accReal.packet[2] = pmadd<Packet>(rhsV3, lhsV, accReal.packet[2]);
+ accReal.packet[3] = pmadd<Packet>(rhsV4, lhsV, accReal.packet[3]);
+
+ accImag.packet[0] = pmadd<Packet>(rhsV1i, lhsV, accImag.packet[0]);
+ accImag.packet[1] = pmadd<Packet>(rhsV2i, lhsV, accImag.packet[1]);
+ accImag.packet[2] = pmadd<Packet>(rhsV3i, lhsV, accImag.packet[2]);
+ accImag.packet[3] = pmadd<Packet>(rhsV4i, lhsV, accImag.packet[3]);
+
+ accReal.packet[0] = psub<Packet>(accReal.packet[0], pmul<Packet>(rhsV1i, lhsVi));
+ accReal.packet[1] = psub<Packet>(accReal.packet[1], pmul<Packet>(rhsV2i, lhsVi));
+ accReal.packet[2] = psub<Packet>(accReal.packet[2], pmul<Packet>(rhsV3i, lhsVi));
+ accReal.packet[3] = psub<Packet>(accReal.packet[3], pmul<Packet>(rhsV4i, lhsVi));
- storeAccumulator<DataMapper, Index, float>(row, col , res, alpha, &acc );
- storeAccumulator<DataMapper, Index, float>(row, col + 1*accCols, res, alpha, &acc2);
- storeAccumulator<DataMapper, Index, float>(row, col + 2*accCols, res, alpha, &acc3);
- storeAccumulator<DataMapper, Index, float>(row, col + 3*accCols, res, alpha, &acc4);
+ accImag.packet[0] = pmadd<Packet>(rhsV1, lhsVi, accImag.packet[0]);
+ accImag.packet[1] = pmadd<Packet>(rhsV2, lhsVi, accImag.packet[1]);
+ accImag.packet[2] = pmadd<Packet>(rhsV3, lhsVi, accImag.packet[2]);
+ accImag.packet[3] = pmadd<Packet>(rhsV4, lhsVi, accImag.packet[3]);
+ }
+}
+#endif
+
+// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Scalar, typename Packet>
+EIGEN_STRONG_INLINE Packet ploadRhs(const Scalar *rhs)
+{
+ return *((Packet *)rhs);
+}
+
+#ifdef __MMA__
+template<>
+EIGEN_STRONG_INLINE PacketBlock<Packet2d, 2> ploadRhs<double, PacketBlock<Packet2d, 2> >(const double *rhs)
+{
+ PacketBlock<Packet2d, 2> pair;
+ pair.packet[0] = *((Packet2d *)rhs );
+ pair.packet[1] = *(((Packet2d *)rhs) + 1);
+ return pair;
+}
+#endif
+
+template<typename Scalar, typename Packet>
+EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs)
+{
+ return *((Packet *)lhs);
+}
+
+#ifndef __MMA__
+// Zero the accumulator on PacketBlock.
+template<typename Scalar, typename Packet>
+EIGEN_STRONG_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+ acc.packet[1] = pset1<Packet>((Scalar)0);
+ acc.packet[2] = pset1<Packet>((Scalar)0);
+ acc.packet[3] = pset1<Packet>((Scalar)0);
+}
+
+// Scale the PacketBlock vectors by alpha.
+template<typename Packet>
+EIGEN_STRONG_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha,accZ.packet[0], acc.packet[0]);
+ acc.packet[1] = pmadd(pAlpha,accZ.packet[1], acc.packet[1]);
+ acc.packet[2] = pmadd(pAlpha,accZ.packet[2], acc.packet[2]);
+ acc.packet[3] = pmadd(pAlpha,accZ.packet[3], acc.packet[3]);
+}
+
+// Complex version of PacketBlock scaling.
+template<typename Packet>
+EIGEN_STRONG_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag)
+{
+ cReal.packet[0] = pmul<Packet>(aReal.packet[0], bReal);
+ cReal.packet[1] = pmul<Packet>(aReal.packet[1], bReal);
+ cReal.packet[2] = pmul<Packet>(aReal.packet[2], bReal);
+ cReal.packet[3] = pmul<Packet>(aReal.packet[3], bReal);
+
+ cImag.packet[0] = pmul<Packet>(aImag.packet[0], bReal);
+ cImag.packet[1] = pmul<Packet>(aImag.packet[1], bReal);
+ cImag.packet[2] = pmul<Packet>(aImag.packet[2], bReal);
+ cImag.packet[3] = pmul<Packet>(aImag.packet[3], bReal);
+
+ cReal.packet[0] = psub<Packet>(cReal.packet[0], pmul<Packet>(aImag.packet[0], bImag));
+ cReal.packet[1] = psub<Packet>(cReal.packet[1], pmul<Packet>(aImag.packet[1], bImag));
+ cReal.packet[2] = psub<Packet>(cReal.packet[2], pmul<Packet>(aImag.packet[2], bImag));
+ cReal.packet[3] = psub<Packet>(cReal.packet[3], pmul<Packet>(aImag.packet[3], bImag));
+
+ cImag.packet[0] = pmadd<Packet>(aReal.packet[0], bImag, cImag.packet[0]);
+ cImag.packet[1] = pmadd<Packet>(aReal.packet[1], bImag, cImag.packet[1]);
+ cImag.packet[2] = pmadd<Packet>(aReal.packet[2], bImag, cImag.packet[2]);
+ cImag.packet[3] = pmadd<Packet>(aReal.packet[3], bImag, cImag.packet[3]);
+}
+
+// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
+template<typename DataMapper, typename Packet, typename Index, int N>
+EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col, Index accCols)
+{
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+}
+
+// An overload of bload when you have a PacketBLock with 8 vectors.
+template<typename DataMapper, typename Packet, typename Index, int N>
+EIGEN_STRONG_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col, Index accCols)
+{
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+ acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
+ acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
+ acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
+ acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
+}
+#endif
+
+// PEEL loop factor.
+#define PEEL 10
+
+/****************
+ * GEMM kernels *
+ * **************/
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper>
+EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
+ Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlpha = pset1<Packet>(alpha);
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows;
+ const Scalar *lhs_base = blockA;
+
+ Index row = 0;
+#ifdef __MMA__
+ for(; row + accCols <= rows; row += accCols)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols;
+
+ __vector_quad acc;
+ __builtin_mma_xxsetaccz(&acc);
+
+ lhs_ptr1 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr1);
+ RhsPacket rhsV = ploadRhs<Scalar, RhsPacket>(rhs_ptr);
+
+ pger<Packet, RhsPacket, false>(&acc, rhsV, lhsV);
+
+ lhs_ptr1 += accCols;
+ rhs_ptr += accRows;
}
- for(; col + accCols <= cols; col += accCols){
- const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize;
- const float *lhs_ptr = lhs_base;
-
- __vector_quad acc;
- __builtin_mma_xxsetaccz(&acc);
- for(int k = 0; k < depth; k++)
+
+ storeAccumulator<DataMapper, Index, Packet>(row, col, res, pAlpha, &acc);
+ }
+#else
+ for(; row + 6*accCols <= rows; row += 6*accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
+ lhs_ptr2 += accCols; \
+ pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
+ lhs_ptr3 += accCols; \
+ pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
+ lhs_ptr4 += accCols; \
+ pger<Scalar, Packet, false>(&accZero5, lhs_ptr5, rhs_ptr); \
+ lhs_ptr5 += accCols; \
+ pger<Scalar, Packet, false>(&accZero6, lhs_ptr6, rhs_ptr); \
+ lhs_ptr6 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
+ const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
+ const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
+ const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols;
+ const Scalar *lhs_ptr6 = lhs_base + ((row/accCols) + 5)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+ PacketBlock<Packet,4> acc2, accZero2;
+ PacketBlock<Packet,4> acc3, accZero3;
+ PacketBlock<Packet,4> acc4, accZero4;
+ PacketBlock<Packet,4> acc5, accZero5;
+ PacketBlock<Packet,4> acc6, accZero6;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+ bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero2);
+ bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero3);
+ bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero4);
+ bload<DataMapper, Packet, Index, 4>(acc5, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero5);
+ bload<DataMapper, Packet, Index, 5>(acc6, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero6);
+
+ lhs_ptr1 += accCols*offsetA;
+ lhs_ptr2 += accCols*offsetA;
+ lhs_ptr3 += accCols*offsetA;
+ lhs_ptr4 += accCols*offsetA;
+ lhs_ptr5 += accCols*offsetA;
+ lhs_ptr6 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+
+ Index k = 0;
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr2);
+ prefetch(lhs_ptr3);
+ prefetch(lhs_ptr4);
+ prefetch(lhs_ptr5);
+ prefetch(lhs_ptr6);
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+ bscale<Packet>(acc2,accZero2, pAlpha);
+ bscale<Packet>(acc3,accZero3, pAlpha);
+ bscale<Packet>(acc4,accZero4, pAlpha);
+ bscale<Packet>(acc5,accZero5, pAlpha);
+ bscale<Packet>(acc6,accZero6, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
+ res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
+ res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
+ res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
+ res.template storePacketBlock<Packet, 4>(row + 4*accCols, col, acc5);
+ res.template storePacketBlock<Packet, 4>(row + 5*accCols, col, acc6);
+#undef MICRO
+ }
+ for(; row + 5*accCols <= rows; row += 5*accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
+ lhs_ptr2 += accCols; \
+ pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
+ lhs_ptr3 += accCols; \
+ pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
+ lhs_ptr4 += accCols; \
+ pger<Scalar, Packet, false>(&accZero5, lhs_ptr5, rhs_ptr); \
+ lhs_ptr5 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
+ const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
+ const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
+ const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
+ const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+ PacketBlock<Packet,4> acc2, accZero2;
+ PacketBlock<Packet,4> acc3, accZero3;
+ PacketBlock<Packet,4> acc4, accZero4;
+ PacketBlock<Packet,4> acc5, accZero5;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+ bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero2);
+ bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero3);
+ bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero4);
+ bload<DataMapper, Packet, Index, 4>(acc5, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero5);
+
+ lhs_ptr1 += accCols*offsetA;
+ lhs_ptr2 += accCols*offsetA;
+ lhs_ptr3 += accCols*offsetA;
+ lhs_ptr4 += accCols*offsetA;
+ lhs_ptr5 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ Index k = 0;
+
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr2);
+ prefetch(lhs_ptr3);
+ prefetch(lhs_ptr4);
+ prefetch(lhs_ptr5);
+
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+ bscale<Packet>(acc2,accZero2, pAlpha);
+ bscale<Packet>(acc3,accZero3, pAlpha);
+ bscale<Packet>(acc4,accZero4, pAlpha);
+ bscale<Packet>(acc5,accZero5, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
+ res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
+ res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
+ res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
+ res.template storePacketBlock<Packet, 4>(row + 4*accCols, col, acc5);
+#undef MICRO
+ }
+ for(; row + 4*accCols <= rows; row += 4*accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
+ lhs_ptr2 += accCols; \
+ pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
+ lhs_ptr3 += accCols; \
+ pger<Scalar, Packet, false>(&accZero4, lhs_ptr4, rhs_ptr); \
+ lhs_ptr4 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
+ const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
+ const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
+ const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+ PacketBlock<Packet,4> acc2, accZero2;
+ PacketBlock<Packet,4> acc3, accZero3;
+ PacketBlock<Packet,4> acc4, accZero4;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+ bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero2);
+ bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero3);
+ bload<DataMapper, Packet, Index, 3>(acc4, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero4);
+
+ lhs_ptr1 += accCols*offsetA;
+ lhs_ptr2 += accCols*offsetA;
+ lhs_ptr3 += accCols*offsetA;
+ lhs_ptr4 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ Index k = 0;
+
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr2);
+ prefetch(lhs_ptr3);
+ prefetch(lhs_ptr4);
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+ bscale<Packet>(acc2,accZero2, pAlpha);
+ bscale<Packet>(acc3,accZero3, pAlpha);
+ bscale<Packet>(acc4,accZero4, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
+ res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
+ res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
+ res.template storePacketBlock<Packet, 4>(row + 3*accCols, col, acc4);
+#undef MICRO
+ }
+ for(; row + 3*accCols <= rows; row += 3*accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
+ lhs_ptr2 += accCols; \
+ pger<Scalar, Packet, false>(&accZero3, lhs_ptr3, rhs_ptr); \
+ lhs_ptr3 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
+ const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
+ const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+ PacketBlock<Packet,4> acc2, accZero2;
+ PacketBlock<Packet,4> acc3, accZero3;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+ bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero2);
+ bload<DataMapper, Packet, Index, 2>(acc3, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero3);
+
+ lhs_ptr1 += accCols*offsetA;
+ lhs_ptr2 += accCols*offsetA;
+ lhs_ptr3 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ Index k = 0;
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr2);
+ prefetch(lhs_ptr3);
+
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+ bscale<Packet>(acc2,accZero2, pAlpha);
+ bscale<Packet>(acc3,accZero3, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
+ res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
+ res.template storePacketBlock<Packet, 4>(row + 2*accCols, col, acc3);
+#undef MICRO
+ }
+ for(; row + 2*accCols <= rows; row += 2*accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ pger<Scalar, Packet, false>(&accZero2, lhs_ptr2, rhs_ptr); \
+ lhs_ptr2 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols;
+ const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+ PacketBlock<Packet,4> acc2, accZero2;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+ bload<DataMapper, Packet, Index, 1>(acc2, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero2);
+
+ lhs_ptr1 += accCols*offsetA;
+ lhs_ptr2 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ Index k = 0;
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr2);
+
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+ bscale<Packet>(acc2,accZero2, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row + 0*accCols, col, acc1);
+ res.template storePacketBlock<Packet, 4>(row + 1*accCols, col, acc2);
+#undef MICRO
+ }
+
+ for(; row + accCols <= rows; row += accCols)
+ {
+#define MICRO() \
+ pger<Scalar, Packet, false>(&accZero1, lhs_ptr1, rhs_ptr); \
+ lhs_ptr1 += accCols; \
+ rhs_ptr += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols;
+
+ PacketBlock<Packet,4> acc1, accZero1;
+
+ bload<DataMapper, Packet, Index, 0>(acc1, res, row, col, accCols);
+ bsetzero<Scalar, Packet>(accZero1);
+
+ lhs_ptr1 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ Index k = 0;
+ for(; k + PEEL < depth; k+= PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(lhs_ptr1);
+
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+
+ bscale<Packet>(acc1,accZero1, pAlpha);
+
+ res.template storePacketBlock<Packet, 4>(row, col, acc1);
+#undef MICRO
+ }
+#endif
+ if(remaining_rows > 0)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
+
+ lhs_ptr += remaining_rows*offsetA;
+ rhs_ptr += accRows*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ for(Index arow = 0; arow < remaining_rows; arow++)
{
- __vector float lhsV = *((__vector float *)lhs_ptr);
- __vector float rhsV = *((__vector float *)rhs_ptr);
-
- __builtin_mma_xvf32gerpp(&acc, (__vector unsigned char) rhsV, (__vector unsigned char) lhsV);
-
- lhs_ptr += floatVectorSize;
- rhs_ptr += floatVectorSize;
+ for(Index acol = 0; acol < accRows; acol++ )
+ {
+ res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
+ }
}
+ rhs_ptr += accRows;
+ lhs_ptr += remaining_rows;
+ }
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows;
+ const Scalar *lhs_base = blockA;
- storeAccumulator<DataMapper, Index, float>(row, col, res, alpha, &acc);
+ Index row = 0;
+ for(; row + accCols <= rows; row += accCols)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
+
+ lhs_ptr += accCols*offsetA;
+ rhs_ptr += remaining_cols*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ for(Index arow = 0; arow < accCols; arow++)
+ {
+ for(Index acol = 0; acol < remaining_cols; acol++ )
+ {
+ res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
+ }
}
-
- if(remaining_cols > 0)
+ rhs_ptr += remaining_cols;
+ lhs_ptr += accCols;
+ }
+ }
+
+ if(remaining_rows > 0 )
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols;
+
+ lhs_ptr += remaining_rows*offsetA;
+ rhs_ptr += remaining_cols*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ for(Index arow = 0; arow < remaining_rows; arow++)
+ {
+ for(Index acol = 0; acol < remaining_cols; acol++ )
+ {
+ res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol];
+ }
+ }
+ rhs_ptr += remaining_cols;
+ lhs_ptr += remaining_rows;
+ }
+ }
+ }
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc,
+ Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols)
+{
+ const int remaining_rows = rows % accCols;
+ const int remaining_cols = cols % accRows;
+ const int accColsC = accCols / 2;
+ int advanceCols = 2;
+ int advanceRows = 2;
+
+ if(LhsIsReal) advanceRows = 1;
+ if(RhsIsReal) advanceCols = 1;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
+
+ const Scalar *blockA = (Scalar *) blockAc;
+ const Scalar *blockB = (Scalar *) blockBc;
+
+ Packet conj = pset1<Packet>((Scalar)-1.0f);
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows;
+ const Scalar *lhs_base = blockA;
+
+ Index row = 0;
+#ifdef __MMA__
+ for(; row + accCols <= rows; row += accCols)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB;
+ const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA;
+
+ __vector_quad accReal, accImag;
+ __builtin_mma_xxsetaccz(&accReal);
+ __builtin_mma_xxsetaccz(&accImag);
+
+ lhs_ptr += accCols*offsetA;
+ if(!LhsIsReal)
+ lhs_ptr_imag += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ if(!RhsIsReal)
+ rhs_ptr_imag += accRows*offsetB;
+ for(Index k = 0; k < depth; k++)
{
- const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize;
- const float *lhs_ptr = lhs_base;
- for(int k = 0; k < depth; k++)
+ Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
+ RhsPacket rhsV = ploadRhs<Scalar, RhsPacket>(rhs_ptr);
+
+ Packet lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
+ RhsPacket rhsVi = ploadRhs<Scalar, RhsPacket>(rhs_ptr_imag);
+
+ if(ConjugateLhs && !LhsIsReal) lhsVi = pmul<Packet>(lhsVi, conj);
+ if(ConjugateRhs && !RhsIsReal) rhsVi = pmul<Packet>(rhsVi, conj);
+
+ if(LhsIsReal)
+ {
+ pger<Packet, RhsPacket, false>(&accReal, rhsV, lhsV);
+ pger<Packet, RhsPacket, false>(&accImag, rhsVi, lhsV);
+ } else if(RhsIsReal) {
+ pger<Packet, RhsPacket, false>(&accReal, rhsV, lhsV);
+ pger<Packet, RhsPacket, false>(&accImag, rhsV, lhsVi);
+ } else {
+ pger<Packet, RhsPacket, false>(&accReal, rhsV, lhsV);
+ pger<Packet, RhsPacket, true>(&accReal, rhsVi, lhsVi);
+ pger<Packet, RhsPacket, false>(&accImag, rhsVi, lhsV);
+ pger<Packet, RhsPacket, false>(&accImag, rhsV, lhsVi);
+ }
+
+ lhs_ptr += accCols;
+ rhs_ptr += accRows;
+ if(!LhsIsReal)
+ lhs_ptr_imag += accCols;
+ if(!RhsIsReal)
+ rhs_ptr_imag += accRows;
+ }
+
+ storeComplexAccumulator<DataMapper, Index, Packet, Packetc, 0>(row, col, res, pAlphaReal, pAlphaImag, &accReal, &accImag, accColsC);
+ }
+#else
+ for(; row + accCols <= rows; row += accCols)
+ {
+#define MICRO() \
+ pgerc<Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal1, accImag1, rhs_ptr, rhs_ptr_imag, lhs_ptr1, lhs_ptr_imag1, conj); \
+ lhs_ptr1 += accCols; \
+ rhs_ptr += accRows; \
+ if(!LhsIsReal) \
+ lhs_ptr_imag1 += accCols; \
+ if(!RhsIsReal) \
+ rhs_ptr_imag += accRows;
+
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB;
+ const Scalar *lhs_ptr1 = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr_imag1 = lhs_ptr1 + accCols*strideA;
+
+ PacketBlock<Packet,4> accReal1, accImag1;
+ bsetzero<Scalar, Packet>(accReal1);
+ bsetzero<Scalar, Packet>(accImag1);
+
+ lhs_ptr1 += accCols*offsetA;
+ if(!LhsIsReal)
+ lhs_ptr_imag1 += accCols*offsetA;
+ rhs_ptr += accRows*offsetB;
+ if(!RhsIsReal)
+ rhs_ptr_imag += accRows*offsetB;
+ Index k = 0;
+ for(; k + PEEL < depth; k+=PEEL)
+ {
+ prefetch(rhs_ptr);
+ prefetch(rhs_ptr_imag);
+ prefetch(lhs_ptr1);
+ prefetch(lhs_ptr_imag1);
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+ MICRO();
+#if PEEL > 8
+ MICRO();
+ MICRO();
+#endif
+ }
+ for(; k < depth; k++)
+ {
+ MICRO();
+ }
+ PacketBlock<Packet,4> taccReal, taccImag;
+ bscalec<Packet>(accReal1, accImag1, pAlphaReal, pAlphaImag, taccReal, taccImag);
+
+ PacketBlock<Packetc, 8> tRes;
+ bload<DataMapper, Packetc, Index, 0>(tRes, res, row, col, accColsC);
+
+ PacketBlock<Packetc, 4> acc1, acc2;
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
+
+ res.template storePacketBlock<Packetc, 4>(row + 0, col, acc1);
+ res.template storePacketBlock<Packetc, 4>(row + accColsC, col, acc2);
+#undef MICRO
+ }
+#endif
+ if(remaining_rows > 0)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *rhs_ptr_imag = rhs_ptr + accRows*strideB;
+ const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA;
+
+ lhs_ptr += remaining_rows*offsetA;
+ if(!LhsIsReal)
+ lhs_ptr_imag += remaining_rows*offsetA;
+ rhs_ptr += accRows*offsetB;
+ if(!RhsIsReal)
+ rhs_ptr_imag += accRows*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ for(Index arow = 0; arow < remaining_rows; arow++)
{
- for(int arow = 0; arow < accRows; arow++)
- {
- for(int acol = 0; acol < remaining_cols; acol++ )
- {
- res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += remaining_cols;
- lhs_ptr += floatVectorSize;
+ Scalar lhs_real = lhs_ptr[arow];
+ Scalar lhs_imag;
+ if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow];
+
+ Scalarc lhsc;
+
+ lhsc.real(lhs_real);
+ if(!LhsIsReal)
+ {
+ if(ConjugateLhs)
+ lhsc.imag(-lhs_imag);
+ else
+ lhsc.imag(lhs_imag);
+ } else {
+ //Lazy approach for now
+ lhsc.imag((Scalar)0);
+ }
+
+ for(int acol = 0; acol < accRows; acol++ )
+ {
+ Scalar rhs_real = rhs_ptr[acol];
+ Scalar rhs_imag;
+ if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol];
+ Scalarc rhsc;
+
+ rhsc.real(rhs_real);
+ if(!RhsIsReal)
+ {
+ if(ConjugateRhs)
+ rhsc.imag(-rhs_imag);
+ else
+ rhsc.imag(rhs_imag);
+ } else {
+ //Lazy approach for now
+ rhsc.imag((Scalar)0);
+ }
+ res(row + arow, col + acol) += alpha*lhsc*rhsc;
+ }
}
+ rhs_ptr += accRows;
+ lhs_ptr += remaining_rows;
+ if(!LhsIsReal)
+ lhs_ptr_imag += remaining_rows;
+ if(!RhsIsReal)
+ rhs_ptr_imag += accRows;
+ }
}
}
- if(remaining_rows > 0)
+ if(remaining_cols > 0)
{
- const float *rhs_base = blockB;
- const float *lhs_base = blockA + (row/accRows)*depth*floatVectorSize;
+ const Scalar *rhs_base = blockB + ( (advanceCols*col)/accRows )*strideB*accRows;
+ const Scalar *lhs_base = blockA;
+ Index row = 0;
- int col;
- for(col = 0; col + accCols <= cols; col += accCols)
+ for(; row + accCols <= rows; row += accCols)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB;
+ const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr_imag = lhs_ptr + accCols*strideA;
+
+ lhs_ptr += accCols*offsetA;
+ if(!LhsIsReal)
+ lhs_ptr_imag += accCols*offsetA;
+ rhs_ptr += remaining_cols*offsetB;
+ if(!RhsIsReal)
+ rhs_ptr_imag += remaining_cols*offsetB;
+ Scalarc scalarAcc[4][4];
+ for(Index arow = 0; arow < 4; arow++ )
+ {
+ for(Index acol = 0; acol < 4; acol++ )
+ {
+ scalarAcc[arow][acol].real((Scalar)0.0f);
+ scalarAcc[arow][acol].imag((Scalar)0.0f);
+ }
+ }
+ for(Index k = 0; k < depth; k++)
{
- const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize;
- const float *lhs_ptr = lhs_base;
- for(int k = 0; k < depth; k++)
+ for(Index arow = 0; arow < accCols; arow++)
+ {
+ Scalar lhs_real = lhs_ptr[arow];
+ Scalar lhs_imag;
+ if(!LhsIsReal)
+ {
+ lhs_imag = lhs_ptr_imag[arow];
+
+ if(ConjugateLhs)
+ lhs_imag *= -1;
+ } else {
+ lhs_imag = (Scalar)0;
+ }
+
+ for(int acol = 0; acol < remaining_cols; acol++ )
{
- for(int arow = 0; arow < remaining_rows; arow++)
- {
- for(int acol = 0; acol < accCols; acol++ )
- {
- res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += floatVectorSize;
- lhs_ptr += remaining_rows;
+ Scalar rhs_real = rhs_ptr[acol];
+ Scalar rhs_imag;
+ if(!RhsIsReal)
+ {
+ rhs_imag = rhs_ptr_imag[acol];
+
+ if(ConjugateRhs)
+ rhs_imag *= -1;
+ } else {
+ rhs_imag = (Scalar)0;
+ }
+
+ scalarAcc[arow][acol].real(scalarAcc[arow][acol].real() + lhs_real*rhs_real - lhs_imag*rhs_imag);
+ scalarAcc[arow][acol].imag(scalarAcc[arow][acol].imag() + lhs_imag*rhs_real + lhs_real*rhs_imag);
}
+ }
+ rhs_ptr += remaining_cols;
+ lhs_ptr += accCols;
+ if(!RhsIsReal)
+ rhs_ptr_imag += remaining_cols;
+ if(!LhsIsReal)
+ lhs_ptr_imag += accCols;
}
-
- if(remaining_cols > 0)
+ for(int arow = 0; arow < accCols; arow++ )
{
- const float *rhs_ptr = rhs_base + (col/accCols)*depth*floatVectorSize;
- const float *lhs_ptr = lhs_base;
- for(int k = 0; k < depth; k++)
+ for(int acol = 0; acol < remaining_cols; acol++ )
+ {
+ Scalar accR = scalarAcc[arow][acol].real();
+ Scalar accI = scalarAcc[arow][acol].imag();
+ Scalar aR = alpha.real();
+ Scalar aI = alpha.imag();
+ Scalar resR = res(row + arow, col + acol).real();
+ Scalar resI = res(row + arow, col + acol).imag();
+
+ res(row + arow, col + acol).real(resR + accR*aR - accI*aI);
+ res(row + arow, col + acol).imag(resI + accR*aI + accI*aR);
+ }
+ }
+ }
+
+ if(remaining_rows > 0)
+ {
+ const Scalar *rhs_ptr = rhs_base;
+ const Scalar *rhs_ptr_imag = rhs_ptr + remaining_cols*strideB;
+ const Scalar *lhs_ptr = lhs_base + ((advanceRows*row)/accCols)*strideA*accCols;
+ const Scalar *lhs_ptr_imag = lhs_ptr + remaining_rows*strideA;
+
+ lhs_ptr += remaining_rows*offsetA;
+ if(!LhsIsReal)
+ lhs_ptr_imag += remaining_rows*offsetA;
+ rhs_ptr += remaining_cols*offsetB;
+ if(!RhsIsReal)
+ rhs_ptr_imag += remaining_cols*offsetB;
+ for(Index k = 0; k < depth; k++)
+ {
+ for(Index arow = 0; arow < remaining_rows; arow++)
+ {
+ Scalar lhs_real = lhs_ptr[arow];
+ Scalar lhs_imag;
+ if(!LhsIsReal) lhs_imag = lhs_ptr_imag[arow];
+ Scalarc lhsc;
+
+ lhsc.real(lhs_real);
+ if(!LhsIsReal)
{
- for(int arow = 0; arow < remaining_rows; arow++)
- {
- for(int acol = 0; acol < remaining_cols; acol++ )
- {
- res(row + arow, col + acol) += lhs_ptr[arow]*rhs_ptr[acol];
- }
- }
- rhs_ptr += remaining_cols;
- lhs_ptr += remaining_rows;
+ if(ConjugateLhs)
+ lhsc.imag(-lhs_imag);
+ else
+ lhsc.imag(lhs_imag);
+ } else {
+ lhsc.imag((Scalar)0);
+ }
+
+ for(Index acol = 0; acol < remaining_cols; acol++ )
+ {
+ Scalar rhs_real = rhs_ptr[acol];
+ Scalar rhs_imag;
+ if(!RhsIsReal) rhs_imag = rhs_ptr_imag[acol];
+ Scalarc rhsc;
+
+ rhsc.real(rhs_real);
+ if(!RhsIsReal)
+ {
+ if(ConjugateRhs)
+ rhsc.imag(-rhs_imag);
+ else
+ rhsc.imag(rhs_imag);
+ } else {
+ rhsc.imag((Scalar)0);
+ }
+ res(row + arow, col + acol) += alpha*lhsc*rhsc;
}
+ }
+ rhs_ptr += remaining_cols;
+ lhs_ptr += remaining_rows;
+ if(!LhsIsReal)
+ lhs_ptr_imag += remaining_rows;
+ if(!RhsIsReal)
+ rhs_ptr_imag += remaining_cols;
}
+ }
}
+}
+
+/************************************
+ * ppc64le template specializations *
+ * **********************************/
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_cpack<float, true, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_cpack<float, true, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_cpack<double, true, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ lhs_cpack<double, true, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ rhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+// ********* gebp specializations *********
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<float>::vectortype Packet;
+ typedef typename quad_traits<float>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<float>::rows;
+ const int accCols = quad_traits<float>::size;
+
+ gemm<float, Index, Packet, RhsPacket, DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<float>::rows;
+ const int accCols = quad_traits<float>::size;
+
+ gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<float>::rows;
+ const int accCols = quad_traits<float>::size;
+
+ gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<float>::rows;
+ const int accCols = quad_traits<float>::size;
+
+ gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
}
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<double>::vectortype Packet;
+ typedef typename quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<double>::rows;
+ const int accCols = quad_traits<double>::size;
+
+ gemm<double, Index, Packet, RhsPacket, DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<double>::rows;
+ const int accCols = quad_traits<double>::size;
+
+ gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<double>::rows;
+ const int accCols = quad_traits<double>::size;
+
+ gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const int accRows = quad_traits<double>::rows;
+ const int accCols = quad_traits<double>::size;
+
+ gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols);
+ }
} // end namespace internal
} // end namespace Eigen
-
-#endif // __MMA__
-#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
+#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H \ No newline at end of file
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index 01e647f17..a90e57446 100755
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -391,6 +391,77 @@ public:
return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
}
+ // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the Complex types.
+ template<typename SubPacket, typename ScalarT, int n, int idx>
+ struct storePacketBlock_helper
+ {
+ storePacketBlock_helper<SubPacket, ScalarT, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup, i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ ScalarT *v = &sup->operator()(i+l, j+idx);
+ *v = block.packet[idx][l];
+ }
+ }
+ };
+
+ template<typename SubPacket, int n, int idx>
+ struct storePacketBlock_helper<SubPacket, std::complex<float>, n, idx>
+ {
+ storePacketBlock_helper<SubPacket, std::complex<float>, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup,i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ std::complex<float> *v = &sup->operator()(i+l, j+idx);
+ v->real(block.packet[idx].v[2*l+0]);
+ v->imag(block.packet[idx].v[2*l+1]);
+ }
+ }
+ };
+
+ template<typename SubPacket, int n, int idx>
+ struct storePacketBlock_helper<SubPacket, std::complex<double>, n, idx>
+ {
+ storePacketBlock_helper<SubPacket, std::complex<double>, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup,i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ std::complex<double> *v = &sup->operator()(i+l, j+idx);
+ v->real(block.packet[idx].v[2*l+0]);
+ v->imag(block.packet[idx].v[2*l+1]);
+ }
+ }
+ };
+
+ template<typename SubPacket, typename ScalarT, int n>
+ struct storePacketBlock_helper<SubPacket, ScalarT, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+
+ template<typename SubPacket, int n>
+ struct storePacketBlock_helper<SubPacket, std::complex<float>, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+
+ template<typename SubPacket, int n>
+ struct storePacketBlock_helper<SubPacket, std::complex<double>, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+ // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be avoided when possible.
+ template<typename SubPacket, int n>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock<SubPacket, n>&block) const {
+ storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
+ spb.store(this, i,j,block);
+ }
protected:
Scalar* EIGEN_RESTRICT m_data;
const Index m_stride;
diff --git a/test/blasutil.cpp b/test/blasutil.cpp
index 9caacfbab..01942918b 100644
--- a/test/blasutil.cpp
+++ b/test/blasutil.cpp
@@ -200,5 +200,7 @@ EIGEN_DECLARE_TEST(blasutil)
CALL_SUBTEST_5(run_test<float_t>());
CALL_SUBTEST_6(run_test<double_t>());
+ CALL_SUBTEST_7(run_test<std::complex<float> >());
+ CALL_SUBTEST_8(run_test<std::complex<double> >());
}
}