aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-22 23:12:22 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-22 23:12:22 +0200
commit0cb4f32e12a190f43f794a8258661bade74f2eb2 (patch)
tree9e7311990bad978415cc443f77499958eb99028d /Eigen/src/Core/products
parente7f8e939e282a64025203a7a22e511165e1e3647 (diff)
implement high level API for SYMM and fix a couple of bugs related to complex
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h7
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixMatrix.h49
2 files changed, 28 insertions, 28 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 1c48a5ed4..b2f51ca5b 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -334,22 +334,23 @@ struct ei_gebp_kernel
};
// pack a block of the lhs
-template<typename Scalar, int mr, int StorageOrder>
+template<typename Scalar, int mr, int StorageOrder, bool Conjugate>
struct ei_gemm_pack_lhs
{
void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int actual_kc, int actual_mc)
{
+ ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
int count = 0;
const int peeled_mc = (actual_mc/mr)*mr;
for(int i=0; i<peeled_mc; i+=mr)
for(int k=0; k<actual_kc; k++)
for(int w=0; w<mr; w++)
- blockA[count++] = lhs(i+w, k);
+ blockA[count++] = cj(lhs(i+w, k));
for(int i=peeled_mc; i<actual_mc; i++)
{
for(int k=0; k<actual_kc; k++)
- blockA[count++] = lhs(i, k);
+ blockA[count++] = cj(lhs(i, k));
}
}
};
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
index af3767e18..4ec9987d6 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
@@ -39,32 +39,30 @@ struct ei_symm_pack_lhs
// normal copy
for(int k=0; k<i; k++)
for(int w=0; w<mr; w++)
- blockA[count++] = lhs(i+w,k);
+ blockA[count++] = lhs(i+w,k); // normal
// symmetric copy
int h = 0;
for(int k=i; k<i+mr; k++)
{
- // transposed copy
for(int w=0; w<h; w++)
- blockA[count++] = lhs(k, i+w);
+ blockA[count++] = ei_conj(lhs(k, i+w)); // transposed
for(int w=h; w<mr; w++)
- blockA[count++] = lhs(i+w, k);
+ blockA[count++] = lhs(i+w, k); // normal
++h;
}
// transposed copy
for(int k=i+mr; k<actual_kc; k++)
for(int w=0; w<mr; w++)
- blockA[count++] = lhs(k, i+w);
+ blockA[count++] = ei_conj(lhs(k, i+w)); // transposed
}
// do the same with mr==1
for(int i=peeled_mc; i<actual_mc; i++)
{
for(int k=0; k<=i; k++)
- blockA[count++] = lhs(i, k);
- // transposed copy
+ blockA[count++] = lhs(i, k); // normal
for(int k=i+1; k<actual_kc; k++)
- blockA[count++] = lhs(k, i);
+ blockA[count++] = ei_conj(lhs(k, i)); // transposed
}
}
};
@@ -79,7 +77,7 @@ struct ei_symm_pack_rhs
int count = 0;
ei_const_blas_data_mapper<Scalar,StorageOrder> rhs(_rhs,rhsStride);
- // first part: standard case
+ // first part: normal case
for(int j2=0; j2<k2; j2+=nr)
{
for(int k=k2; k<end_k; k++)
@@ -102,12 +100,12 @@ struct ei_symm_pack_rhs
// transpose
for(int k=k2; k<j2; k++)
{
- ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
- ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
+ ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+0,k))));
+ ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+1,k))));
if (nr==4)
{
- ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
- ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
+ ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+2,k))));
+ ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+3,k))));
}
count += nr*PacketSize;
}
@@ -120,7 +118,7 @@ struct ei_symm_pack_rhs
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(k,j2+w)));
// transpose
for (int w=h ; w<nr; ++w)
- ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(j2+w,k)));
+ ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+w,k))));
count += nr*PacketSize;
++h;
}
@@ -138,17 +136,17 @@ struct ei_symm_pack_rhs
}
}
- // third part: transpose
+ // third part: transposed
for(int j2=k2+actual_kc; j2<packet_cols; j2+=nr)
{
for(int k=k2; k<end_k; k++)
{
- ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
- ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
+ ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+0,k))));
+ ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+1,k))));
if (nr==4)
{
- ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
- ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
+ ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+2,k))));
+ ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+3,k))));
}
count += nr*PacketSize;
}
@@ -161,7 +159,7 @@ struct ei_symm_pack_rhs
int half = std::min(end_k,j2);
for(int k=k2; k<half; k++)
{
- ei_pstore(&blockB[count], ei_pset1(alpha*rhs(j2,k)));
+ ei_pstore(&blockB[count], ei_pset1(alpha*ei_conj(rhs(j2,k))));
count += PacketSize;
}
// normal
@@ -198,8 +196,9 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,LhsSelfAdjoint,Conju
{
ei_product_selfadjoint_matrix<Scalar,
RhsStorageOrder==RowMajor ? ColMajor : RowMajor, RhsSelfAdjoint, ConjugateRhs,
- LhsStorageOrder==RowMajor ? ColMajor : RowMajor, LhsSelfAdjoint, ConjugateLhs, ColMajor>
- ::run(rows, cols, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
+ EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
+ LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && !ConjugateLhs, ColMajor>
+ ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
}
};
@@ -254,8 +253,8 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, R
for(int i2=0; i2<k2; i2+=mc)
{
const int actual_mc = std::min(i2+mc,k2)-i2;
- // transposed packed copy if Lower part
- ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor>()
+ // transposed packed copy
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor, true>()
(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
@@ -273,7 +272,7 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, R
for(int i2=k2+kc; i2<size; i2+=mc)
{
const int actual_mc = std::min(i2+mc,size)-i2;
- ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>()
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder,false>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);