aboutsummaryrefslogtreecommitdiffhomepage
path: root/blas/level3_impl.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-22 18:49:12 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-22 18:49:12 +0100
commitf5f288b741b173a271b9c939ac5231639135dd93 (patch)
tree977b85dc7b88aa9faf58696eabba0163ca7c2235 /blas/level3_impl.h
parenta6f483e86b0c4c1d82622eec99fb051c804bf13d (diff)
split level 1 and 2 implementation files into smaller ones and fix a couple of numerical and tricky issues discovered by the lapack test suite
Diffstat (limited to 'blas/level3_impl.h')
-rw-r--r--blas/level3_impl.h64
1 files changed, 42 insertions, 22 deletions
diff --git a/blas/level3_impl.h b/blas/level3_impl.h
index 5b28a1d52..6c53dc679 100644
--- a/blas/level3_impl.h
+++ b/blas/level3_impl.h
@@ -52,7 +52,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
-
+
int info = 0;
if(OP(*opa)==INVALID) info = 1;
else if(OP(*opb)==INVALID) info = 2;
@@ -342,13 +342,17 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
else if(*ldc<std::max(1,*n)) info = 10;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
-
+
int code = OP(*op) | (UPLO(*uplo) << 2);
if(beta!=Scalar(1))
{
- if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
- else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
+ if(UPLO(*uplo)==UP)
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
+ else
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
}
#if ISCOMPLEX
@@ -383,7 +387,7 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
-
+
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
else if(OP(*op)==INVALID) info = 2;
@@ -394,11 +398,15 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
else if(*ldc<std::max(1,*n)) info = 12;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
-
+
if(beta!=Scalar(1))
{
- if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
- else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
+ if(UPLO(*uplo)==UP)
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
+ else
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
}
if(*k==0)
@@ -458,11 +466,9 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
if(info)
return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
- if(beta==Scalar(0))
- matrix(c, *m, *n, *ldc).setZero();
- else if(beta!=Scalar(1))
- matrix(c, *m, *n, *ldc) *= beta;
-
+ if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
+ else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta;
+
if(*m==0 || *n==0)
{
return 1;
@@ -535,11 +541,18 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
if(beta!=RealScalar(1))
{
- if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
- else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
-
- matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
- matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
+ if(UPLO(*uplo)==UP)
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
+ else
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
+
+ if(beta!=Scalar(0))
+ {
+ matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
+ matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
+ }
}
if(*k>0 && alpha!=RealScalar(0))
@@ -573,11 +586,18 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
if(beta!=RealScalar(1))
{
- if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
- else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
+ if(UPLO(*uplo)==UP)
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
+ else
+ if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
+ else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
- matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
- matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
+ if(beta!=Scalar(0))
+ {
+ matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
+ matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
+ }
}
else if(*k>0 && alpha!=Scalar(0))
matrix(c, *n, *n, *ldc).diagonal().imag().setZero();