aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/plugin_registry.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/plugin_registry.cc')
-rw-r--r--tensorflow/stream_executor/plugin_registry.cc228
1 files changed, 228 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc
new file mode 100644
index 0000000000..eda44d1146
--- /dev/null
+++ b/tensorflow/stream_executor/plugin_registry.cc
@@ -0,0 +1,228 @@
+#include "tensorflow/stream_executor/plugin_registry.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+namespace perftools {
+namespace gputools {
+
+const PluginId kNullPlugin = nullptr;
+
+// Returns the string representation of the specified PluginKind.
+string PluginKindString(PluginKind plugin_kind) {
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ return "BLAS";
+ case PluginKind::kDnn:
+ return "DNN";
+ case PluginKind::kFft:
+ return "FFT";
+ case PluginKind::kRng:
+ return "RNG";
+ case PluginKind::kInvalid:
+ default:
+ return "kInvalid";
+ }
+}
+
+PluginRegistry::DefaultFactories::DefaultFactories() :
+ blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
+
+/* static */ mutex PluginRegistry::mu_(LINKER_INITIALIZED);
+/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
+
+PluginRegistry::PluginRegistry() {}
+
+/* static */ PluginRegistry* PluginRegistry::Instance() {
+ mutex_lock lock{mu_};
+ if (instance_ == nullptr) {
+ instance_ = new PluginRegistry();
+ }
+ return instance_;
+}
+
+void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
+ Platform::Id platform_id) {
+ platform_id_by_kind_[platform_kind] = platform_id;
+}
+
+template <typename FACTORY_TYPE>
+port::Status PluginRegistry::RegisterFactoryInternal(
+ PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
+ std::map<PluginId, FACTORY_TYPE>* factories) {
+ mutex_lock lock{mu_};
+
+ if (factories->find(plugin_id) != factories->end()) {
+ return port::Status{
+ port::error::ALREADY_EXISTS,
+ port::Printf("Attempting to register factory for plugin %s when "
+ "one has already been registered",
+ plugin_name.c_str())};
+ }
+
+ (*factories)[plugin_id] = factory;
+ plugin_names_[plugin_id] = plugin_name;
+ return port::Status::OK();
+}
+
+template <typename FACTORY_TYPE>
+port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
+ PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
+ const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
+ auto iter = factories.find(plugin_id);
+ if (iter == factories.end()) {
+ iter = generic_factories.find(plugin_id);
+ if (iter == generic_factories.end()) {
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::Printf("Plugin ID %p not registered.", plugin_id)};
+ }
+ }
+
+ return iter->second;
+}
+
+bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
+ PluginKind plugin_kind,
+ PluginId plugin_id) {
+ if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
+ port::StatusOr<Platform*> status =
+ MultiPlatformManager::PlatformWithId(platform_id);
+ string platform_name = "<unregistered platform>";
+ if (status.ok()) {
+ platform_name = status.ValueOrDie()->Name();
+ }
+
+ LOG(ERROR) << "A factory must be registered for a platform before being "
+ << "set as default! "
+ << "Platform name: " << platform_name
+ << ", PluginKind: " << PluginKindString(plugin_kind)
+ << ", PluginId: " << plugin_id;
+ return false;
+ }
+
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ default_factories_[platform_id].blas = plugin_id;
+ break;
+ case PluginKind::kDnn:
+ default_factories_[platform_id].dnn = plugin_id;
+ break;
+ case PluginKind::kFft:
+ default_factories_[platform_id].fft = plugin_id;
+ break;
+ case PluginKind::kRng:
+ default_factories_[platform_id].rng = plugin_id;
+ break;
+ default:
+ LOG(ERROR) << "Invalid plugin kind specified: "
+ << static_cast<int>(plugin_kind);
+ return false;
+ }
+
+ return true;
+}
+
+bool PluginRegistry::HasFactory(const PluginFactories& factories,
+ PluginKind plugin_kind,
+ PluginId plugin_id) const {
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ return factories.blas.find(plugin_id) != factories.blas.end();
+ case PluginKind::kDnn:
+ return factories.dnn.find(plugin_id) != factories.dnn.end();
+ case PluginKind::kFft:
+ return factories.fft.find(plugin_id) != factories.fft.end();
+ case PluginKind::kRng:
+ return factories.rng.find(plugin_id) != factories.rng.end();
+ default:
+ LOG(ERROR) << "Invalid plugin kind specified: "
+ << PluginKindString(plugin_kind);
+ return false;
+ }
+}
+
+bool PluginRegistry::HasFactory(Platform::Id platform_id,
+ PluginKind plugin_kind,
+ PluginId plugin_id) const {
+ auto iter = factories_.find(platform_id);
+ if (iter != factories_.end()) {
+ if (HasFactory(iter->second, plugin_kind, plugin_id)) {
+ return true;
+ }
+ }
+
+ return HasFactory(generic_factories_, plugin_kind, plugin_id);
+}
+
+// Explicit instantiations to support types exposed in user/public API.
+#define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
+ template port::StatusOr<PluginRegistry::FACTORY_TYPE> \
+ PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
+ PluginId plugin_id, \
+ const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \
+ const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \
+ generic_factories) const; \
+ \
+ template port::Status \
+ PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
+ PluginId plugin_id, const string& plugin_name, \
+ PluginRegistry::FACTORY_TYPE factory, \
+ std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \
+ \
+ template <> \
+ port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
+ Platform::Id platform_id, PluginId plugin_id, const string& name, \
+ PluginRegistry::FACTORY_TYPE factory) { \
+ return RegisterFactoryInternal(plugin_id, name, factory, \
+ &factories_[platform_id].FACTORY_VAR); \
+ } \
+ \
+ template <> \
+ port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
+ PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \
+ PluginRegistry::FACTORY_TYPE factory) { \
+ return RegisterFactoryInternal(plugin_id, name, factory, \
+ &generic_factories_.FACTORY_VAR); \
+ } \
+ \
+ template <> \
+ port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
+ Platform::Id platform_id, PluginId plugin_id) { \
+ if (plugin_id == PluginConfig::kDefault) { \
+ plugin_id = default_factories_[platform_id].FACTORY_VAR; \
+ \
+ if (plugin_id == kNullPlugin) { \
+ return port::Status{port::error::FAILED_PRECONDITION, \
+ "No suitable " PLUGIN_STRING \
+ " plugin registered, default or otherwise."}; \
+ } else { \
+ VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
+ << plugin_names_[plugin_id]; \
+ } \
+ } \
+ return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
+ generic_factories_.FACTORY_VAR); \
+ } \
+ \
+ /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \
+ template <> \
+ port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
+ PlatformKind platform_kind, PluginId plugin_id) { \
+ auto iter = platform_id_by_kind_.find(platform_kind); \
+ if (iter == platform_id_by_kind_.end()) { \
+ return port::Status{port::error::FAILED_PRECONDITION, \
+ port::Printf("Platform kind %d not registered.", \
+ static_cast<int>(platform_kind))}; \
+ } \
+ return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
+ }
+
+EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
+EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
+EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
+EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
+
+} // namespace gputools
+} // namespace perftools