blob: 130d96589176fdb1e4b78a5901e31321c8ecfc38 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// DeviceSet is a container class for managing the various types of
// devices used by a model.
class DeviceSet {
public:
DeviceSet();
~DeviceSet();
// Does not take ownership of 'device'.
void AddDevice(Device* device);
// Set the device designated as the "client". This device
// must also be registered via AddDevice().
void set_client_device(Device* device) { client_device_ = device; }
// Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; }
// Return the list of devices in this set.
const std::vector<Device*>& devices() const { return devices_; }
// Given a DeviceNameUtils::ParsedName (which may have some
// wildcards for different components), fills "*devices" with all
// devices in "*this" that match "spec".
void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
std::vector<Device*>* devices) const;
// Finds the device with the given "fullname". Returns nullptr if
// not found.
Device* FindDeviceByName(const string& fullname) const;
// Return the list of unique device types in this set, ordered
// with more preferable devices earlier.
std::vector<DeviceType> PrioritizedDeviceTypeList() const;
private:
// Not owned.
std::vector<Device*> devices_;
// Fullname -> device* for device in devices_.
std::unordered_map<string, Device*> device_by_name_;
// client_device_ points to an element of devices_ that we consider
// to be the client device (in this local process).
Device* client_device_ = nullptr;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
|