1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
#include <string>
#include <vector>
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
class Device;
struct SessionOptions;
class DeviceFactory {
public:
virtual ~DeviceFactory() {}
static void Register(const string& device_type, DeviceFactory* factory,
int priority);
static DeviceFactory* GetFactory(const string& device_type);
// Append to "*devices" all suitable devices, respecting
// any device type specific properties/counts listed in "options".
//
// CPU devices are added first.
static void AddDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices);
// Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments.
static Device* NewDevice(const string& type, const SessionOptions& options,
const string& name_prefix);
// Most clients should call AddDevices() instead.
virtual void CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) = 0;
};
namespace dfactory {
template <class Factory>
class Registrar {
public:
// Multiple registrations for the same device type with different priorities
// are allowed. The registration with the highest priority will be used.
explicit Registrar(const string& device_type, int priority = 0) {
DeviceFactory::Register(device_type, new Factory(), priority);
}
};
} // namespace dfactory
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
__COUNTER__, ##__VA_ARGS__)
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
ctr, ...) \
static ::tensorflow::dfactory::Registrar<device_factory> \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
##__VA_ARGS__)
// __COUNTER__ must go through another macro to be properly expanded
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
} // namespace tensorflow
#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
|