blob: 3b0465d9a66f7b0e00066d26d6fb04c052782732 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
#include "tensorflow/core/common_runtime/device_set.h"
#include <set>
#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
DeviceSet::DeviceSet() {}
DeviceSet::~DeviceSet() {}
void DeviceSet::AddDevice(Device* device) {
devices_.push_back(device);
device_by_name_.insert({device->name(), device});
}
void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
std::vector<Device*>* devices) const {
// TODO(jeff): If we are going to repeatedly lookup the set of devices
// for the same spec, maybe we should have a cache of some sort
devices->clear();
for (Device* d : devices_) {
if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
devices->push_back(d);
}
}
}
Device* DeviceSet::FindDeviceByName(const string& name) const {
return gtl::FindPtrOrNull(device_by_name_, name);
}
// Higher result implies lower priority.
static int Order(const DeviceType& d) {
if (StringPiece(d.type()) == DEVICE_CPU) {
return 3;
} else if (StringPiece(d.type()) == DEVICE_GPU) {
return 2;
} else {
return 1;
}
}
static bool ByPriority(const DeviceType& a, const DeviceType& b) {
// Order by "order number"; break ties lexicographically.
return std::make_pair(Order(a), StringPiece(a.type())) <
std::make_pair(Order(b), StringPiece(b.type()));
}
std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
std::vector<DeviceType> result;
std::set<string> seen;
for (Device* d : devices_) {
auto t = d->device_type();
if (seen.insert(t).second) {
result.emplace_back(DeviceType(t));
}
}
std::sort(result.begin(), result.end(), ByPriority);
return result;
}
} // namespace tensorflow
|