diff options
author | 2010-03-05 10:16:25 +0100 | |
---|---|---|
committer | 2010-03-05 10:16:25 +0100 | |
commit | 62ac0216060045619ff1e6035643ecf9dbefa14f (patch) | |
tree | d0b87026608125212c47a74224f8aaa1057e95c3 /Eigen/src/Core/products | |
parent | d13b877014928c80a7cf0ae2e563d4e2e60e2c3c (diff) |
fix openmp version for scalar types different than float
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 16 | ||||
-rw-r--r-- | Eigen/src/Core/products/Parallelizer.h | 11 |
2 files changed, 18 insertions, 9 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index cbb389542..c1d42d387 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -40,7 +40,7 @@ struct ei_general_matrix_matrix_product<Scalar,LhsStorageOrder,ConjugateLhs,RhsS const Scalar* rhs, int rhsStride, Scalar* res, int resStride, Scalar alpha, - GemmParallelInfo* info = 0) + GemmParallelInfo<Scalar>* info = 0) { // transpose the product such that the result is column major ei_general_matrix_matrix_product<Scalar, @@ -66,7 +66,7 @@ static void run(int rows, int cols, int depth, const Scalar* _rhs, int rhsStride, Scalar* res, int resStride, Scalar alpha, - GemmParallelInfo* info = 0) + GemmParallelInfo<Scalar>* info = 0) { ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride); @@ -218,11 +218,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> > template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_functor { + typedef typename Rhs::Scalar BlockBScalar; + ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha) : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha) {} - void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo* info=0) const + void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo<BlockBScalar>* info=0) const { if(cols==-1) cols = m_rhs.cols(); @@ -234,6 +236,12 @@ struct ei_gemm_functor info); } + + int sharedBlockBSize() const + { + return std::min<int>(ei_product_blocking_traits<Scalar>::Max_kc,m_rhs.rows()) * m_rhs.cols(); + } + protected: const Lhs& m_lhs; const Rhs& m_rhs; @@ -275,7 +283,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> _ActualRhsType, Dest> GemmFunctor; - ei_parallelize_gemm<Dest::MaxRowsAtCompileTime>32>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols()); + ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols()); } }; diff --git a/Eigen/src/Core/products/Parallelizer.h b/Eigen/src/Core/products/Parallelizer.h index 62cf16047..03d85c1ce 100644 --- a/Eigen/src/Core/products/Parallelizer.h +++ b/Eigen/src/Core/products/Parallelizer.h @@ -25,16 +25,16 @@ #ifndef EIGEN_PARALLELIZER_H #define EIGEN_PARALLELIZER_H -struct GemmParallelInfo +template<typename BlockBScalar> struct GemmParallelInfo { - GemmParallelInfo() : sync(-1), users(0) {} + GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {} int volatile sync; int volatile users; int rhs_start; int rhs_length; - float* blockB; + BlockBScalar* blockB; }; template<bool Condition,typename Functor> @@ -51,9 +51,10 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols) int blockCols = (cols / threads) & ~0x3; int blockRows = (rows / threads) & ~0x7; - float* sharedBlockB = new float[2048*2048*4]; + typedef typename Functor::BlockBScalar BlockBScalar; + BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()]; - GemmParallelInfo* info = new GemmParallelInfo[threads]; + GemmParallelInfo<BlockBScalar>* info = new GemmParallelInfo<BlockBScalar>[threads]; #pragma omp parallel for schedule(static,1) for(int i=0; i<threads; ++i) |