aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/threadpool_device_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/threadpool_device_factory.cc')
-rw-r--r--tensorflow/core/common_runtime/threadpool_device_factory.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc
new file mode 100644
index 0000000000..ee6319abad
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc
@@ -0,0 +1,31 @@
+// Register a factory that provides CPU devices.
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
+class ThreadPoolDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override {
+ // TODO(zhifengc/tucker): Figure out the number of available CPUs
+ // and/or NUMA configuration.
+ int n = 1;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ for (int i = 0; i < n; i++) {
+ string name = strings::StrCat(name_prefix, "/cpu:", i);
+ devices->push_back(new ThreadPoolDevice(options, name, Bytes(256 << 20),
+ BUS_ANY, cpu_allocator()));
+ }
+ }
+};
+REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory);
+
+} // namespace tensorflow