aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/plugin.h
blob: 5dc39b792888546be49e850fb4375400bde4a7fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
#define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_

namespace perftools {
namespace gputools {

// A plugin ID is a unique identifier for each registered plugin type.
typedef void* PluginId;

// Helper macro to define a plugin ID. To be used only inside plugin
// implementation files. Works by "reserving" an address/value (guaranteed to be
// unique) inside a process space.
#define PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(ID_VAR_NAME) \
  namespace {                                         \
  int plugin_id_value;                                \
  }                                                   \
  const PluginId ID_VAR_NAME = &plugin_id_value;

// kNullPlugin denotes an invalid plugin identifier.
extern const PluginId kNullPlugin;

// Enumeration to list the supported types of plugins / support libraries.
enum class PluginKind {
  kInvalid,
  kBlas,
  kDnn,
  kFft,
  kRng,
};

// A PluginConfig describes the set of plugins to be used by a StreamExecutor
// instance. Each plugin is defined by an arbitrary identifier, usually best set
// to the address static member in the implementation (to avoid conflicts).
//
// A PluginConfig may be passed to the StreamExecutor constructor - the plugins
// described therein will be used to provide BLAS, DNN, FFT, and RNG
// functionality. Platform-approprate defaults will be used for any un-set
// libraries. If a platform does not support a specified plugin (ex. cuBLAS on
// an OpenCL executor), then an error will be logged and no plugin operations
// will succeed.
//
// The StreamExecutor BUILD target does not link ANY plugin libraries - even
// common host fallbacks! Any plugins must be explicitly linked by dependent
// targets. See the cuda, opencl and host BUILD files for implemented plugin
// support (search for "plugin").
class PluginConfig {
 public:
  // Value specifying the platform's default option for that plugin.
  static const PluginId kDefault;

  // Initializes all members to the default options.
  PluginConfig();

  bool operator==(const PluginConfig& rhs) const;

  // Sets the appropriate library kind to that passed in.
  PluginConfig& SetBlas(PluginId blas);
  PluginConfig& SetDnn(PluginId dnn);
  PluginConfig& SetFft(PluginId fft);
  PluginConfig& SetRng(PluginId rng);

  PluginId blas() const { return blas_; }
  PluginId dnn() const { return dnn_; }
  PluginId fft() const { return fft_; }
  PluginId rng() const { return rng_; }

 private:
  PluginId blas_, dnn_, fft_, rng_;
};

}  // namespace gputools
}  // namespace perftools

#endif  // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_