From 1ac9124fac72c11eab3d831e142bba8927c140d0 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 20 Nov 2010 23:29:20 +0100 Subject: implements TRMV level 2 blas routine --- blas/level2_impl.h | 64 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 23 deletions(-) (limited to 'blas') diff --git a/blas/level2_impl.h b/blas/level2_impl.h index ba3868145..094d6b5d3 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -139,11 +139,8 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) { - return 0; - // TODO - - typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int); - functype func[16]; + typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); + static functype func[16]; static bool init = false; if(!init) @@ -151,21 +148,21 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar for(int k=0; k<16; ++k) func[k] = 0; -// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); init = true; } @@ -173,11 +170,32 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*opa)==INVALID) info = 2; + else if(DIAG(*diag)==INVALID) info = 3; + else if(*n<0) info = 4; + else if(*lda res(*n); + res.setZero(); + int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3); if(code>=16 || func[code]==0) return 0; - func[code](*n, a, *lda, b, *incb, b, *incb); + func[code](*n, *n, a, *lda, actual_b, 1, res.data(), 1, Scalar(1)); + + copy_back(res.data(),b,*n,*incb); + if(actual_b!=b) delete[] actual_b; + return 0; } @@ -194,7 +212,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, { // typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar); -// functype func[2]; +// static functype func[2]; // static bool init = false; // if(!init) @@ -241,7 +259,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc) { // typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); -// functype func[2]; +// static functype func[2]; // // static bool init = false; // if(!init) -- cgit v1.2.3