aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/device_factory.cc')
-rw-r--r--tensorflow/core/common_runtime/device_factory.cc106
1 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc
new file mode 100644
index 0000000000..7d391bde1d
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.cc
@@ -0,0 +1,106 @@
+#include "tensorflow/core/common_runtime/device_factory.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+namespace {
+
+static mutex* get_device_factory_lock() {
+ static mutex device_factory_lock;
+ return &device_factory_lock;
+}
+
+struct FactoryItem {
+ std::unique_ptr<DeviceFactory> factory;
+ int priority;
+};
+
+std::unordered_map<string, FactoryItem>& device_factories() {
+ static std::unordered_map<string, FactoryItem>* factories =
+ new std::unordered_map<string, FactoryItem>;
+ return *factories;
+}
+} // namespace
+
+void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
+ int priority) {
+ mutex_lock l(*get_device_factory_lock());
+ std::unique_ptr<DeviceFactory> factory_ptr(factory);
+ std::unordered_map<string, FactoryItem>& factories = device_factories();
+ auto iter = factories.find(device_type);
+ if (iter == factories.end()) {
+ factories[device_type] = {std::move(factory_ptr), priority};
+ } else {
+ if (iter->second.priority < priority) {
+ iter->second = {std::move(factory_ptr), priority};
+ } else if (iter->second.priority == priority) {
+ LOG(FATAL) << "Duplicate registration of device factory for type "
+ << device_type << " with the same priority " << priority;
+ }
+ }
+}
+
+DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
+ mutex_lock l(*get_device_factory_lock()); // could use reader lock
+ auto it = device_factories().find(device_type);
+ if (it == device_factories().end()) {
+ return nullptr;
+ }
+ return it->second.factory.get();
+}
+
+void DeviceFactory::AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ // CPU first.
+ auto cpu_factory = GetFactory("CPU");
+ if (!cpu_factory) {
+ LOG(FATAL)
+ << "CPU Factory not registered. Did you link in threadpool_device?";
+ }
+ size_t init_size = devices->size();
+ cpu_factory->CreateDevices(options, name_prefix, devices);
+ if (devices->size() == init_size) {
+ LOG(FATAL) << "No CPU devices are available in this process";
+ }
+
+ // Then GPU.
+ auto gpu_factory = GetFactory("GPU");
+ if (gpu_factory) {
+ gpu_factory->CreateDevices(options, name_prefix, devices);
+ }
+
+ // Then the rest.
+ mutex_lock l(*get_device_factory_lock());
+ for (auto& p : device_factories()) {
+ auto factory = p.second.factory.get();
+ if (factory != cpu_factory && factory != gpu_factory) {
+ factory->CreateDevices(options, name_prefix, devices);
+ }
+ }
+}
+
+Device* DeviceFactory::NewDevice(const string& type,
+ const SessionOptions& options,
+ const string& name_prefix) {
+ auto device_factory = GetFactory(type);
+ if (!device_factory) {
+ return nullptr;
+ }
+ SessionOptions opt = options;
+ (*opt.config.mutable_device_count())[type] = 1;
+ std::vector<Device*> devices;
+ device_factory->CreateDevices(opt, name_prefix, &devices);
+ CHECK_EQ(devices.size(), 1);
+ return devices[0];
+}
+
+} // namespace tensorflow