diff options
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 76 |
1 files changed, 65 insertions, 11 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index beec17ee4..7f449ac23 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -128,6 +128,49 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> > : ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> > {}; +template<bool Prallelize,typename Functor> +void ei_multithreaded_product(const Functor& func, int size) +{ + if(!Prallelize) + return func(0,size); + #ifdef OMP + int threads = omp_get_num_procs(); + #else + int threads = 1; + #endif + int blockSize = size / threads; + #pragma omp parallel for schedule(static,1) + for(int i=0; i<threads; ++i) + { + int blockStart = i*blockSize; + int actualBlockSize = std::min(blockSize, size - blockStart); + + func(blockStart, actualBlockSize); + } +} + +template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback +{ + ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha) + : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha) + {} + + void operator() (int start, int size) const + { + Gemm::run(m_lhs.rows(), size, m_lhs.cols(), + (const Scalar*)&(m_lhs.const_cast_derived().coeffRef(0,0)), m_lhs.stride(), + (const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,start)), m_rhs.stride(), + (Scalar*)&(m_dest.coeffRef(0,start)), m_dest.stride(), + m_actualAlpha); + } + + protected: + const Lhs& m_lhs; + const Rhs& m_rhs; + mutable Dest& m_dest; + Scalar m_actualAlpha; +}; + template<typename Lhs, typename Rhs> class GeneralProduct<Lhs, Rhs, GemmProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> @@ -151,17 +194,28 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); - ei_general_matrix_matrix_product< - Scalar, - (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor> - ::run( - this->rows(), this->cols(), lhs.cols(), - (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), - (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), - (Scalar*)&(dst.coeffRef(0,0)), dst.stride(), - actualAlpha); + typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product< + Scalar, + (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + _ActualLhsType, _ActualRhsType, Dest> Functor; + + #ifdef OMP + ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols()); + #else + ei_general_matrix_matrix_product< + Scalar, + (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor> + ::run( + this->rows(), this->cols(), lhs.cols(), + (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), + (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), + (Scalar*)&(dst.coeffRef(0,0)), dst.stride(), + actualAlpha); + #endif } }; |