diff options
Diffstat (limited to 'tensorflow/core/common_runtime/threadpool_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/threadpool_device.cc | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index f7a07fe503..74a87215e1 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -31,7 +31,11 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #ifdef INTEL_MKL +#ifdef _OPENMP +#include <omp.h> +#endif #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h" +#include "tensorflow/core/platform/cpu_info.h" #endif namespace tensorflow { @@ -43,7 +47,26 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, : LocalDevice(options, Device::BuildDeviceAttributes( name, DEVICE_CPU, memory_limit, locality)), allocator_(allocator), - scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {} + scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) { +#ifdef INTEL_MKL +#ifdef _OPENMP + const char* user_omp_threads = getenv("OMP_NUM_THREADS"); + if (user_omp_threads == nullptr) { + // OMP_NUM_THREADS controls MKL's intra-op parallelization + // Default to available physical cores + const int mkl_intra_op = port::NumSchedulableCPUs(); + const int ht = port::NumHyperthreadsPerCore(); + omp_set_num_threads((mkl_intra_op + ht - 1) / ht); + } else { + uint64 user_val = 0; + if (strings::safe_strtou64(user_omp_threads, &user_val)) { + // Superflous but triggers OpenMP loading + omp_set_num_threads(user_val); + } + } +#endif // _OPENMP +#endif // INTEL_MKL +} ThreadPoolDevice::~ThreadPoolDevice() {} |