diff options
Diffstat (limited to 'tensorflow/core/common_runtime/device_factory.h')
-rw-r--r-- | tensorflow/core/common_runtime/device_factory.h | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h new file mode 100644 index 0000000000..57b625b3e5 --- /dev/null +++ b/tensorflow/core/common_runtime/device_factory.h @@ -0,0 +1,69 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ +#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ + +#include <string> +#include <vector> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Device; +struct SessionOptions; + +class DeviceFactory { + public: + virtual ~DeviceFactory() {} + static void Register(const string& device_type, DeviceFactory* factory, + int priority); + static DeviceFactory* GetFactory(const string& device_type); + + // Append to "*devices" all suitable devices, respecting + // any device type specific properties/counts listed in "options". + // + // CPU devices are added first. + static void AddDevices(const SessionOptions& options, + const string& name_prefix, + std::vector<Device*>* devices); + + // Helper for tests. Create a single device of type "type". The + // returned device is always numbered zero, so if creating multiple + // devices of the same type, supply distinct name_prefix arguments. + static Device* NewDevice(const string& type, const SessionOptions& options, + const string& name_prefix); + + // Most clients should call AddDevices() instead. + virtual void CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector<Device*>* devices) = 0; +}; + +namespace dfactory { + +template <class Factory> +class Registrar { + public: + // Multiple registrations for the same device type with different priorities + // are allowed. The registration with the highest priority will be used. + explicit Registrar(const string& device_type, int priority = 0) { + DeviceFactory::Register(device_type, new Factory(), priority); + } +}; + +} // namespace dfactory + +#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + __COUNTER__, ##__VA_ARGS__) + +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + ctr, ...) \ + static ::tensorflow::dfactory::Registrar<device_factory> \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \ + ##__VA_ARGS__) + +// __COUNTER__ must go through another macro to be properly expanded +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ |