aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/GeneralMatrixMatrix.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-02-22 11:08:37 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-02-22 11:08:37 +0100
commit3e62fafce8d9c11401e0fb6ebe5cd8bf5ef91eb6 (patch)
tree669537151a6b6b638dd2af3785471e1426e1f86f /Eigen/src/Core/products/GeneralMatrixMatrix.h
parentb20935be9b41ece3b022eaea14fb5eac92bbaea0 (diff)
clean a bit the parallelizer
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h54
1 files changed, 12 insertions, 42 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 7f449ac23..c13e09eac 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -128,33 +128,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
{};
-template<bool Prallelize,typename Functor>
-void ei_multithreaded_product(const Functor& func, int size)
+template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest>
+struct ei_gemm_functor
{
- if(!Prallelize)
- return func(0,size);
- #ifdef OMP
- int threads = omp_get_num_procs();
- #else
- int threads = 1;
- #endif
- int blockSize = size / threads;
- #pragma omp parallel for schedule(static,1)
- for(int i=0; i<threads; ++i)
- {
- int blockStart = i*blockSize;
- int actualBlockSize = std::min(blockSize, size - blockStart);
-
- func(blockStart, actualBlockSize);
- }
-}
-
-template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback
-{
- ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
+ ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
{}
-
+
void operator() (int start, int size) const
{
Gemm::run(m_lhs.rows(), size, m_lhs.cols(),
@@ -194,28 +174,18 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
- typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product<
- Scalar,
- (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
- (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
- _ActualLhsType, _ActualRhsType, Dest> Functor;
-
- #ifdef OMP
- ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols());
- #else
+ typedef ei_gemm_functor<
+ Scalar,
ei_general_matrix_matrix_product<
Scalar,
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
- ::run(
- this->rows(), this->cols(), lhs.cols(),
- (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
- (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
- (Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
- actualAlpha);
- #endif
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ _ActualLhsType,
+ _ActualRhsType,
+ Dest> Functor;
+
+ ei_run_parallel_1d<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols());
}
};