aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-03-05 10:16:25 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-03-05 10:16:25 +0100
commit62ac0216060045619ff1e6035643ecf9dbefa14f (patch)
treed0b87026608125212c47a74224f8aaa1057e95c3 /Eigen/src/Core/products
parentd13b877014928c80a7cf0ae2e563d4e2e60e2c3c (diff)
fix openmp version for scalar types different than float
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h16
-rw-r--r--Eigen/src/Core/products/Parallelizer.h11
2 files changed, 18 insertions, 9 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index cbb389542..c1d42d387 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -40,7 +40,7 @@ struct ei_general_matrix_matrix_product<Scalar,LhsStorageOrder,ConjugateLhs,RhsS
const Scalar* rhs, int rhsStride,
Scalar* res, int resStride,
Scalar alpha,
- GemmParallelInfo* info = 0)
+ GemmParallelInfo<Scalar>* info = 0)
{
// transpose the product such that the result is column major
ei_general_matrix_matrix_product<Scalar,
@@ -66,7 +66,7 @@ static void run(int rows, int cols, int depth,
const Scalar* _rhs, int rhsStride,
Scalar* res, int resStride,
Scalar alpha,
- GemmParallelInfo* info = 0)
+ GemmParallelInfo<Scalar>* info = 0)
{
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
@@ -218,11 +218,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest>
struct ei_gemm_functor
{
+ typedef typename Rhs::Scalar BlockBScalar;
+
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
{}
- void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo* info=0) const
+ void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo<BlockBScalar>* info=0) const
{
if(cols==-1)
cols = m_rhs.cols();
@@ -234,6 +236,12 @@ struct ei_gemm_functor
info);
}
+
+ int sharedBlockBSize() const
+ {
+ return std::min<int>(ei_product_blocking_traits<Scalar>::Max_kc,m_rhs.rows()) * m_rhs.cols();
+ }
+
protected:
const Lhs& m_lhs;
const Rhs& m_rhs;
@@ -275,7 +283,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
_ActualRhsType,
Dest> GemmFunctor;
- ei_parallelize_gemm<Dest::MaxRowsAtCompileTime>32>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
+ ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
}
};
diff --git a/Eigen/src/Core/products/Parallelizer.h b/Eigen/src/Core/products/Parallelizer.h
index 62cf16047..03d85c1ce 100644
--- a/Eigen/src/Core/products/Parallelizer.h
+++ b/Eigen/src/Core/products/Parallelizer.h
@@ -25,16 +25,16 @@
#ifndef EIGEN_PARALLELIZER_H
#define EIGEN_PARALLELIZER_H
-struct GemmParallelInfo
+template<typename BlockBScalar> struct GemmParallelInfo
{
- GemmParallelInfo() : sync(-1), users(0) {}
+ GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
int volatile sync;
int volatile users;
int rhs_start;
int rhs_length;
- float* blockB;
+ BlockBScalar* blockB;
};
template<bool Condition,typename Functor>
@@ -51,9 +51,10 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols)
int blockCols = (cols / threads) & ~0x3;
int blockRows = (rows / threads) & ~0x7;
- float* sharedBlockB = new float[2048*2048*4];
+ typedef typename Functor::BlockBScalar BlockBScalar;
+ BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()];
- GemmParallelInfo* info = new GemmParallelInfo[threads];
+ GemmParallelInfo<BlockBScalar>* info = new GemmParallelInfo<BlockBScalar>[threads];
#pragma omp parallel for schedule(static,1)
for(int i=0; i<threads; ++i)