aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_set.h
blob: 130d96589176fdb1e4b78a5901e31321c8ecfc38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_

#include <memory>
#include <unordered_map>
#include <vector>

#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/util/device_name_utils.h"

namespace tensorflow {

// DeviceSet is a container class for managing the various types of
// devices used by a model.
class DeviceSet {
 public:
  DeviceSet();
  ~DeviceSet();

  // Does not take ownership of 'device'.
  void AddDevice(Device* device);

  // Set the device designated as the "client".  This device
  // must also be registered via AddDevice().
  void set_client_device(Device* device) { client_device_ = device; }

  // Returns a pointer to the device designated as the "client".
  Device* client_device() const { return client_device_; }

  // Return the list of devices in this set.
  const std::vector<Device*>& devices() const { return devices_; }

  // Given a DeviceNameUtils::ParsedName (which may have some
  // wildcards for different components), fills "*devices" with all
  // devices in "*this" that match "spec".
  void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
                           std::vector<Device*>* devices) const;

  // Finds the device with the given "fullname". Returns nullptr if
  // not found.
  Device* FindDeviceByName(const string& fullname) const;

  // Return the list of unique device types in this set, ordered
  // with more preferable devices earlier.
  std::vector<DeviceType> PrioritizedDeviceTypeList() const;

 private:
  // Not owned.
  std::vector<Device*> devices_;

  // Fullname -> device* for device in devices_.
  std::unordered_map<string, Device*> device_by_name_;

  // client_device_ points to an element of devices_ that we consider
  // to be the client device (in this local process).
  Device* client_device_ = nullptr;

  TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_