aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/local_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/local_device.cc')
-rw-r--r--tensorflow/core/common_runtime/local_device.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
new file mode 100644
index 0000000000..6a75346805
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -0,0 +1,51 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace {
+
+DeviceBase::CpuWorkerThreads eigen_worker_threads;
+Eigen::ThreadPoolInterface* eigen_thread_pool = nullptr;
+Eigen::ThreadPoolDevice* eigen_device = nullptr;
+
+static bool InitModule(const SessionOptions& options) {
+ int32 intra_op_parallelism_threads =
+ options.config.intra_op_parallelism_threads();
+ if (intra_op_parallelism_threads == 0) {
+ intra_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+ LOG(INFO) << "Local device intra op parallelism threads: "
+ << intra_op_parallelism_threads;
+ eigen_worker_threads.num_threads = intra_op_parallelism_threads;
+ eigen_worker_threads.workers = new thread::ThreadPool(
+ options.env, "Eigen", intra_op_parallelism_threads);
+ eigen_thread_pool = new EigenThreadPoolWrapper(eigen_worker_threads.workers);
+ eigen_device = new Eigen::ThreadPoolDevice(eigen_thread_pool,
+ eigen_worker_threads.num_threads);
+ return true;
+}
+} // end namespace
+
+// LocalDevice ----------------------------------------------------------------
+
+LocalDevice::LocalDevice(const SessionOptions& options,
+ const DeviceAttributes& attributes,
+ Allocator* device_allocator)
+ : Device(options.env, attributes, device_allocator) {
+ // All ThreadPoolDevices in the process will use this single fixed
+ // sized threadpool for numerical computations.
+ static bool init = InitModule(options);
+ CHECK(init); // Avoids compiler warning that init is unused.
+ set_tensorflow_cpu_worker_threads(&eigen_worker_threads);
+ set_eigen_cpu_device(eigen_device);
+}
+
+} // namespace tensorflow