diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 14:14:50 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 14:14:50 +0100 |
commit | 0e30c4ae3f35523302f1449f67a2be714e30beb8 (patch) | |
tree | d723cecf1cd7faf84e51162e035d1b91ffac6c36 /blas/level2_impl.h | |
parent | 3fdea699b80c429738ac0af8c9b7479594b90583 (diff) |
blas level2: gemv and trsv are green
Diffstat (limited to 'blas/level2_impl.h')
-rw-r--r-- | blas/level2_impl.h | 130 |
1 files changed, 83 insertions, 47 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h index 2749cf5b3..55851ddb3 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -24,8 +24,39 @@ #include "common.h" +#define MAKE_ACTUAL_VECTOR(X,INCX,N,COND) \ + Scalar* actual_##X = X; \ + if(COND) { \ + actual_##X = new Scalar[N]; \ + if((INCX)<0) vector(actual_##X,(N)) = vector(X,(N),-(INCX)).reverse(); \ + else vector(actual_##X,(N)) = vector(X,(N), (INCX)); \ + } + +#define RELEASE_ACTUAL_VECTOR(X,INCX,N,COND) \ + if(COND) { \ + if((INCX)<0) vector(X,(N),-(INCX)).reverse() = vector(actual_##X,(N)); \ + else vector(X,(N), (INCX)) = vector(actual_##X,(N)); \ + delete[] actual_##X; \ + } + int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc) { + typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar); + static functype func[4]; + + static bool init = false; + if(!init) + { + for(int k=0; k<4; ++k) + func[k] = 0; + + func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run); + func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run); + func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run); + + init = true; + } + Scalar* a = reinterpret_cast<Scalar*>(pa); Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* c = reinterpret_cast<Scalar*>(pc); @@ -34,9 +65,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca // check arguments int info = 0; - if( OP(*opa)!=NOTR - && OP(*opa)!=TR - && OP(*opa)!=ADJ) info = 1; + if(OP(*opa)==INVALID) info = 1; else if(*m<0) info = 2; else if(*n<0) info = 3; else if(*lda<std::max(1,*m)) info = 6; @@ -44,39 +73,34 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca else if(*incc==0) info = 11; if(info) return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6); -// return xerbla_("SGEMV ",&info,sizeof("SGEMV ")); - - if(beta!=Scalar(1)) - vector(c, *m, *incc) *= beta; - - if(OP(*opa)==NOTR) - if(*incc==1) - vector(c,*m) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb); - else - vector(c,*m,*incc) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb); - else if(OP(*opa)==TR) - if(*incb==1) - vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n); - else - vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n,*incb); - else if(OP(*opa)==TR) - if(*incb==1) - vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n); - else - vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n,*incb); - else + + if(*m==0 || *n==0) return 0; + int actual_m = *m; + int actual_n = *n; + if(OP(*opa)!=NOTR) + std::swap(actual_m,actual_n); + + MAKE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1) + MAKE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1) + + if(beta!=Scalar(1)) + vector(actual_c, actual_m, 1) *= beta; + + int code = OP(*opa); + func[code](actual_m, actual_n, a, *lda, actual_b, 1, actual_c, 1, alpha); + + RELEASE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1) + RELEASE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1) + return 1; } - int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) { - return 0; - - typedef void (*functype)(int, const Scalar *, int, Scalar *, int); - functype func[16]; + typedef void (*functype)(int, const Scalar *, int, Scalar *); + static functype func[16]; static bool init = false; if(!init) @@ -84,21 +108,21 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar for(int k=0; k<16; ++k) func[k] = 0; -// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,ColMajor,ColMajor>::run); -// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,RowMajor,ColMajor>::run); -// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, Conj, RowMajor,ColMajor>::run); -// -// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,ColMajor,ColMajor>::run); -// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,RowMajor,ColMajor>::run); -// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, Conj, RowMajor,ColMajor>::run); -// -// func[NOTR | (UP << 3) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run); -// func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run); -// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run); -// -// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run); -// func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run); -// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run); + func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,ColMajor>::run); + func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,RowMajor>::run); + func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, Conj, RowMajor>::run); + + func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,ColMajor>::run); + func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,RowMajor>::run); + func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, Conj, RowMajor>::run); + + func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,ColMajor>::run); + func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,RowMajor>::run); + func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,Conj, RowMajor>::run); + + func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,ColMajor>::run); + func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,RowMajor>::run); + func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,Conj, RowMajor>::run); init = true; } @@ -106,11 +130,23 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar Scalar* a = reinterpret_cast<Scalar*>(pa); Scalar* b = reinterpret_cast<Scalar*>(pb); + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*opa)==INVALID) info = 2; + else if(DIAG(*diag)==INVALID) info = 3; + else if(*n<0) info = 4; + else if(*lda<std::max(1,*n)) info = 6; + else if(*incb==0) info = 8; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"TRSV ",&info,6); + + MAKE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1) + int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3); - if(code>=16 || func[code]==0) - return 0; + func[code](*n, a, *lda, actual_b); - func[code](*n, a, *lda, b, *incb); + RELEASE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1) + return 0; } |