aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_mgr.h
blob: c57d0222aacb86300cd51827a1100aae0193e4e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

class DeviceAttributes;

class DeviceMgr {
 public:
  // TODO(zhifengc): Other initialization information.
  explicit DeviceMgr(const std::vector<Device*>& devices);
  ~DeviceMgr();

  // Returns attributes of all devices.
  void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;

  std::vector<Device*> ListDevices() const;

  // Returns a string listing all devices.
  string DebugString() const;

  // Returns a string of all the device mapping.
  string DeviceMappingString() const;

  // Assigns *device with pointer to Device of the given name.
  // Accepts either a full device name, or just the replica-local suffix.
  Status LookupDevice(const string& name, Device** device) const;

  // Clears given containers of all devices if 'container' is
  // non-empty. Otherwise, clears default containers of all devices.
  void ClearContainers(gtl::ArraySlice<string> containers) const;

  int NumDeviceType(const string& type) const;

 private:
  typedef gtl::InlinedVector<Device*, 8> DeviceVec;
  DeviceVec devices_;
  std::unordered_map<string, Device*> device_map_;
  std::unordered_map<string, int> device_type_counts_;

  TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_