From fd88d721d2327e92a8c6c156dde266967dfb0d91 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 3 Nov 2010 22:03:12 +0100 Subject: implement proper error handling in level 3 routines --- blas/level3_impl.h | 150 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 52 deletions(-) (limited to 'blas/level3_impl.h') diff --git a/blas/level3_impl.h b/blas/level3_impl.h index e5b781cd6..b07c20ee6 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -52,14 +52,18 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal Scalar* c = reinterpret_cast(pc); Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - - int code = OP(*opa) | (OP(*opb) << 2); - if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0)) - { - int info = 1; - xerbla_("GEMM", &info, 4); - return 0; - } + + int info = 0; + if(OP(*opa)==INVALID) info = 1; + else if(OP(*opb)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*k<0) info = 5; + else if(*lda blocking(*m,*n,*k); + int code = OP(*opa) | (OP(*opb) << 2); func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0); return 0; } @@ -125,13 +130,19 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast(pb); Scalar alpha = *reinterpret_cast(palpha); + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(OP(*opa)==INVALID) info = 3; + else if(DIAG(*diag)==INVALID) info = 4; + else if(*m<0) info = 5; + else if(*n<0) info = 6; + else if(*lda=32 || func[code]==0 || *m<0 || *n <0) - { - int info=1; - xerbla_("TRSM",&info,4); - return 0; - } if(SIDE(*side)==LEFT) func[code](*m, *n, a, *lda, b, *ldb); @@ -197,13 +208,19 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast(pb); Scalar alpha = *reinterpret_cast(palpha); + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(OP(*opa)==INVALID) info = 3; + else if(DIAG(*diag)==INVALID) info = 4; + else if(*m<0) info = 5; + else if(*n<0) info = 6; + else if(*lda=32 || func[code]==0 || *m<0 || *n <0) - { - int info=1; - xerbla_("TRMM",&info,4); - return 0; - } if(*m==0 || *n==0) return 1; @@ -230,12 +247,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - if(*m<0 || *n<0) - { - int info=1; - xerbla_("SYMM",&info,4); - return 0; - } + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*lda(palpha); Scalar beta = *reinterpret_cast(pbeta); + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*op)==INVALID) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda=8 || func[code]==0 || *n<0 || *k<0) - { - int info=1; - xerbla_("SYRK",&info,4); - return 0; - } if(beta!=Scalar(1)) { @@ -358,12 +383,18 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar* c = reinterpret_cast(pc); Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - - if(*n<=0 || *k<0) - { - return 0; - } - + + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*op)==INVALID) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda() *= beta; @@ -416,10 +447,16 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; - if(*m<0 || *n<0) - { - return 0; - } + int info = 0; + if(SIDE(*side)==INVALID) info = 1; + else if(UPLO(*uplo)==INVALID) info = 2; + else if(*m<0) info = 3; + else if(*n<0) info = 4; + else if(*lda=8 || func[code]==0) - return 0; if(beta!=RealScalar(1)) { @@ -520,10 +560,16 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar alpha = *reinterpret_cast(palpha); RealScalar beta = *pbeta; - if(*n<=0 || *k<0) - { - return 0; - } + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2; + else if(*n<0) info = 3; + else if(*k<0) info = 4; + else if(*lda