diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-19 16:09:25 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-19 16:09:25 +0100 |
commit | e14f14642d6101f94274459cfd499d2ec28925e4 (patch) | |
tree | e747b94f138bb642756ab4adc3c9ea37c3b68b90 /blas/level2_impl.h | |
parent | 661ef6c127fb608bd91078d23032613194d3d512 (diff) |
implement SYR and SYR2
Diffstat (limited to 'blas/level2_impl.h')
-rw-r--r-- | blas/level2_impl.h | 143 |
1 files changed, 99 insertions, 44 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h index 55851ddb3..2dc059c14 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -204,70 +204,125 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p // TODO } -int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *inca, RealScalar *pc, int *ldc) +// C := alpha*x*x' + C +int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pc, int *ldc) { - return 0; - - // TODO - typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar); - functype func[2]; - - static bool init = false; - if(!init) - { - for(int k=0; k<2; ++k) - func[k] = 0; - + +// typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar); +// functype func[2]; + +// static bool init = false; +// if(!init) +// { +// for(int k=0; k<2; ++k) +// func[k] = 0; +// // func[UP] = (internal::selfadjoint_product<Scalar,ColMajor,ColMajor,false,UpperTriangular>::run); // func[LO] = (internal::selfadjoint_product<Scalar,ColMajor,ColMajor,false,LowerTriangular>::run); - init = true; - } +// init = true; +// } - Scalar* a = reinterpret_cast<Scalar*>(pa); + Scalar* x = reinterpret_cast<Scalar*>(px); Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); + + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(*n<0) info = 2; + else if(*incx==0) info = 5; + else if(*ldc<std::max(1,*n)) info = 7; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"SYR ",&info,6); + + if(alpha==Scalar(0)) + return 1; + + // if the increment is not 1, let's copy it to a temporary vector to enable vectorization + Scalar* x_cpy = x; + if(*incx!=1) + { + x_cpy = new Scalar[*n]; + if(*incx<0) vector(x_cpy,*n) = vector(x,*n,-*incx).reverse(); + else vector(x_cpy,*n) = vector(x,*n,*incx); + } + + // TODO perform direct calls to underlying implementation + if(UPLO(*uplo)==LO) matrix(c,*n,*n,*ldc).selfadjointView<Lower>().rankUpdate(vector(x_cpy,*n), alpha); + else if(UPLO(*uplo)==UP) matrix(c,*n,*n,*ldc).selfadjointView<Upper>().rankUpdate(vector(x_cpy,*n), alpha); + + if(*incx!=1) + delete[] x_cpy; - int code = UPLO(*uplo); - if(code>=2 || func[code]==0) - return 0; - - func[code](*n, a, *inca, c, *ldc, alpha); +// func[code](*n, a, *inca, c, *ldc, alpha); return 1; } - -int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *inca, RealScalar *pb, int *incb, RealScalar *pc, int *ldc) +// C := alpha*x*y' + alpha*y*x' + C +int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc) { - return 0; - - // TODO - typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[2]; - - static bool init = false; - if(!init) - { - for(int k=0; k<2; ++k) - func[k] = 0; - +// typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); +// functype func[2]; +// +// static bool init = false; +// if(!init) +// { +// for(int k=0; k<2; ++k) +// func[k] = 0; +// // func[UP] = (internal::selfadjoint_product<Scalar,ColMajor,ColMajor,false,UpperTriangular>::run); // func[LO] = (internal::selfadjoint_product<Scalar,ColMajor,ColMajor,false,LowerTriangular>::run); +// +// init = true; +// } - init = true; - } - - Scalar* a = reinterpret_cast<Scalar*>(pa); - Scalar* b = reinterpret_cast<Scalar*>(pb); + Scalar* x = reinterpret_cast<Scalar*>(px); + Scalar* y = reinterpret_cast<Scalar*>(py); Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar alpha = *reinterpret_cast<Scalar*>(palpha); + + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(*n<0) info = 2; + else if(*incx==0) info = 5; + else if(*incy==0) info = 7; + else if(*ldc<std::max(1,*n)) info = 9; + if(info) + return xerbla_(SCALAR_SUFFIX_UP"SYR2 ",&info,6); + + if(alpha==Scalar(0)) + return 1; + + // if the increment is not 1, let's copy it to a temporary vector to enable vectorization + Scalar* x_cpy = x; + if(*incx!=1) + { + x_cpy = new Scalar[*n]; + if(*incx<0) vector(x_cpy,*n) = vector(x,*n,-*incx).reverse(); + else vector(x_cpy,*n) = vector(x,*n, *incx); + } + + Scalar* y_cpy = y; + if(*incy!=1) + { + y_cpy = new Scalar[*n]; + if(*incy<0) vector(y_cpy,*n) = vector(y,*n,-*incy).reverse(); + else vector(y_cpy,*n) = vector(y,*n, *incy); + } + + // TODO perform direct calls to underlying implementation + if(UPLO(*uplo)==LO) matrix(c,*n,*n,*ldc).selfadjointView<Lower>().rankUpdate(vector(x_cpy,*n), vector(y_cpy,*n), alpha); + else if(UPLO(*uplo)==UP) matrix(c,*n,*n,*ldc).selfadjointView<Upper>().rankUpdate(vector(x_cpy,*n), vector(y_cpy,*n), alpha); + + if(*incx!=1) delete[] x_cpy; + if(*incy!=1) delete[] y_cpy; - int code = UPLO(*uplo); - if(code>=2 || func[code]==0) - return 0; +// int code = UPLO(*uplo); +// if(code>=2 || func[code]==0) +// return 0; - func[code](*n, a, *inca, b, *incb, c, *ldc, alpha); +// func[code](*n, a, *inca, b, *incb, c, *ldc, alpha); return 1; } |