diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-03 22:03:12 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-03 22:03:12 +0100 |
commit | fd88d721d2327e92a8c6c156dde266967dfb0d91 (patch) | |
tree | 4aca8c505d28ebc405c6e4dd5f8c9889105361d4 /blas | |
parent | a8fb6b0ad33a4620424e31e842b54c8cd255c6d2 (diff) |
implement proper error handling in level 3 routines
Diffstat (limited to 'blas')
-rw-r--r-- | blas/common.h | 26 | ||||
-rw-r--r-- | blas/complex_double.cpp | 1 | ||||
-rw-r--r-- | blas/complex_single.cpp | 1 | ||||
-rw-r--r-- | blas/double.cpp | 1 | ||||
-rw-r--r-- | blas/level2_impl.h | 14 | ||||
-rw-r--r-- | blas/level3_impl.h | 150 | ||||
-rw-r--r-- | blas/single.cpp | 1 |
7 files changed, 138 insertions, 56 deletions
diff --git a/blas/common.h b/blas/common.h index c91cdc9a1..d56815ce3 100644 --- a/blas/common.h +++ b/blas/common.h @@ -56,22 +56,40 @@ extern "C" #define NUNIT 0 #define UNIT 1 +#define INVALID 0xff + #define OP(X) ( ((X)=='N' || (X)=='n') ? NOTR \ : ((X)=='T' || (X)=='t') ? TR \ : ((X)=='C' || (X)=='c') ? ADJ \ - : 0xff) + : INVALID) #define SIDE(X) ( ((X)=='L' || (X)=='l') ? LEFT \ : ((X)=='R' || (X)=='r') ? RIGHT \ - : 0xff) + : INVALID) #define UPLO(X) ( ((X)=='U' || (X)=='u') ? UP \ : ((X)=='L' || (X)=='l') ? LO \ - : 0xff) + : INVALID) #define DIAG(X) ( ((X)=='N' || (X)=='N') ? NUNIT \ : ((X)=='U' || (X)=='u') ? UNIT \ - : 0xff) + : INVALID) + + +inline bool check_op(const char* op) +{ + return OP(*op)!=0xff; +} + +inline bool check_side(const char* side) +{ + return SIDE(*side)!=0xff; +} + +inline bool check_uplo(const char* uplo) +{ + return UPLO(*uplo)!=0xff; +} #include <Eigen/Core> #include <Eigen/Jacobi> diff --git a/blas/complex_double.cpp b/blas/complex_double.cpp index f3065c1d6..bd7674cda 100644 --- a/blas/complex_double.cpp +++ b/blas/complex_double.cpp @@ -24,6 +24,7 @@ #define SCALAR std::complex<double> #define SCALAR_SUFFIX z +#define SCALAR_SUFFIX_UP "Z" #define REAL_SCALAR_SUFFIX d #define ISCOMPLEX 1 diff --git a/blas/complex_single.cpp b/blas/complex_single.cpp index b88afb667..4cf19378f 100644 --- a/blas/complex_single.cpp +++ b/blas/complex_single.cpp @@ -24,6 +24,7 @@ #define SCALAR std::complex<float> #define SCALAR_SUFFIX c +#define SCALAR_SUFFIX_UP "C" #define REAL_SCALAR_SUFFIX s #define ISCOMPLEX 1 diff --git a/blas/double.cpp b/blas/double.cpp index 7f2c58484..10373d585 100644 --- a/blas/double.cpp +++ b/blas/double.cpp @@ -24,6 +24,7 @@ #define SCALAR double #define SCALAR_SUFFIX d +#define SCALAR_SUFFIX_UP "D" #define ISCOMPLEX 0 #include "level1_impl.h" diff --git a/blas/level2_impl.h b/blas/level2_impl.h index a7d5adb64..3489a426d 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -32,6 +32,20 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); + // check arguments + int info = 0; + if( OP(*opa)!=NOTR + && OP(*opa)!=TR + && OP(*opa)!=ADJ) info = 1; + else if(*m<0) info = 2; + else if(*n<0) info = 3; + else if(*lda<std::max(1,*m)) info = 6; + else if(*incb==0) info = 8; + else if(*incc==0) info = 11; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6); +// return xerbla_("SGEMV ",&info,sizeof("SGEMV ")); + if(beta!=Scalar(1)) vector(c, *m, *incc) *= beta; diff --git a/blas/level3_impl.h b/blas/level3_impl.h index e5b781cd6..b07c20ee6 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -52,14 +52,18 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); - - int code = OP(*opa) | (OP(*opb) << 2); - if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0)) - { - int info = 1; - xerbla_("GEMM", &info, 4); - return 0; - } + + int info = 0; + if(OP(*opa)==INVALID) info = 1; + else if(OP(*opb)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*k<0) info = 5; + else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8; + else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10; + else if(*ldc<std::max(1,*m)) info = 13; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6); if(beta!=Scalar(1)) { @@ -69,6 +73,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k); + int code = OP(*opa) | (OP(*opb) << 2); func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0); return 0; } @@ -125,13 +130,19 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(OP(*opa)==INVALID) info = 3; + else if(DIAG(*diag)==INVALID) info = 4; + else if(*m<0) info = 5; + else if(*n<0) info = 6; + else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9; + else if(*ldb<std::max(1,*m)) info = 11; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6); + int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0 || *m<0 || *n <0) - { - int info=1; - xerbla_("TRSM",&info,4); - return 0; - } if(SIDE(*side)==LEFT) func[code](*m, *n, a, *lda, b, *ldb); @@ -197,13 +208,19 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(OP(*opa)==INVALID) info = 3; + else if(DIAG(*diag)==INVALID) info = 4; + else if(*m<0) info = 5; + else if(*n<0) info = 6; + else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9; + else if(*ldb<std::max(1,*m)) info = 11; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6); + int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0 || *m<0 || *n <0) - { - int info=1; - xerbla_("TRMM",&info,4); - return 0; - } if(*m==0 || *n==0) return 1; @@ -230,12 +247,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); - if(*m<0 || *n<0) - { - int info=1; - xerbla_("SYMM",&info,4); - return 0; - } + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7; + else if(*ldb<std::max(1,*m)) info = 9; + else if(*ldc<std::max(1,*m)) info = 12; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6); if(beta!=Scalar(1)) { @@ -312,13 +333,17 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*op)==INVALID) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; + else if(*ldc<std::max(1,*n)) info = 10; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6); + int code = OP(*op) | (UPLO(*uplo) << 2); - if(code>=8 || func[code]==0 || *n<0 || *k<0) - { - int info=1; - xerbla_("SYRK",&info,4); - return 0; - } if(beta!=Scalar(1)) { @@ -358,12 +383,18 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); - - if(*n<=0 || *k<0) - { - return 0; - } - + + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*op)==INVALID) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; + else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9; + else if(*ldc<std::max(1,*n)) info = 12; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6); + if(beta!=Scalar(1)) { if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; @@ -416,10 +447,16 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; - if(*m<0 || *n<0) - { - return 0; - } + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7; + else if(*ldb<std::max(1,*m)) info = 9; + else if(*ldc<std::max(1,*m)) info = 12; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6); if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); @@ -484,14 +521,17 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; - if(*n<0 || *k<0) - { - return 0; - } + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; + else if(*ldc<std::max(1,*n)) info = 10; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6); int code = OP(*op) | (UPLO(*uplo) << 2); - if(code>=8 || func[code]==0) - return 0; if(beta!=RealScalar(1)) { @@ -520,10 +560,16 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar alpha = *reinterpret_cast<Scalar*>(palpha); RealScalar beta = *pbeta; - if(*n<=0 || *k<0) - { - return 0; - } + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; + else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9; + else if(*ldc<std::max(1,*n)) info = 12; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6); if(beta!=RealScalar(1)) { diff --git a/blas/single.cpp b/blas/single.cpp index dd8d5cde9..9ee2b78dd 100644 --- a/blas/single.cpp +++ b/blas/single.cpp @@ -24,6 +24,7 @@ #define SCALAR float #define SCALAR_SUFFIX s +#define SCALAR_SUFFIX_UP "S" #define ISCOMPLEX 0 #include "level1_impl.h" |