aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/threadpool_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/threadpool_device.cc')
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index f7a07fe503..74a87215e1 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -31,7 +31,11 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#ifdef INTEL_MKL
+#ifdef _OPENMP
+#include <omp.h>
+#endif
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
+#include "tensorflow/core/platform/cpu_info.h"
#endif
namespace tensorflow {
@@ -43,7 +47,26 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
: LocalDevice(options, Device::BuildDeviceAttributes(
name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator),
- scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {}
+ scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
+#ifdef INTEL_MKL
+#ifdef _OPENMP
+ const char* user_omp_threads = getenv("OMP_NUM_THREADS");
+ if (user_omp_threads == nullptr) {
+ // OMP_NUM_THREADS controls MKL's intra-op parallelization
+ // Default to available physical cores
+ const int mkl_intra_op = port::NumSchedulableCPUs();
+ const int ht = port::NumHyperthreadsPerCore();
+ omp_set_num_threads((mkl_intra_op + ht - 1) / ht);
+ } else {
+ uint64 user_val = 0;
+ if (strings::safe_strtou64(user_omp_threads, &user_val)) {
+ // Superflous but triggers OpenMP loading
+ omp_set_num_threads(user_val);
+ }
+ }
+#endif // _OPENMP
+#endif // INTEL_MKL
+}
ThreadPoolDevice::~ThreadPoolDevice() {}