From 1b61aadcbe0f3d8c6898d28b7605347a7643190d Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Sat, 8 Sep 2012 02:06:45 +0800 Subject: Implement SDSDOT with DSDOT and avoid allocating buffers in DSDOT. --- blas/double.cpp | 26 ++++++++------------------ blas/single.cpp | 14 +------------- 2 files changed, 9 insertions(+), 31 deletions(-) (limited to 'blas') diff --git a/blas/double.cpp b/blas/double.cpp index 8a3b00175..8fd0709ba 100644 --- a/blas/double.cpp +++ b/blas/double.cpp @@ -19,25 +19,15 @@ #include "level2_real_impl.h" #include "level3_impl.h" -// currently used by DSDOT only -double* cast_vector_to_double(float* x, int n, int incx) +double BLASFUNC(dsdot)(int* n, float* x, int* incx, float* y, int* incy) { - double* ret = new double[n]; - if(incx<0) vector(ret,n) = vector(x,n,-incx).reverse().cast(); - else vector(ret,n) = vector(x,n, incx).cast(); - return ret; -} - -double BLASFUNC(dsdot)(int* n, float* px, int* incx, float* py, int* incy) -{ - if(*n <= 0) return 0; - - double* x = cast_vector_to_double(px, *n, *incx); - double* y = cast_vector_to_double(py, *n, *incy); - double res = vector(x,*n).cwiseProduct(vector(y,*n)).sum(); + if(*n<=0) return 0; - delete[] x; - delete[] y; - return res; + if(*incx==1 && *incy==1) return (vector(x,*n).cast().cwiseProduct(vector(y,*n).cast())).sum(); + else if(*incx>0 && *incy>0) return (vector(x,*n,*incx).cast().cwiseProduct(vector(y,*n,*incy).cast())).sum(); + else if(*incx<0 && *incy>0) return (vector(x,*n,-*incx).reverse().cast().cwiseProduct(vector(y,*n,*incy).cast())).sum(); + else if(*incx>0 && *incy<0) return (vector(x,*n,*incx).cast().cwiseProduct(vector(y,*n,-*incy).reverse().cast())).sum(); + else if(*incx<0 && *incy<0) return (vector(x,*n,-*incx).reverse().cast().cwiseProduct(vector(y,*n,-*incy).reverse().cast())).sum(); + else return 0; } diff --git a/blas/single.cpp b/blas/single.cpp index 9516398ba..836e3eee2 100644 --- a/blas/single.cpp +++ b/blas/single.cpp @@ -2,7 +2,6 @@ // for linear algebra. // // Copyright (C) 2009 Gael Guennebaud -// Copyright (C) 2012 Chen-Pang He // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -20,15 +19,4 @@ #include "level3_impl.h" float BLASFUNC(sdsdot)(int* n, float* alpha, float* x, int* incx, float* y, int* incy) -{ - float res = *alpha; - - if(*n>0) { - if(*incx==1 && *incy==1) res += (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); - else if(*incx>0 && *incy>0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); - else if(*incx<0 && *incy>0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,*incy))).sum(); - else if(*incx>0 && *incy<0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); - else if(*incx<0 && *incy<0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,-*incy).reverse())).sum(); - } - return res; -} +{ return *alpha + BLASFUNC(dsdot)(n, x, incx, y, incy); } -- cgit v1.2.3