1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
#include "tensorflow/core/common_runtime/device_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
static mutex* get_device_factory_lock() {
static mutex device_factory_lock;
return &device_factory_lock;
}
struct FactoryItem {
std::unique_ptr<DeviceFactory> factory;
int priority;
};
std::unordered_map<string, FactoryItem>& device_factories() {
static std::unordered_map<string, FactoryItem>* factories =
new std::unordered_map<string, FactoryItem>;
return *factories;
}
} // namespace
void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
int priority) {
mutex_lock l(*get_device_factory_lock());
std::unique_ptr<DeviceFactory> factory_ptr(factory);
std::unordered_map<string, FactoryItem>& factories = device_factories();
auto iter = factories.find(device_type);
if (iter == factories.end()) {
factories[device_type] = {std::move(factory_ptr), priority};
} else {
if (iter->second.priority < priority) {
iter->second = {std::move(factory_ptr), priority};
} else if (iter->second.priority == priority) {
LOG(FATAL) << "Duplicate registration of device factory for type "
<< device_type << " with the same priority " << priority;
}
}
}
DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
mutex_lock l(*get_device_factory_lock()); // could use reader lock
auto it = device_factories().find(device_type);
if (it == device_factories().end()) {
return nullptr;
}
return it->second.factory.get();
}
void DeviceFactory::AddDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
// CPU first.
auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) {
LOG(FATAL)
<< "CPU Factory not registered. Did you link in threadpool_device?";
}
size_t init_size = devices->size();
cpu_factory->CreateDevices(options, name_prefix, devices);
if (devices->size() == init_size) {
LOG(FATAL) << "No CPU devices are available in this process";
}
// Then GPU.
auto gpu_factory = GetFactory("GPU");
if (gpu_factory) {
gpu_factory->CreateDevices(options, name_prefix, devices);
}
// Then the rest.
mutex_lock l(*get_device_factory_lock());
for (auto& p : device_factories()) {
auto factory = p.second.factory.get();
if (factory != cpu_factory && factory != gpu_factory) {
factory->CreateDevices(options, name_prefix, devices);
}
}
}
Device* DeviceFactory::NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix) {
auto device_factory = GetFactory(type);
if (!device_factory) {
return nullptr;
}
SessionOptions opt = options;
(*opt.config.mutable_device_count())[type] = 1;
std::vector<Device*> devices;
device_factory->CreateDevices(opt, name_prefix, &devices);
CHECK_EQ(devices.size(), 1);
return devices[0];
}
} // namespace tensorflow
|