diff options
author | 2010-02-22 11:08:37 +0100 | |
---|---|---|
committer | 2010-02-22 11:08:37 +0100 | |
commit | 3e62fafce8d9c11401e0fb6ebe5cd8bf5ef91eb6 (patch) | |
tree | 669537151a6b6b638dd2af3785471e1426e1f86f /Eigen/src/Core/products/GeneralMatrixMatrix.h | |
parent | b20935be9b41ece3b022eaea14fb5eac92bbaea0 (diff) |
clean a bit the parallelizer
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 54 |
1 files changed, 12 insertions, 42 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 7f449ac23..c13e09eac 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -128,33 +128,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> > : ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> > {}; -template<bool Prallelize,typename Functor> -void ei_multithreaded_product(const Functor& func, int size) +template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> +struct ei_gemm_functor { - if(!Prallelize) - return func(0,size); - #ifdef OMP - int threads = omp_get_num_procs(); - #else - int threads = 1; - #endif - int blockSize = size / threads; - #pragma omp parallel for schedule(static,1) - for(int i=0; i<threads; ++i) - { - int blockStart = i*blockSize; - int actualBlockSize = std::min(blockSize, size - blockStart); - - func(blockStart, actualBlockSize); - } -} - -template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback -{ - ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha) + ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha) : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha) {} - + void operator() (int start, int size) const { Gemm::run(m_lhs.rows(), size, m_lhs.cols(), @@ -194,28 +174,18 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); - typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product< - Scalar, - (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, - _ActualLhsType, _ActualRhsType, Dest> Functor; - - #ifdef OMP - ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols()); - #else + typedef ei_gemm_functor< + Scalar, ei_general_matrix_matrix_product< Scalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor> - ::run( - this->rows(), this->cols(), lhs.cols(), - (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), - (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), - (Scalar*)&(dst.coeffRef(0,0)), dst.stride(), - actualAlpha); - #endif + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + _ActualLhsType, + _ActualRhsType, + Dest> Functor; + + ei_run_parallel_1d<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols()); } }; |