From a76c296e7f56e912e265ee44e565c284cbdd011e Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 2 Mar 2010 14:45:43 +0100 Subject: blas: fix most of level1 functions --- blas/level1_impl.h | 222 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 168 insertions(+), 54 deletions(-) (limited to 'blas/level1_impl.h') diff --git a/blas/level1_impl.h b/blas/level1_impl.h index 5326c6917..fd680b819 100644 --- a/blas/level1_impl.h +++ b/blas/level1_impl.h @@ -30,52 +30,111 @@ int EIGEN_BLAS_FUNC(axpy)(int *n, RealScalar *palpha, RealScalar *px, int *incx, Scalar* y = reinterpret_cast(py); Scalar alpha = *reinterpret_cast(palpha); - if(*incx==1 && *incy==1) - vector(y,*n) += alpha * vector(x,*n); - else - vector(y,*n,*incy) += alpha * vector(x,*n,*incx); +// std::cerr << "axpy " << *n << " " << alpha << " " << *incx << " " << *incy << "\n"; - return 1; + if(*incx==1 && *incy==1) vector(y,*n) += alpha * vector(x,*n); + else if(*incx>0 && *incy>0) vector(y,*n,*incy) += alpha * vector(x,*n,*incx); + else if(*incx>0 && *incy<0) vector(y,*n,-*incy).reverse() += alpha * vector(x,*n,*incx); + else if(*incx<0 && *incy>0) vector(y,*n,*incy) += alpha * vector(x,*n,-*incx).reverse(); + else if(*incx<0 && *incy<0) vector(y,*n,-*incy).reverse() += alpha * vector(x,*n,-*incx).reverse(); + + return 0; } +#if !ISCOMPLEX // computes the sum of magnitudes of all vector elements or, for a complex vector x, the sum // res = |Rex1| + |Imx1| + |Rex2| + |Imx2| + ... + |Rexn| + |Imxn|, where x is a vector of order n RealScalar EIGEN_BLAS_FUNC(asum)(int *n, RealScalar *px, int *incx) { - int size = IsComplex ? 2* *n : *n; +// std::cerr << "_asum " << *n << " " << *incx << "\n"; - if(*incx==1) - return vector(px,size).cwiseAbs().sum(); - else - return vector(px,size,*incx).cwiseAbs().sum(); + Scalar* x = reinterpret_cast(px); - return 1; + if(*n<=0) return 0; + + if(*incx==1) return vector(x,*n).cwiseAbs().sum(); + else return vector(x,*n,std::abs(*incx)).cwiseAbs().sum(); } +#else + +struct ei_scalar_norm1_op { + typedef RealScalar result_type; + EIGEN_EMPTY_STRUCT_CTOR(ei_scalar_norm1_op) + inline RealScalar operator() (const Scalar& a) const { return ei_norm1(a); } +}; +namespace Eigen { +template<> struct ei_functor_traits +{ + enum { Cost = 3 * NumTraits::AddCost, PacketAccess = 0 }; +}; +} + +RealScalar EIGEN_CAT(EIGEN_CAT(REAL_SCALAR_SUFFIX,SCALAR_SUFFIX),asum_)(int *n, RealScalar *px, int *incx) +{ +// std::cerr << "__asum " << *n << " " << *incx << "\n"; + + Complex* x = reinterpret_cast(px); + + if(*n<=0) return 0; + + if(*incx==1) return vector(x,*n).unaryExpr().sum(); + else return vector(x,*n,std::abs(*incx)).unaryExpr().sum(); +} +#endif int EIGEN_BLAS_FUNC(copy)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { - int size = IsComplex ? 2* *n : *n; +// std::cerr << "_copy " << *n << " " << *incx << " " << *incy << "\n"; - if(*incx==1 && *incy==1) - vector(py,size) = vector(px,size); - else - vector(py,size,*incy) = vector(px,size,*incx); + Scalar* x = reinterpret_cast(px); + Scalar* y = reinterpret_cast(py); - return 1; + if(*incx==1 && *incy==1) vector(y,*n) = vector(x,*n); + else if(*incx>0 && *incy>0) vector(y,*n,*incy) = vector(x,*n,*incx); + else if(*incx>0 && *incy<0) vector(y,*n,-*incy).reverse() = vector(x,*n,*incx); + else if(*incx<0 && *incy>0) vector(y,*n,*incy) = vector(x,*n,-*incx).reverse(); + else if(*incx<0 && *incy<0) vector(y,*n,-*incy).reverse() = vector(x,*n,-*incx).reverse(); + + return 0; } // computes a vector-vector dot product. Scalar EIGEN_BLAS_FUNC(dot)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { +// std::cerr << "_dot " << *n << " " << *incx << " " << *incy << "\n"; + + if(*n<=0) + return 0; + Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - if(*incx==1 && *incy==1) - return (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); + if(*incx==1 && *incy==1) return (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); + else if(*incx>0 && *incy>0) return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); + else if(*incx<0 && *incy>0) return (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,*incy))).sum(); + else if(*incx>0 && *incy<0) return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); + else if(*incx<0 && *incy<0) return (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); + else return 0; +} + +int EIGEN_CAT(EIGEN_CAT(i,SCALAR_SUFFIX),amax_)(int *n, RealScalar *px, int *incx) +{ +// std::cerr << "i_amax " << *n << " " << *incx << "\n"; + + Scalar* x = reinterpret_cast(px); - return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); + if(*n<=0) + return 0; + + int ret; + + if(*incx==1) vector(x,*n).cwiseAbs().maxCoeff(&ret); + else vector(x,*n,std::abs(*incx)).cwiseAbs().maxCoeff(&ret); + + return ret+1; } + /* // computes a vector-vector dot product with extended precision. @@ -96,53 +155,95 @@ Scalar EIGEN_BLAS_FUNC(sdot)(int *n, RealScalar *px, int *incx, RealScalar *py, #if ISCOMPLEX // computes a dot product of a conjugated vector with another vector. -Scalar EIGEN_BLAS_FUNC(dotc)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) +void EIGEN_BLAS_FUNC(dotc)(RealScalar* dot, int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { + return; + + // TODO: find how to return a complex to fortran + +// std::cerr << "_dotc " << *n << " " << *incx << " " << *incy << "\n"; + Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); if(*incx==1 && *incy==1) - return vector(x,*n).dot(vector(y,*n)); - - return vector(x,*n,*incx).dot(vector(y,*n,*incy)); + *reinterpret_cast(dot) = vector(x,*n).dot(vector(y,*n)); + else + *reinterpret_cast(dot) = vector(x,*n,*incx).dot(vector(y,*n,*incy)); } // computes a vector-vector dot product without complex conjugation. -Scalar EIGEN_BLAS_FUNC(dotu)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) +void EIGEN_BLAS_FUNC(dotu)(RealScalar* dot, int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { + return; + + // TODO: find how to return a complex to fortran + +// std::cerr << "_dotu " << *n << " " << *incx << " " << *incy << "\n"; + Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); if(*incx==1 && *incy==1) - return (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); - - return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); + *reinterpret_cast(dot) = (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); + else + *reinterpret_cast(dot) = (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); } #endif // ISCOMPLEX +#if !ISCOMPLEX // computes the Euclidean norm of a vector. Scalar EIGEN_BLAS_FUNC(nrm2)(int *n, RealScalar *px, int *incx) { +// std::cerr << "_nrm2 " << *n << " " << *incx << "\n"; Scalar* x = reinterpret_cast(px); + if(*n<=0) + return 0; + + if(*incx==1) return vector(x,*n).norm(); + else return vector(x,*n,std::abs(*incx)).norm(); +} +#else +RealScalar EIGEN_CAT(EIGEN_CAT(REAL_SCALAR_SUFFIX,SCALAR_SUFFIX),nrm2_)(int *n, RealScalar *px, int *incx) +{ +// std::cerr << "__nrm2 " << *n << " " << *incx << "\n"; + Scalar* x = reinterpret_cast(px); + + if(*n<=0) + return 0; + if(*incx==1) return vector(x,*n).norm(); return vector(x,*n,*incx).norm(); } +#endif int EIGEN_BLAS_FUNC(rot)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, RealScalar *ps) { +// std::cerr << "_rot " << *n << " " << *incx << " " << *incy << "\n"; Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); Scalar c = *reinterpret_cast(pc); Scalar s = *reinterpret_cast(ps); - StridedVectorType vx(vector(x,*n,*incx)); - StridedVectorType vy(vector(y,*n,*incy)); - ei_apply_rotation_in_the_plane(vx, vy, PlanarRotation(c,s)); - return 1; + if(*n<=0) + return 0; + + StridedVectorType vx(vector(x,*n,std::abs(*incx))); + StridedVectorType vy(vector(y,*n,std::abs(*incy))); + + Reverse rvx(vx); + Reverse rvy(vy); + + if(*incx<0 && *incy>0) ei_apply_rotation_in_the_plane(rvx, vy, PlanarRotation(c,s)); + else if(*incx>0 && *incy<0) ei_apply_rotation_in_the_plane(vx, rvy, PlanarRotation(c,s)); + else ei_apply_rotation_in_the_plane(vx, vy, PlanarRotation(c,s)); + + + return 0; } int EIGEN_BLAS_FUNC(rotg)(RealScalar *pa, RealScalar *pb, RealScalar *pc, RealScalar *ps) @@ -157,7 +258,7 @@ int EIGEN_BLAS_FUNC(rotg)(RealScalar *pa, RealScalar *pb, RealScalar *pc, RealSc *c = r.c(); *s = r.s(); - return 1; + return 0; } #if !ISCOMPLEX @@ -183,43 +284,56 @@ int EIGEN_BLAS_FUNC(rotmg)(RealScalar *d1, RealScalar *d2, RealScalar *x1, RealS */ #endif // !ISCOMPLEX -int EIGEN_BLAS_FUNC(scal)(int *n, RealScalar *px, int *incx, RealScalar *palpha) +int EIGEN_BLAS_FUNC(scal)(int *n, RealScalar *palpha, RealScalar *px, int *incx) { Scalar* x = reinterpret_cast(px); Scalar alpha = *reinterpret_cast(palpha); - if(*incx==1) - vector(x,*n) *= alpha; + std::cerr << "_scal " << *n << " " << alpha << " " << *incx << "\n"; - vector(x,*n,*incx) *= alpha; + if(*n<=0) + return 0; - return 1; + if(*incx==1) vector(x,*n) *= alpha; + else vector(x,*n,std::abs(*incx)) *= alpha; + + return 0; } -int EIGEN_BLAS_FUNC(swap)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) +#if ISCOMPLEX +int EIGEN_CAT(EIGEN_CAT(SCALAR_SUFFIX,REAL_SCALAR_SUFFIX),scal_)(int *n, RealScalar *palpha, RealScalar *px, int *incx) { - int size = IsComplex ? 2* *n : *n; + Scalar* x = reinterpret_cast(px); + RealScalar alpha = *palpha; - if(*incx==1 && *incy==1) - vector(py,size).swap(vector(px,size)); - else - vector(py,size,*incy).swap(vector(px,size,*incx)); + std::cerr << "__scal " << *n << " " << alpha << " " << *incx << "\n"; - return 1; -} + if(*n<=0) + return 0; -#if !ISCOMPLEX + if(*incx==1) vector(x,*n) *= alpha; + else vector(x,*n,std::abs(*incx)) *= alpha; -RealScalar EIGEN_BLAS_FUNC(casum)(int *n, RealScalar *px, int *incx) + return 0; +} +#endif // ISCOMPLEX + +int EIGEN_BLAS_FUNC(swap)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) { - Complex* x = reinterpret_cast(px); + std::cerr << "_swap " << *n << " " << *incx << " " << *incy << "\n"; - if(*incx==1) - return vector(x,*n).cwiseAbs().sum(); - else - return vector(x,*n,*incx).cwiseAbs().sum(); + Scalar* x = reinterpret_cast(px); + Scalar* y = reinterpret_cast(py); + + if(*n<=0) + return 0; + + if(*incx==1 && *incy==1) vector(y,*n).swap(vector(x,*n)); + else if(*incx>0 && *incy>0) vector(y,*n,*incy).swap(vector(x,*n,*incx)); + else if(*incx>0 && *incy<0) vector(y,*n,-*incy).reverse().swap(vector(x,*n,*incx)); + else if(*incx<0 && *incy>0) vector(y,*n,*incy).swap(vector(x,*n,-*incx).reverse()); + else if(*incx<0 && *incy<0) vector(y,*n,-*incy).reverse().swap(vector(x,*n,-*incx).reverse()); return 1; } -#endif // ISCOMPLEX -- cgit v1.2.3