diff options
author | Gael Guennebaud <g.gael@free.fr> | 2013-01-23 23:56:57 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2013-01-23 23:56:57 +0100 |
commit | 691e607d8578a43ef238fee50b4d8cd2c0c0f15e (patch) | |
tree | 4f0a4f32bec7a6cf606637e528ccddaa14b2fcb3 /unsupported | |
parent | c22f7cef83cedb1ed445bb309cc07db58303f4cb (diff) |
Specialize GEBP traits and kernel for mpreal to by-pass mpreal and remove the costly creation of many temporaries.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/MPRealSupport | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/unsupported/Eigen/MPRealSupport b/unsupported/Eigen/MPRealSupport index dfef0fe4e..8e699210f 100644 --- a/unsupported/Eigen/MPRealSupport +++ b/unsupported/Eigen/MPRealSupport @@ -12,8 +12,8 @@ #ifndef EIGEN_MPREALSUPPORT_MODULE_H #define EIGEN_MPREALSUPPORT_MODULE_H -#include <mpreal.h> #include <Eigen/Core> +#include <mpreal.h> namespace Eigen { @@ -131,6 +131,74 @@ int main() template<> inline int cast<mpfr::mpreal,int>(const mpfr::mpreal& x) { return int(x.toLong()); } + // Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff) + // This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal + template<> + class gebp_traits<mpfr::mpreal, mpfr::mpreal, false, false> + { + public: + typedef mpfr::mpreal ResScalar; + enum { + nr = 2, // must be 2 for proper packing... + mr = 1, + WorkSpaceFactor = nr, + LhsProgress = 1, + RhsProgress = 1 + }; + }; + + template<typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> + struct gebp_kernel<mpfr::mpreal,mpfr::mpreal,Index,mr,nr,ConjugateLhs,ConjugateRhs> + { + typedef mpfr::mpreal mpreal; + + EIGEN_DONT_INLINE + void operator()(mpreal* res, Index resStride, const mpreal* blockA, const mpreal* blockB, Index rows, Index depth, Index cols, mpreal alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, mpreal* /*unpackedB*/ = 0) + { + mpreal acc1, acc2, tmp; + + if(strideA==-1) strideA = depth; + if(strideB==-1) strideB = depth; + + for(Index j=0; j<cols; j+=nr) + { + Index actual_nr = (std::min<Index>)(nr,cols-j); + mpreal *C1 = res + j*resStride; + mpreal *C2 = res + (j+1)*resStride; + for(Index i=0; i<rows; i++) + { + mpreal *B = const_cast<mpreal*>(blockB) + j*strideB + offsetB*actual_nr; + mpreal *A = const_cast<mpreal*>(blockA) + i*strideA + offsetA; + acc1 = 0; + acc2 = 0; + for(Index k=0; k<depth; k++) + { + mpreal a = A[k]; + mpreal b = B[0]; + mpfr_mul(tmp.mpfr_ptr(), A[k].mpfr_ptr(), B[0].mpfr_ptr(), mpreal::get_default_rnd()); + mpfr_add(acc1.mpfr_ptr(), acc1.mpfr_ptr(), tmp.mpfr_ptr(), mpreal::get_default_rnd()); + + if(actual_nr==2) { + mpfr_mul(tmp.mpfr_ptr(), A[k].mpfr_ptr(), B[1].mpfr_ptr(), mpreal::get_default_rnd()); + mpfr_add(acc2.mpfr_ptr(), acc2.mpfr_ptr(), tmp.mpfr_ptr(), mpreal::get_default_rnd()); + } + + B+=actual_nr; + } + + mpfr_mul(acc1.mpfr_ptr(), acc1.mpfr_ptr(), alpha.mpfr_ptr(), mpreal::get_default_rnd()); + mpfr_add(C1[i].mpfr_ptr(), C1[i].mpfr_ptr(), acc1.mpfr_ptr(), mpreal::get_default_rnd()); + + if(actual_nr==2) { + mpfr_mul(acc2.mpfr_ptr(), acc2.mpfr_ptr(), alpha.mpfr_ptr(), mpreal::get_default_rnd()); + mpfr_add(C2[i].mpfr_ptr(), C2[i].mpfr_ptr(), acc2.mpfr_ptr(), mpreal::get_default_rnd()); + } + } + } + } + }; + } // end namespace internal } |