aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-01-23 23:56:57 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-01-23 23:56:57 +0100
commit691e607d8578a43ef238fee50b4d8cd2c0c0f15e (patch)
tree4f0a4f32bec7a6cf606637e528ccddaa14b2fcb3 /unsupported
parentc22f7cef83cedb1ed445bb309cc07db58303f4cb (diff)
Specialize GEBP traits and kernel for mpreal to by-pass mpreal and remove the costly creation of many temporaries.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/MPRealSupport70
1 files changed, 69 insertions, 1 deletions
diff --git a/unsupported/Eigen/MPRealSupport b/unsupported/Eigen/MPRealSupport
index dfef0fe4e..8e699210f 100644
--- a/unsupported/Eigen/MPRealSupport
+++ b/unsupported/Eigen/MPRealSupport
@@ -12,8 +12,8 @@
#ifndef EIGEN_MPREALSUPPORT_MODULE_H
#define EIGEN_MPREALSUPPORT_MODULE_H
-#include <mpreal.h>
#include <Eigen/Core>
+#include <mpreal.h>
namespace Eigen {
@@ -131,6 +131,74 @@ int main()
template<> inline int cast<mpfr::mpreal,int>(const mpfr::mpreal& x)
{ return int(x.toLong()); }
+ // Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff)
+ // This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal
+ template<>
+ class gebp_traits<mpfr::mpreal, mpfr::mpreal, false, false>
+ {
+ public:
+ typedef mpfr::mpreal ResScalar;
+ enum {
+ nr = 2, // must be 2 for proper packing...
+ mr = 1,
+ WorkSpaceFactor = nr,
+ LhsProgress = 1,
+ RhsProgress = 1
+ };
+ };
+
+ template<typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+ struct gebp_kernel<mpfr::mpreal,mpfr::mpreal,Index,mr,nr,ConjugateLhs,ConjugateRhs>
+ {
+ typedef mpfr::mpreal mpreal;
+
+ EIGEN_DONT_INLINE
+ void operator()(mpreal* res, Index resStride, const mpreal* blockA, const mpreal* blockB, Index rows, Index depth, Index cols, mpreal alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, mpreal* /*unpackedB*/ = 0)
+ {
+ mpreal acc1, acc2, tmp;
+
+ if(strideA==-1) strideA = depth;
+ if(strideB==-1) strideB = depth;
+
+ for(Index j=0; j<cols; j+=nr)
+ {
+ Index actual_nr = (std::min<Index>)(nr,cols-j);
+ mpreal *C1 = res + j*resStride;
+ mpreal *C2 = res + (j+1)*resStride;
+ for(Index i=0; i<rows; i++)
+ {
+ mpreal *B = const_cast<mpreal*>(blockB) + j*strideB + offsetB*actual_nr;
+ mpreal *A = const_cast<mpreal*>(blockA) + i*strideA + offsetA;
+ acc1 = 0;
+ acc2 = 0;
+ for(Index k=0; k<depth; k++)
+ {
+ mpreal a = A[k];
+ mpreal b = B[0];
+ mpfr_mul(tmp.mpfr_ptr(), A[k].mpfr_ptr(), B[0].mpfr_ptr(), mpreal::get_default_rnd());
+ mpfr_add(acc1.mpfr_ptr(), acc1.mpfr_ptr(), tmp.mpfr_ptr(), mpreal::get_default_rnd());
+
+ if(actual_nr==2) {
+ mpfr_mul(tmp.mpfr_ptr(), A[k].mpfr_ptr(), B[1].mpfr_ptr(), mpreal::get_default_rnd());
+ mpfr_add(acc2.mpfr_ptr(), acc2.mpfr_ptr(), tmp.mpfr_ptr(), mpreal::get_default_rnd());
+ }
+
+ B+=actual_nr;
+ }
+
+ mpfr_mul(acc1.mpfr_ptr(), acc1.mpfr_ptr(), alpha.mpfr_ptr(), mpreal::get_default_rnd());
+ mpfr_add(C1[i].mpfr_ptr(), C1[i].mpfr_ptr(), acc1.mpfr_ptr(), mpreal::get_default_rnd());
+
+ if(actual_nr==2) {
+ mpfr_mul(acc2.mpfr_ptr(), acc2.mpfr_ptr(), alpha.mpfr_ptr(), mpreal::get_default_rnd());
+ mpfr_add(C2[i].mpfr_ptr(), C2[i].mpfr_ptr(), acc2.mpfr_ptr(), mpreal::get_default_rnd());
+ }
+ }
+ }
+ }
+ };
+
} // end namespace internal
}