#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ namespace perftools { namespace gputools { // A plugin ID is a unique identifier for each registered plugin type. typedef void* PluginId; // Helper macro to define a plugin ID. To be used only inside plugin // implementation files. Works by "reserving" an address/value (guaranteed to be // unique) inside a process space. #define PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(ID_VAR_NAME) \ namespace { \ int plugin_id_value; \ } \ const PluginId ID_VAR_NAME = &plugin_id_value; // kNullPlugin denotes an invalid plugin identifier. extern const PluginId kNullPlugin; // Enumeration to list the supported types of plugins / support libraries. enum class PluginKind { kInvalid, kBlas, kDnn, kFft, kRng, }; // A PluginConfig describes the set of plugins to be used by a StreamExecutor // instance. Each plugin is defined by an arbitrary identifier, usually best set // to the address static member in the implementation (to avoid conflicts). // // A PluginConfig may be passed to the StreamExecutor constructor - the plugins // described therein will be used to provide BLAS, DNN, FFT, and RNG // functionality. Platform-approprate defaults will be used for any un-set // libraries. If a platform does not support a specified plugin (ex. cuBLAS on // an OpenCL executor), then an error will be logged and no plugin operations // will succeed. // // The StreamExecutor BUILD target does not link ANY plugin libraries - even // common host fallbacks! Any plugins must be explicitly linked by dependent // targets. See the cuda, opencl and host BUILD files for implemented plugin // support (search for "plugin"). class PluginConfig { public: // Value specifying the platform's default option for that plugin. static const PluginId kDefault; // Initializes all members to the default options. PluginConfig(); bool operator==(const PluginConfig& rhs) const; // Sets the appropriate library kind to that passed in. PluginConfig& SetBlas(PluginId blas); PluginConfig& SetDnn(PluginId dnn); PluginConfig& SetFft(PluginId fft); PluginConfig& SetRng(PluginId rng); PluginId blas() const { return blas_; } PluginId dnn() const { return dnn_; } PluginId fft() const { return fft_; } PluginId rng() const { return rng_; } private: PluginId blas_, dnn_, fft_, rng_; }; } // namespace gputools } // namespace perftools #endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_