diff options
Diffstat (limited to 'tensorflow/stream_executor/plugin.h')
-rw-r--r-- | tensorflow/stream_executor/plugin.h | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/plugin.h b/tensorflow/stream_executor/plugin.h new file mode 100644 index 0000000000..5dc39b7928 --- /dev/null +++ b/tensorflow/stream_executor/plugin.h @@ -0,0 +1,74 @@ +#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ +#define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ + +namespace perftools { +namespace gputools { + +// A plugin ID is a unique identifier for each registered plugin type. +typedef void* PluginId; + +// Helper macro to define a plugin ID. To be used only inside plugin +// implementation files. Works by "reserving" an address/value (guaranteed to be +// unique) inside a process space. +#define PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(ID_VAR_NAME) \ + namespace { \ + int plugin_id_value; \ + } \ + const PluginId ID_VAR_NAME = &plugin_id_value; + +// kNullPlugin denotes an invalid plugin identifier. +extern const PluginId kNullPlugin; + +// Enumeration to list the supported types of plugins / support libraries. +enum class PluginKind { + kInvalid, + kBlas, + kDnn, + kFft, + kRng, +}; + +// A PluginConfig describes the set of plugins to be used by a StreamExecutor +// instance. Each plugin is defined by an arbitrary identifier, usually best set +// to the address static member in the implementation (to avoid conflicts). +// +// A PluginConfig may be passed to the StreamExecutor constructor - the plugins +// described therein will be used to provide BLAS, DNN, FFT, and RNG +// functionality. Platform-approprate defaults will be used for any un-set +// libraries. If a platform does not support a specified plugin (ex. cuBLAS on +// an OpenCL executor), then an error will be logged and no plugin operations +// will succeed. +// +// The StreamExecutor BUILD target does not link ANY plugin libraries - even +// common host fallbacks! Any plugins must be explicitly linked by dependent +// targets. See the cuda, opencl and host BUILD files for implemented plugin +// support (search for "plugin"). +class PluginConfig { + public: + // Value specifying the platform's default option for that plugin. + static const PluginId kDefault; + + // Initializes all members to the default options. + PluginConfig(); + + bool operator==(const PluginConfig& rhs) const; + + // Sets the appropriate library kind to that passed in. + PluginConfig& SetBlas(PluginId blas); + PluginConfig& SetDnn(PluginId dnn); + PluginConfig& SetFft(PluginId fft); + PluginConfig& SetRng(PluginId rng); + + PluginId blas() const { return blas_; } + PluginId dnn() const { return dnn_; } + PluginId fft() const { return fft_; } + PluginId rng() const { return rng_; } + + private: + PluginId blas_, dnn_, fft_, rng_; +}; + +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ |