aboutsummaryrefslogtreecommitdiffhomepage
path: root/blas/level3_impl.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-03 22:03:12 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-03 22:03:12 +0100
commitfd88d721d2327e92a8c6c156dde266967dfb0d91 (patch)
tree4aca8c505d28ebc405c6e4dd5f8c9889105361d4 /blas/level3_impl.h
parenta8fb6b0ad33a4620424e31e842b54c8cd255c6d2 (diff)
implement proper error handling in level 3 routines
Diffstat (limited to 'blas/level3_impl.h')
-rw-r--r--blas/level3_impl.h150
1 files changed, 98 insertions, 52 deletions
diff --git a/blas/level3_impl.h b/blas/level3_impl.h
index e5b781cd6..b07c20ee6 100644
--- a/blas/level3_impl.h
+++ b/blas/level3_impl.h
@@ -52,14 +52,18 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
-
- int code = OP(*opa) | (OP(*opb) << 2);
- if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0))
- {
- int info = 1;
- xerbla_("GEMM", &info, 4);
- return 0;
- }
+
+ int info = 0;
+ if(OP(*opa)==INVALID) info = 1;
+ else if(OP(*opb)==INVALID) info = 2;
+ else if(*m<0) info = 3;
+ else if(*n<0) info = 4;
+ else if(*k<0) info = 5;
+ else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
+ else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
+ else if(*ldc<std::max(1,*m)) info = 13;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
if(beta!=Scalar(1))
{
@@ -69,6 +73,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
+ int code = OP(*opa) | (OP(*opb) << 2);
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
return 0;
}
@@ -125,13 +130,19 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
+ int info = 0;
+ if(SIDE(*side)==INVALID) info = 1;
+ else if(UPLO(*uplo)==INVALID) info = 2;
+ else if(OP(*opa)==INVALID) info = 3;
+ else if(DIAG(*diag)==INVALID) info = 4;
+ else if(*m<0) info = 5;
+ else if(*n<0) info = 6;
+ else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
+ else if(*ldb<std::max(1,*m)) info = 11;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
+
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
- if(code>=32 || func[code]==0 || *m<0 || *n <0)
- {
- int info=1;
- xerbla_("TRSM",&info,4);
- return 0;
- }
if(SIDE(*side)==LEFT)
func[code](*m, *n, a, *lda, b, *ldb);
@@ -197,13 +208,19 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
+ int info = 0;
+ if(SIDE(*side)==INVALID) info = 1;
+ else if(UPLO(*uplo)==INVALID) info = 2;
+ else if(OP(*opa)==INVALID) info = 3;
+ else if(DIAG(*diag)==INVALID) info = 4;
+ else if(*m<0) info = 5;
+ else if(*n<0) info = 6;
+ else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
+ else if(*ldb<std::max(1,*m)) info = 11;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
+
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
- if(code>=32 || func[code]==0 || *m<0 || *n <0)
- {
- int info=1;
- xerbla_("TRMM",&info,4);
- return 0;
- }
if(*m==0 || *n==0)
return 1;
@@ -230,12 +247,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
- if(*m<0 || *n<0)
- {
- int info=1;
- xerbla_("SYMM",&info,4);
- return 0;
- }
+ int info = 0;
+ if(SIDE(*side)==INVALID) info = 1;
+ else if(UPLO(*uplo)==INVALID) info = 2;
+ else if(*m<0) info = 3;
+ else if(*n<0) info = 4;
+ else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
+ else if(*ldb<std::max(1,*m)) info = 9;
+ else if(*ldc<std::max(1,*m)) info = 12;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
if(beta!=Scalar(1))
{
@@ -312,13 +333,17 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if(OP(*op)==INVALID) info = 2;
+ else if(*n<0) info = 3;
+ else if(*k<0) info = 4;
+ else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
+ else if(*ldc<std::max(1,*n)) info = 10;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
+
int code = OP(*op) | (UPLO(*uplo) << 2);
- if(code>=8 || func[code]==0 || *n<0 || *k<0)
- {
- int info=1;
- xerbla_("SYRK",&info,4);
- return 0;
- }
if(beta!=Scalar(1))
{
@@ -358,12 +383,18 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
-
- if(*n<=0 || *k<0)
- {
- return 0;
- }
-
+
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if(OP(*op)==INVALID) info = 2;
+ else if(*n<0) info = 3;
+ else if(*k<0) info = 4;
+ else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
+ else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
+ else if(*ldc<std::max(1,*n)) info = 12;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
+
if(beta!=Scalar(1))
{
if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
@@ -416,10 +447,16 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
- if(*m<0 || *n<0)
- {
- return 0;
- }
+ int info = 0;
+ if(SIDE(*side)==INVALID) info = 1;
+ else if(UPLO(*uplo)==INVALID) info = 2;
+ else if(*m<0) info = 3;
+ else if(*n<0) info = 4;
+ else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
+ else if(*ldb<std::max(1,*m)) info = 9;
+ else if(*ldc<std::max(1,*m)) info = 12;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
if(beta==Scalar(0))
matrix(c, *m, *n, *ldc).setZero();
@@ -484,14 +521,17 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
- if(*n<0 || *k<0)
- {
- return 0;
- }
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
+ else if(*n<0) info = 3;
+ else if(*k<0) info = 4;
+ else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
+ else if(*ldc<std::max(1,*n)) info = 10;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
int code = OP(*op) | (UPLO(*uplo) << 2);
- if(code>=8 || func[code]==0)
- return 0;
if(beta!=RealScalar(1))
{
@@ -520,10 +560,16 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
RealScalar beta = *pbeta;
- if(*n<=0 || *k<0)
- {
- return 0;
- }
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
+ else if(*n<0) info = 3;
+ else if(*k<0) info = 4;
+ else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
+ else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
+ else if(*ldc<std::max(1,*n)) info = 12;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
if(beta!=RealScalar(1))
{