diff options
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/products/Parallelizer.h | 48 | ||||
-rw-r--r-- | bench/bench_gemm.cpp | 3 |
3 files changed, 50 insertions, 5 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 801ed2792..3513d118e 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -241,8 +241,8 @@ struct ei_gemm_functor Index sharedBlockBSize() const { - int maxKc, maxMc; - getBlockingSizes<Scalar>(maxKc,maxMc); + Index maxKc, maxMc, maxNc; + getBlockingSizes<Scalar>(maxKc, maxMc, maxNc); return std::min<Index>(maxKc,m_rhs.rows()) * m_rhs.cols(); } diff --git a/Eigen/src/Core/products/Parallelizer.h b/Eigen/src/Core/products/Parallelizer.h index f7bdceab7..588f78b4c 100644 --- a/Eigen/src/Core/products/Parallelizer.h +++ b/Eigen/src/Core/products/Parallelizer.h @@ -25,6 +25,50 @@ #ifndef EIGEN_PARALLELIZER_H #define EIGEN_PARALLELIZER_H +/** \internal */ +inline void ei_manage_multi_threading(Action action, int* v) +{ + static int m_maxThreads = -1; + + if(action==SetAction) + { + ei_internal_assert(v!=0); + m_maxThreads = *v; + } + else if(action==GetAction) + { + ei_internal_assert(v!=0); + #ifdef EIGEN_HAS_OPENMP + if(m_maxThreads>0) + *v = m_maxThreads; + else + *v = omp_get_max_threads(); + #else + *v = 1; + #endif + } + else + { + ei_internal_assert(false); + } +} + +/** \returns the max number of threads reserved for Eigen + * \sa setNbThreads */ +inline int nbThreads() +{ + int ret; + ei_manage_multi_threading(GetAction, &ret); + return ret; +} + +/** Sets the max number of threads reserved for Eigen + * \sa nbThreads */ +inline void setNbThreads(int v) +{ + ei_manage_multi_threading(SetAction, &v); +} + template<typename BlockBScalar, typename Index> struct GemmParallelInfo { GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {} @@ -57,10 +101,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols) // 2- compute the maximal number of threads from the size of the product: // FIXME this has to be fine tuned - Index max_threads = std::max(1,rows / 32); + Index max_threads = std::max<Index>(1,rows / 32); // 3 - compute the number of threads we are going to use - Index threads = std::min<Index>(omp_get_max_threads(), max_threads); + Index threads = std::min<Index>(nbThreads(), max_threads); if(threads==1) return func(0,rows, 0,cols); diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp index 5c55d4b7c..77cc420f4 100644 --- a/bench/bench_gemm.cpp +++ b/bench/bench_gemm.cpp @@ -112,7 +112,8 @@ int main(int argc, char ** argv) if(procs>1) { BenchTimer tmono; - omp_set_num_threads(1); + //omp_set_num_threads(1); + Eigen::setNbThreads(1); BENCH(tmono, tries, rep, gemm(a,b,c)); std::cout << "eigen mono cpu " << tmono.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(CPU_TIMER) << "s)\n"; std::cout << "eigen mono real " << tmono.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(REAL_TIMER) << "s)\n"; |