diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-22 18:49:12 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-22 18:49:12 +0100 |
commit | f5f288b741b173a271b9c939ac5231639135dd93 (patch) | |
tree | 977b85dc7b88aa9faf58696eabba0163ca7c2235 /blas/level3_impl.h | |
parent | a6f483e86b0c4c1d82622eec99fb051c804bf13d (diff) |
split level 1 and 2 implementation files into smaller ones and fix a couple of numerical and tricky issues discovered by the lapack test suite
Diffstat (limited to 'blas/level3_impl.h')
-rw-r--r-- | blas/level3_impl.h | 64 |
1 files changed, 42 insertions, 22 deletions
diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 5b28a1d52..6c53dc679 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -52,7 +52,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); - + int info = 0; if(OP(*opa)==INVALID) info = 1; else if(OP(*opb)==INVALID) info = 2; @@ -342,13 +342,17 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp else if(*ldc<std::max(1,*n)) info = 10; if(info) return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6); - + int code = OP(*op) | (UPLO(*uplo) << 2); if(beta!=Scalar(1)) { - if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; - else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; + if(UPLO(*uplo)==UP) + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; + else + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; } #if ISCOMPLEX @@ -383,7 +387,7 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); - + int info = 0; if(UPLO(*uplo)==INVALID) info = 1; else if(OP(*op)==INVALID) info = 2; @@ -394,11 +398,15 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal else if(*ldc<std::max(1,*n)) info = 12; if(info) return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6); - + if(beta!=Scalar(1)) { - if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; - else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; + if(UPLO(*uplo)==UP) + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; + else + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; } if(*k==0) @@ -458,11 +466,9 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa if(info) return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6); - if(beta==Scalar(0)) - matrix(c, *m, *n, *ldc).setZero(); - else if(beta!=Scalar(1)) - matrix(c, *m, *n, *ldc) *= beta; - + if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); + else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta; + if(*m==0 || *n==0) { return 1; @@ -535,11 +541,18 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp if(beta!=RealScalar(1)) { - if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; - else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; - - matrix(c, *n, *n, *ldc).diagonal().real() *= beta; - matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + if(UPLO(*uplo)==UP) + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; + else + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; + + if(beta!=Scalar(0)) + { + matrix(c, *n, *n, *ldc).diagonal().real() *= beta; + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } } if(*k>0 && alpha!=RealScalar(0)) @@ -573,11 +586,18 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal if(beta!=RealScalar(1)) { - if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; - else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; + if(UPLO(*uplo)==UP) + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; + else + if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); + else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; - matrix(c, *n, *n, *ldc).diagonal().real() *= beta; - matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + if(beta!=Scalar(0)) + { + matrix(c, *n, *n, *ldc).diagonal().real() *= beta; + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } } else if(*k>0 && alpha!=Scalar(0)) matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); |