aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_set.cc
blob: 3b0465d9a66f7b0e00066d26d6fb04c052782732 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "tensorflow/core/common_runtime/device_set.h"

#include <set>
#include <utility>
#include <vector>

#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"

namespace tensorflow {

DeviceSet::DeviceSet() {}

DeviceSet::~DeviceSet() {}

void DeviceSet::AddDevice(Device* device) {
  devices_.push_back(device);
  device_by_name_.insert({device->name(), device});
}

void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
                                    std::vector<Device*>* devices) const {
  // TODO(jeff): If we are going to repeatedly lookup the set of devices
  // for the same spec, maybe we should have a cache of some sort
  devices->clear();
  for (Device* d : devices_) {
    if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
      devices->push_back(d);
    }
  }
}

Device* DeviceSet::FindDeviceByName(const string& name) const {
  return gtl::FindPtrOrNull(device_by_name_, name);
}

// Higher result implies lower priority.
static int Order(const DeviceType& d) {
  if (StringPiece(d.type()) == DEVICE_CPU) {
    return 3;
  } else if (StringPiece(d.type()) == DEVICE_GPU) {
    return 2;
  } else {
    return 1;
  }
}

static bool ByPriority(const DeviceType& a, const DeviceType& b) {
  // Order by "order number"; break ties lexicographically.
  return std::make_pair(Order(a), StringPiece(a.type())) <
         std::make_pair(Order(b), StringPiece(b.type()));
}

std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
  std::vector<DeviceType> result;
  std::set<string> seen;
  for (Device* d : devices_) {
    auto t = d->device_type();
    if (seen.insert(t).second) {
      result.emplace_back(DeviceType(t));
    }
  }
  std::sort(result.begin(), result.end(), ByPriority);
  return result;
}

}  // namespace tensorflow