aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_mgr.cc
blob: 4fa13f6b4bb06a89be39e6b628ae0d05dfc6c6a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "tensorflow/core/common_runtime/device_mgr.h"

#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/device_name_utils.h"

namespace tensorflow {

DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
  for (Device* d : devices) {
    devices_.push_back(d);

    // Register under both the full name and the local name.
    device_map_[d->name()] = d;
    device_map_[DeviceNameUtils::LocalName(d->name())] = d;
    device_type_counts_[d->device_type()]++;
  }
}

DeviceMgr::~DeviceMgr() {
  for (auto p : devices_) delete p;
}

void DeviceMgr::ListDeviceAttributes(
    std::vector<DeviceAttributes>* devices) const {
  devices->reserve(devices_.size());
  for (Device* dev : devices_) {
    devices->emplace_back(dev->attributes());
  }
}

std::vector<Device*> DeviceMgr::ListDevices() const {
  return std::vector<Device*>(devices_.begin(), devices_.end());
}

string DeviceMgr::DebugString() const {
  string out;
  for (Device* dev : devices_) {
    strings::StrAppend(&out, dev->name(), "\n");
  }
  return out;
}

string DeviceMgr::DeviceMappingString() const {
  string out;
  for (Device* dev : devices_) {
    if (!dev->attributes().physical_device_desc().empty()) {
      strings::StrAppend(&out, dev->name(), " -> ",
                         dev->attributes().physical_device_desc(), "\n");
    }
  }
  return out;
}

Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
  Status s;
  auto iter = device_map_.find(name);
  if (iter == device_map_.end()) {
    return errors::InvalidArgument(name, " unknown device.");
  }
  *device = iter->second;
  return Status::OK();
}

void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
  Status s;
  for (Device* dev : devices_) {
    if (containers.empty()) {
      s.Update(dev->resource_manager()->Cleanup(
          dev->resource_manager()->default_container()));
    } else {
      for (const string& c : containers) {
        s.Update(dev->resource_manager()->Cleanup(c));
      }
    }
    if (!s.ok()) {
      LOG(WARNING) << s;
    }
  }
}

int DeviceMgr::NumDeviceType(const string& type) const {
  auto iter = device_type_counts_.find(type);
  if (iter != device_type_counts_.end()) return iter->second;
  return 0;
}

}  // namespace tensorflow