From dd27e10360ada43cc9c33a802211d79cc3984b23 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 17 Jul 2010 11:59:09 +0200 Subject: fix level3 blas: it now passes all computational tests --- blas/level3_impl.h | 152 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 139 insertions(+), 13 deletions(-) (limited to 'blas/level3_impl.h') diff --git a/blas/level3_impl.h b/blas/level3_impl.h index ff7ed6752..38b2637c3 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -243,6 +243,30 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa else matrix(c, *m, *n, *ldc) *= beta; } + if(*m==0 || *n==0) + { + return 1; + } + + #if ISCOMPLEX + // FIXME add support for symmetric complex matrix + int size = (SIDE(*side)==LEFT) ? (*m) : (*n); + Matrix matA(size,size); + if(UPLO(*uplo)==UP) + { + matA.triangularView() = matrix(a,size,size,*lda); + matA.triangularView() = matrix(a,size,size,*lda).transpose(); + } + else if(UPLO(*uplo)==LO) + { + matA.triangularView() = matrix(a,size,size,*lda); + matA.triangularView() = matrix(a,size,size,*lda).transpose(); + } + if(SIDE(*side)==LEFT) + matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb); + else if(SIDE(*side)==RIGHT) + matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA; + #else if(SIDE(*side)==LEFT) if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); @@ -253,6 +277,7 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa else return 0; else return 0; + #endif return 0; } @@ -301,7 +326,25 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp else matrix(c, *n, *n, *ldc).triangularView() *= beta; } + #if ISCOMPLEX + // FIXME add support for symmetric complex matrix + if(UPLO(*uplo)==UP) + { + if(OP(*op)==NOTR) + matrix(c, *n, *n, *ldc).triangularView() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose(); + else + matrix(c, *n, *n, *ldc).triangularView() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda); + } + else + { + if(OP(*op)==NOTR) + matrix(c, *n, *n, *ldc).triangularView() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose(); + else + matrix(c, *n, *n, *ldc).triangularView() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda); + } + #else func[code](*n, *k, a, *lda, c, *ldc, alpha); + #endif return 0; } @@ -316,8 +359,44 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - // TODO - std::cerr << "Eigen BLAS: _syr2k is not implemented yet\n"; + if(*n<=0 || *k<0) + { + return 0; + } + + if(beta!=Scalar(1)) + { + if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView() *= beta; + else matrix(c, *n, *n, *ldc).triangularView() *= beta; + } + + if(*k==0) + return 1; + + if(OP(*op)==NOTR) + { + if(UPLO(*uplo)==UP) + { + matrix(c, *n, *n, *ldc).triangularView() + += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose() + + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose(); + } + else if(UPLO(*uplo)==LO) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose() + + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose(); + } + else if(OP(*op)==TR || OP(*op)==ADJ) + { + if(UPLO(*uplo)==UP) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb) + + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda); + else if(UPLO(*uplo)==LO) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb) + + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda); + } return 0; } @@ -342,19 +421,30 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa return 0; } - if(beta!=Scalar(1)) + if(beta==Scalar(0)) + matrix(c, *m, *n, *ldc).setZero(); + else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta; + + if(*m==0 || *n==0) + { + return 1; + } if(SIDE(*side)==LEFT) { - if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix + ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix + ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); else return 0; } else if(SIDE(*side)==RIGHT) { - if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView();/*ei_product_selfadjoint_matrix + ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);*/ + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix + ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); else return 0; } else @@ -421,24 +511,60 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp } // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' -// c = alpha*conj(b')*a + conj(alpha)*conj(a')*b + beta*c, for op = 'C'or'c' +// c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c' int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + RealScalar beta = *pbeta; - if(*n<0 || *k<0) + if(*n<=0 || *k<0) { return 0; } - // TODO - std::cerr << "Eigen BLAS: _her2k is not implemented yet\n"; + if(beta!=RealScalar(1)) + { + if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView() *= beta; + else matrix(c, *n, *n, *ldc).triangularView() *= beta; - return 0; + matrix(c, *n, *n, *ldc).diagonal().real() *= beta; + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } + else if(*k>0 && alpha!=Scalar(0)) + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + + if(*k==0) + return 1; + + if(OP(*op)==NOTR) + { + if(UPLO(*uplo)==UP) + { + matrix(c, *n, *n, *ldc).triangularView() + += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint() + + ei_conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint(); + } + else if(UPLO(*uplo)==LO) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint() + + ei_conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint(); + } + else if(OP(*op)==ADJ) + { + if(UPLO(*uplo)==UP) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb) + + ei_conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda); + else if(UPLO(*uplo)==LO) + matrix(c, *n, *n, *ldc).triangularView() + += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb) + + ei_conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda); + } + + return 1; } #endif // ISCOMPLEX -- cgit v1.2.3