From 2a820d41df2fcbf34d14d538ba8280271a96ad92 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 17 Jul 2010 13:49:43 +0200 Subject: finish/fix level1 blas, all test pass --- bench/btl/libs/C_BLAS/blas.h | 8 ++-- blas/common.h | 1 + blas/level1_impl.h | 99 ++++++++++++++++++++++++++++++-------------- 3 files changed, 73 insertions(+), 35 deletions(-) diff --git a/bench/btl/libs/C_BLAS/blas.h b/bench/btl/libs/C_BLAS/blas.h index 07cd9efd2..ab3d44052 100644 --- a/bench/btl/libs/C_BLAS/blas.h +++ b/bench/btl/libs/C_BLAS/blas.h @@ -38,10 +38,10 @@ void BLASFUNC(zdotc) (double *, int *, double *, int *, double *, int *); void BLASFUNC(xdotu) (double *, int *, double *, int *, double *, int *); void BLASFUNC(xdotc) (double *, int *, double *, int *, double *, int *); #else -float BLASFUNC(cdotu) (int *, float *, int *, float *, int *); -float BLASFUNC(cdotc) (int *, float *, int *, float *, int *); -double BLASFUNC(zdotu) (int *, double *, int *, double *, int *); -double BLASFUNC(zdotc) (int *, double *, int *, double *, int *); +std::complex BLASFUNC(cdotu) (int *, float *, int *, float *, int *); +std::complex BLASFUNC(cdotc) (int *, float *, int *, float *, int *); +std::complex BLASFUNC(zdotu) (int *, double *, int *, double *, int *); +std::complex BLASFUNC(zdotc) (int *, double *, int *, double *, int *); double BLASFUNC(xdotu) (int *, double *, int *, double *, int *); double BLASFUNC(xdotc) (int *, double *, int *, double *, int *); #endif diff --git a/blas/common.h b/blas/common.h index 843e64ab7..5e7117dee 100644 --- a/blas/common.h +++ b/blas/common.h @@ -26,6 +26,7 @@ #define EIGEN_BLAS_COMMON_H #include +#include #ifndef SCALAR #error the token SCALAR must be defined to compile this file diff --git a/blas/level1_impl.h b/blas/level1_impl.h index 5ea80064f..1665310c4 100644 --- a/blas/level1_impl.h +++ b/blas/level1_impl.h @@ -153,44 +153,36 @@ Scalar EIGEN_BLAS_FUNC(sdot)(int *n, RealScalar *px, int *incx, RealScalar *py, #if ISCOMPLEX // computes a dot product of a conjugated vector with another vector. -void EIGEN_BLAS_FUNC(dotc)(RealScalar* dot, int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) +Scalar EIGEN_BLAS_FUNC(dotc)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { - - std::cerr << "Eigen BLAS: _dotc is not implemented yet\n"; - - return; - - // TODO: find how to return a complex to fortran - // std::cerr << "_dotc " << *n << " " << *incx << " " << *incy << "\n"; Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - if(*incx==1 && *incy==1) - *reinterpret_cast(dot) = vector(x,*n).dot(vector(y,*n)); - else - *reinterpret_cast(dot) = vector(x,*n,*incx).dot(vector(y,*n,*incy)); + Scalar res; + if(*incx==1 && *incy==1) res = (vector(x,*n).dot(vector(y,*n))); + else if(*incx>0 && *incy>0) res = (vector(x,*n,*incx).dot(vector(y,*n,*incy))); + else if(*incx<0 && *incy>0) res = (vector(x,*n,-*incx).reverse().dot(vector(y,*n,*incy))); + else if(*incx>0 && *incy<0) res = (vector(x,*n,*incx).dot(vector(y,*n,-*incy).reverse())); + else if(*incx<0 && *incy<0) res = (vector(x,*n,-*incx).reverse().dot(vector(y,*n,-*incy).reverse())); + return res; } // computes a vector-vector dot product without complex conjugation. -void EIGEN_BLAS_FUNC(dotu)(RealScalar* dot, int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) +Scalar EIGEN_BLAS_FUNC(dotu)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { - std::cerr << "Eigen BLAS: _dotu is not implemented yet\n"; - - return; - - // TODO: find how to return a complex to fortran - // std::cerr << "_dotu " << *n << " " << *incx << " " << *incy << "\n"; Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - - if(*incx==1 && *incy==1) - *reinterpret_cast(dot) = (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); - else - *reinterpret_cast(dot) = (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); + Scalar res; + if(*incx==1 && *incy==1) res = (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); + else if(*incx>0 && *incy>0) res = (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); + else if(*incx<0 && *incy>0) res = (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,*incy))).sum(); + else if(*incx>0 && *incy<0) res = (vector(x,*n,*incx).cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); + else if(*incx<0 && *incy<0) res = (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); + return res; } #endif // ISCOMPLEX @@ -251,15 +243,60 @@ int EIGEN_BLAS_FUNC(rot)(int *n, RealScalar *px, int *incx, RealScalar *py, int int EIGEN_BLAS_FUNC(rotg)(RealScalar *pa, RealScalar *pb, RealScalar *pc, RealScalar *ps) { - Scalar a = *reinterpret_cast(pa); - Scalar b = *reinterpret_cast(pb); - Scalar* c = reinterpret_cast(pc); + Scalar& a = *reinterpret_cast(pa); + Scalar& b = *reinterpret_cast(pb); + RealScalar* c = pc; Scalar* s = reinterpret_cast(ps); - PlanarRotation r; - r.makeGivens(a,b); - *c = r.c(); - *s = r.s(); + #if !ISCOMPLEX + Scalar r,z; + Scalar aa = ei_abs(a); + Scalar ab = ei_abs(b); + if((aa+ab)==Scalar(0)) + { + *c = 1; + *s = 0; + r = 0; + z = 0; + } + else + { + r = ei_sqrt(a*a + b*b); + Scalar amax = aa>ab ? a : b; + r = amax>0 ? r : -r; + *c = a/r; + *s = b/r; + z = 1; + if (aa > ab) z = *s; + if (ab > aa && *c!=RealScalar(0)) + z = Scalar(1)/ *c; + } + *pa = r; + *pb = z; + #else + Scalar alpha; + RealScalar norm,scale; + if(ei_abs(a)==RealScalar(0)) + { + *c = RealScalar(0); + *s = Scalar(1); + a = b; + } + else + { + scale = ei_abs(a) + ei_abs(b); + norm = scale*ei_sqrt((ei_abs2(a/scale))+ (ei_abs2(b/scale))); + alpha = a/ei_abs(a); + *c = ei_abs(a)/norm; + *s = alpha*ei_conj(b)/norm; + a = alpha*norm; + } + #endif + +// PlanarRotation r; +// r.makeGivens(a,b); +// *c = r.c(); +// *s = r.s(); return 0; } -- cgit v1.2.3