aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-03-05 10:44:31 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-03-05 10:44:31 +0100
commit48d0595c29eef24ef98b82d23ed075de4819e39c (patch)
treedb3899fab5884ee01bbc3ac4ace67f84226dcad4 /Eigen/src
parentdd961f8c60ed684b7e3683b348544fc28f391d8f (diff)
* dynamically adjust the number of threads
* disbale parallelisation if we already are in a parallel session
Diffstat (limited to 'Eigen/src')
-rw-r--r--Eigen/src/Core/products/Parallelizer.h22
1 files changed, 19 insertions, 3 deletions
diff --git a/Eigen/src/Core/products/Parallelizer.h b/Eigen/src/Core/products/Parallelizer.h
index 03d85c1ce..304dc7ed0 100644
--- a/Eigen/src/Core/products/Parallelizer.h
+++ b/Eigen/src/Core/products/Parallelizer.h
@@ -44,8 +44,24 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols)
func(0,rows, 0,cols);
#else
- int threads = omp_get_max_threads();
- if((!Condition)||(threads==1))
+ // Dynamically check whether we should enable or disable OpenMP.
+ // The conditions are:
+ // - the max number of threads we can create is greater than 1
+ // - we are not already in a parallel code
+ // - the sizes are large enough
+
+ // 1- are we already in a parallel session?
+ if((!Condition) || (omp_get_num_threads()>1))
+ return func(0,rows, 0,cols);
+
+ // 2- compute the maximal number of threads from the size of the product:
+ // FIXME this has to be fine tuned
+ int max_threads = std::max(1,rows / 32);
+
+ // 3 - compute the number of threads we are going to use
+ int threads = std::min(omp_get_max_threads(), max_threads);
+
+ if(threads==1)
return func(0,rows, 0,cols);
int blockCols = (cols / threads) & ~0x3;
@@ -56,7 +72,7 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols)
GemmParallelInfo<BlockBScalar>* info = new GemmParallelInfo<BlockBScalar>[threads];
- #pragma omp parallel for schedule(static,1)
+ #pragma omp parallel for schedule(static,1) num_threads(threads)
for(int i=0; i<threads; ++i)
{
int r0 = i*blockRows;