aboutsummaryrefslogtreecommitdiffhomepage
path: root/blas/level2_impl.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 14:14:50 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 14:14:50 +0100
commit0e30c4ae3f35523302f1449f67a2be714e30beb8 (patch)
treed723cecf1cd7faf84e51162e035d1b91ffac6c36 /blas/level2_impl.h
parent3fdea699b80c429738ac0af8c9b7479594b90583 (diff)
blas level2: gemv and trsv are green
Diffstat (limited to 'blas/level2_impl.h')
-rw-r--r--blas/level2_impl.h130
1 files changed, 83 insertions, 47 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h
index 2749cf5b3..55851ddb3 100644
--- a/blas/level2_impl.h
+++ b/blas/level2_impl.h
@@ -24,8 +24,39 @@
#include "common.h"
+#define MAKE_ACTUAL_VECTOR(X,INCX,N,COND) \
+ Scalar* actual_##X = X; \
+ if(COND) { \
+ actual_##X = new Scalar[N]; \
+ if((INCX)<0) vector(actual_##X,(N)) = vector(X,(N),-(INCX)).reverse(); \
+ else vector(actual_##X,(N)) = vector(X,(N), (INCX)); \
+ }
+
+#define RELEASE_ACTUAL_VECTOR(X,INCX,N,COND) \
+ if(COND) { \
+ if((INCX)<0) vector(X,(N),-(INCX)).reverse() = vector(actual_##X,(N)); \
+ else vector(X,(N), (INCX)) = vector(actual_##X,(N)); \
+ delete[] actual_##X; \
+ }
+
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
{
+ typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
+ static functype func[4];
+
+ static bool init = false;
+ if(!init)
+ {
+ for(int k=0; k<4; ++k)
+ func[k] = 0;
+
+ func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run);
+ func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run);
+ func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run);
+
+ init = true;
+ }
+
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
@@ -34,9 +65,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
// check arguments
int info = 0;
- if( OP(*opa)!=NOTR
- && OP(*opa)!=TR
- && OP(*opa)!=ADJ) info = 1;
+ if(OP(*opa)==INVALID) info = 1;
else if(*m<0) info = 2;
else if(*n<0) info = 3;
else if(*lda<std::max(1,*m)) info = 6;
@@ -44,39 +73,34 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
else if(*incc==0) info = 11;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6);
-// return xerbla_("SGEMV ",&info,sizeof("SGEMV "));
-
- if(beta!=Scalar(1))
- vector(c, *m, *incc) *= beta;
-
- if(OP(*opa)==NOTR)
- if(*incc==1)
- vector(c,*m) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
- else
- vector(c,*m,*incc) += alpha * matrix(a,*m,*n,*lda) * vector(b,*n,*incb);
- else if(OP(*opa)==TR)
- if(*incb==1)
- vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n);
- else
- vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).transpose() * vector(b,*n,*incb);
- else if(OP(*opa)==TR)
- if(*incb==1)
- vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n);
- else
- vector(c,*m,*incc) += alpha * matrix(a,*n,*m,*lda).adjoint() * vector(b,*n,*incb);
- else
+
+ if(*m==0 || *n==0)
return 0;
+ int actual_m = *m;
+ int actual_n = *n;
+ if(OP(*opa)!=NOTR)
+ std::swap(actual_m,actual_n);
+
+ MAKE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
+ MAKE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
+
+ if(beta!=Scalar(1))
+ vector(actual_c, actual_m, 1) *= beta;
+
+ int code = OP(*opa);
+ func[code](actual_m, actual_n, a, *lda, actual_b, 1, actual_c, 1, alpha);
+
+ RELEASE_ACTUAL_VECTOR(b,*incb,actual_n,*incb!=1)
+ RELEASE_ACTUAL_VECTOR(c,*incc,actual_m,*incc!=1)
+
return 1;
}
-
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
{
- return 0;
-
- typedef void (*functype)(int, const Scalar *, int, Scalar *, int);
- functype func[16];
+ typedef void (*functype)(int, const Scalar *, int, Scalar *);
+ static functype func[16];
static bool init = false;
if(!init)
@@ -84,21 +108,21 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
for(int k=0; k<16; ++k)
func[k] = 0;
-// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,ColMajor,ColMajor>::run);
-// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, false,RowMajor,ColMajor>::run);
-// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|0, Conj, RowMajor,ColMajor>::run);
-//
-// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,ColMajor,ColMajor>::run);
-// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, false,RowMajor,ColMajor>::run);
-// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|0, Conj, RowMajor,ColMajor>::run);
-//
-// func[NOTR | (UP << 3) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
-// func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
-// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, UpperTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
-//
-// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,ColMajor,ColMajor>::run);
-// func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,false,RowMajor,ColMajor>::run);
-// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar, LowerTriangular|UnitDiagBit,Conj, RowMajor,ColMajor>::run);
+ func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,ColMajor>::run);
+ func[TR | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,RowMajor>::run);
+ func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, Conj, RowMajor>::run);
+
+ func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|0, false,ColMajor>::run);
+ func[TR | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, false,RowMajor>::run);
+ func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|0, Conj, RowMajor>::run);
+
+ func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,ColMajor>::run);
+ func[TR | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,RowMajor>::run);
+ func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,Conj, RowMajor>::run);
+
+ func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Lower|UnitDiag,false,ColMajor>::run);
+ func[TR | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,false,RowMajor>::run);
+ func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::triangular_solve_vector<Scalar,Scalar,int,OnTheLeft, Upper|UnitDiag,Conj, RowMajor>::run);
init = true;
}
@@ -106,11 +130,23 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
+ int info = 0;
+ if(UPLO(*uplo)==INVALID) info = 1;
+ else if(OP(*opa)==INVALID) info = 2;
+ else if(DIAG(*diag)==INVALID) info = 3;
+ else if(*n<0) info = 4;
+ else if(*lda<std::max(1,*n)) info = 6;
+ else if(*incb==0) info = 8;
+ if(info)
+ return xerbla_(SCALAR_SUFFIX_UP"TRSV ",&info,6);
+
+ MAKE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
+
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
- if(code>=16 || func[code]==0)
- return 0;
+ func[code](*n, a, *lda, actual_b);
- func[code](*n, a, *lda, b, *incb);
+ RELEASE_ACTUAL_VECTOR(b,*incb,*n,*incb!=1)
+
return 0;
}