aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_factory.h
blob: 57b625b3e51d7632594f446d565514424d80c069 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_

#include <string>
#include <vector>
#include "tensorflow/core/platform/port.h"

namespace tensorflow {

class Device;
struct SessionOptions;

class DeviceFactory {
 public:
  virtual ~DeviceFactory() {}
  static void Register(const string& device_type, DeviceFactory* factory,
                       int priority);
  static DeviceFactory* GetFactory(const string& device_type);

  // Append to "*devices" all suitable devices, respecting
  // any device type specific properties/counts listed in "options".
  //
  // CPU devices are added first.
  static void AddDevices(const SessionOptions& options,
                         const string& name_prefix,
                         std::vector<Device*>* devices);

  // Helper for tests.  Create a single device of type "type".  The
  // returned device is always numbered zero, so if creating multiple
  // devices of the same type, supply distinct name_prefix arguments.
  static Device* NewDevice(const string& type, const SessionOptions& options,
                           const string& name_prefix);

  // Most clients should call AddDevices() instead.
  virtual void CreateDevices(const SessionOptions& options,
                             const string& name_prefix,
                             std::vector<Device*>* devices) = 0;
};

namespace dfactory {

template <class Factory>
class Registrar {
 public:
  // Multiple registrations for the same device type with different priorities
  // are allowed. The registration with the highest priority will be used.
  explicit Registrar(const string& device_type, int priority = 0) {
    DeviceFactory::Register(device_type, new Factory(), priority);
  }
};

}  // namespace dfactory

#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
  INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory,   \
                                         __COUNTER__, ##__VA_ARGS__)

#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
                                               ctr, ...)                    \
  static ::tensorflow::dfactory::Registrar<device_factory>                  \
      INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type,         \
                                                       ##__VA_ARGS__)

// __COUNTER__ must go through another macro to be properly expanded
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_

}  // namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_