aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_factory.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/device_factory.h')
-rw-r--r--tensorflow/core/common_runtime/device_factory.h69
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
new file mode 100644
index 0000000000..57b625b3e5
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -0,0 +1,69 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Device;
+struct SessionOptions;
+
+class DeviceFactory {
+ public:
+ virtual ~DeviceFactory() {}
+ static void Register(const string& device_type, DeviceFactory* factory,
+ int priority);
+ static DeviceFactory* GetFactory(const string& device_type);
+
+ // Append to "*devices" all suitable devices, respecting
+ // any device type specific properties/counts listed in "options".
+ //
+ // CPU devices are added first.
+ static void AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices);
+
+ // Helper for tests. Create a single device of type "type". The
+ // returned device is always numbered zero, so if creating multiple
+ // devices of the same type, supply distinct name_prefix arguments.
+ static Device* NewDevice(const string& type, const SessionOptions& options,
+ const string& name_prefix);
+
+ // Most clients should call AddDevices() instead.
+ virtual void CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) = 0;
+};
+
+namespace dfactory {
+
+template <class Factory>
+class Registrar {
+ public:
+ // Multiple registrations for the same device type with different priorities
+ // are allowed. The registration with the highest priority will be used.
+ explicit Registrar(const string& device_type, int priority = 0) {
+ DeviceFactory::Register(device_type, new Factory(), priority);
+ }
+};
+
+} // namespace dfactory
+
+#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ __COUNTER__, ##__VA_ARGS__)
+
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ ctr, ...) \
+ static ::tensorflow::dfactory::Registrar<device_factory> \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
+ ##__VA_ARGS__)
+
+// __COUNTER__ must go through another macro to be properly expanded
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_