From a2d7c239f54190ddb40febb6b4b65d74c261f008 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 2 Mar 2010 12:44:40 +0100 Subject: blas: fix HEMM and HERK --- blas/level3_impl.h | 90 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 53 insertions(+), 37 deletions(-) (limited to 'blas/level3_impl.h') diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 32b49b118..c9023ab37 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -218,8 +218,7 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, // c = alpha*b*a + beta*c for side = 'R'or'r int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { -// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << " " -// << pa << " " << pb << " " << pc << "\n"; +// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); @@ -234,25 +233,17 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa } if(beta!=Scalar(1)) - if(beta==Scalar(0)) - matrix(c, *m, *n, *ldc).setZero(); - else - matrix(c, *m, *n, *ldc) *= beta; + if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); + else matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else return 0; else if(SIDE(*side)==RIGHT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else return 0; else return 0; @@ -334,27 +325,30 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); +// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; + + if(*m<0 || *n<0) + { + return 0; + } + if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else return 0; else if(SIDE(*side)==RIGHT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else return 0; else + { return 0; + } - return 1; + return 0; } // c = alpha*a*conj(a') + beta*c for op = 'N'or'n' @@ -381,18 +375,35 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp Scalar* a = reinterpret_cast(pa); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + RealScalar alpha = *palpha; + RealScalar beta = *pbeta; + +// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; + + if(*n<0 || *k<0) + { + return 0; + } int code = OP(*op) | (UPLO(*uplo) << 2); if(code>=8 || func[code]==0) return 0; - if(beta!=Scalar(1)) - matrix(c, *n, *n, *ldc) *= beta; + if(beta!=RealScalar(1)) + { + if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView() *= beta; + else matrix(c, *n, *n, *ldc).triangularView() *= beta; - func[code](*n, *k, a, *lda, c, *ldc, alpha); - return 1; + matrix(c, *n, *n, *ldc).diagonal().real() *= beta; + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } + + if(*k>0 && alpha!=RealScalar(0)) + { + func[code](*n, *k, a, *lda, c, *ldc, alpha); + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } + return 0; } // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' @@ -405,6 +416,11 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); + if(*n<0 || *k<0) + { + return 0; + } + // TODO return 0; -- cgit v1.2.3