aboutsummaryrefslogtreecommitdiffhomepage
path: root/blas
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-20 23:29:20 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-20 23:29:20 +0100
commit1ac9124fac72c11eab3d831e142bba8927c140d0 (patch)
tree867d98bad66f7f64eb4f1f97be95c44d50f900b8 /blas
parentd72a8f1e50e98593c79bb05175b14a910f7b4a69 (diff)
implements TRMV level 2 blas routine
Diffstat (limited to 'blas')
-rw-r--r--blas/level2_impl.h64
1 files changed, 41 insertions, 23 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h
index ba3868145..094d6b5d3 100644
--- a/blas/level2_impl.h
+++ b/blas/level2_impl.h
@@ -139,11 +139,8 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
{
- return 0;
- // TODO
-
- typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int);
- functype func[16];
+ typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
+ static functype func[16];
static bool init = false;
if(!init)
@@ -151,21 +148,21 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
for(int k=0; k<16; ++k)
func[k] = 0;
-// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
-// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
-// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
-//
-// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
-// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
-// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
-//
-// func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
-// func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
-// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
-//
-// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
-// func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
-// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
+ func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,false,Scalar,false,ColMajor>::run);
+ func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,false,Scalar,false,RowMajor>::run);
+ func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,Conj, Scalar,false,RowMajor>::run);
+
+ func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,false,Scalar,false,ColMajor>::run);
+ func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,false,Scalar,false,RowMajor>::run);
+ func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,Conj, Scalar,false,RowMajor>::run);
+
+ func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,false,Scalar,false,ColMajor>::run);
+ func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,false,Scalar,false,RowMajor>::run);
+ func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,Conj, Scalar,false,RowMajor>::run);
+
+ func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,false,Scalar,false,ColMajor>::run);
+ func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,false,Scalar,false,RowMajor>::run);
+ func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,Conj, Scalar,false,RowMajor>::run);
init = true;
}
@@ -173,11 +170,32 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if(OP(*opa)==INVALID) info = 2;
+ else if(DIAG(*diag)==INVALID) info = 3;
+ else if(*n<0) info = 4;
+ else if(*lda<std::max(1,*n)) info = 6;
+ else if(*incb==0) info = 8;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"TRMV ",&info,6);
+
+ if(*n==0)
+ return 1;
+
+ Scalar* actual_b = get_compact_vector(b,*n,*incb);
+ Matrix<Scalar,Dynamic,1> res(*n);
+ res.setZero();
+
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
if(code>=16 || func[code]==0)
return 0;
- func[code](*n, a, *lda, b, *incb, b, *incb);
+ func[code](*n, *n, a, *lda, actual_b, 1, res.data(), 1, Scalar(1));
+
+ copy_back(res.data(),b,*n,*incb);
+ if(actual_b!=b) delete[] actual_b;
+
return 0;
}
@@ -194,7 +212,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
{
// typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar);
-// functype func[2];
+// static functype func[2];
// static bool init = false;
// if(!init)
@@ -241,7 +259,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc)
{
// typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
-// functype func[2];
+// static functype func[2];
//
// static bool init = false;
// if(!init)