aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/local_device.cc
blob: 6a753468053290ca13b71233c974687b93ee2363 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#define EIGEN_USE_THREADS

#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/session_options.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

namespace {

DeviceBase::CpuWorkerThreads eigen_worker_threads;
Eigen::ThreadPoolInterface* eigen_thread_pool = nullptr;
Eigen::ThreadPoolDevice* eigen_device = nullptr;

static bool InitModule(const SessionOptions& options) {
  int32 intra_op_parallelism_threads =
      options.config.intra_op_parallelism_threads();
  if (intra_op_parallelism_threads == 0) {
    intra_op_parallelism_threads = port::NumSchedulableCPUs();
  }
  LOG(INFO) << "Local device intra op parallelism threads: "
            << intra_op_parallelism_threads;
  eigen_worker_threads.num_threads = intra_op_parallelism_threads;
  eigen_worker_threads.workers = new thread::ThreadPool(
      options.env, "Eigen", intra_op_parallelism_threads);
  eigen_thread_pool = new EigenThreadPoolWrapper(eigen_worker_threads.workers);
  eigen_device = new Eigen::ThreadPoolDevice(eigen_thread_pool,
                                             eigen_worker_threads.num_threads);
  return true;
}
}  // end namespace

// LocalDevice ----------------------------------------------------------------

LocalDevice::LocalDevice(const SessionOptions& options,
                         const DeviceAttributes& attributes,
                         Allocator* device_allocator)
    : Device(options.env, attributes, device_allocator) {
  // All ThreadPoolDevices in the process will use this single fixed
  // sized threadpool for numerical computations.
  static bool init = InitModule(options);
  CHECK(init);  // Avoids compiler warning that init is unused.
  set_tensorflow_cpu_worker_threads(&eigen_worker_threads);
  set_eigen_cpu_device(eigen_device);
}

}  // namespace tensorflow