aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-02-22 09:40:34 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-02-22 09:40:34 +0100
commitb20935be9b41ece3b022eaea14fb5eac92bbaea0 (patch)
tree9abb90183df8fe8311ad3887a799babdc28fec1b /Eigen/src
parent1a70f3b48d54d505f60613395f83dd181e9e51dc (diff)
add initial openmp support for matrix-matrix products
=> x1.9 speedup on my core2 duo
Diffstat (limited to 'Eigen/src')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h76
1 files changed, 65 insertions, 11 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index beec17ee4..7f449ac23 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -128,6 +128,49 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
{};
+template<bool Prallelize,typename Functor>
+void ei_multithreaded_product(const Functor& func, int size)
+{
+ if(!Prallelize)
+ return func(0,size);
+ #ifdef OMP
+ int threads = omp_get_num_procs();
+ #else
+ int threads = 1;
+ #endif
+ int blockSize = size / threads;
+ #pragma omp parallel for schedule(static,1)
+ for(int i=0; i<threads; ++i)
+ {
+ int blockStart = i*blockSize;
+ int actualBlockSize = std::min(blockSize, size - blockStart);
+
+ func(blockStart, actualBlockSize);
+ }
+}
+
+template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback
+{
+ ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
+ : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
+ {}
+
+ void operator() (int start, int size) const
+ {
+ Gemm::run(m_lhs.rows(), size, m_lhs.cols(),
+ (const Scalar*)&(m_lhs.const_cast_derived().coeffRef(0,0)), m_lhs.stride(),
+ (const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,start)), m_rhs.stride(),
+ (Scalar*)&(m_dest.coeffRef(0,start)), m_dest.stride(),
+ m_actualAlpha);
+ }
+
+ protected:
+ const Lhs& m_lhs;
+ const Rhs& m_rhs;
+ mutable Dest& m_dest;
+ Scalar m_actualAlpha;
+};
+
template<typename Lhs, typename Rhs>
class GeneralProduct<Lhs, Rhs, GemmProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
@@ -151,17 +194,28 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
- ei_general_matrix_matrix_product<
- Scalar,
- (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
- (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
- ::run(
- this->rows(), this->cols(), lhs.cols(),
- (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
- (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
- (Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
- actualAlpha);
+ typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product<
+ Scalar,
+ (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
+ (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ _ActualLhsType, _ActualRhsType, Dest> Functor;
+
+ #ifdef OMP
+ ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols());
+ #else
+ ei_general_matrix_matrix_product<
+ Scalar,
+ (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
+ (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
+ ::run(
+ this->rows(), this->cols(), lhs.cols(),
+ (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
+ (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
+ (Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
+ actualAlpha);
+ #endif
}
};