diff options
Diffstat (limited to 'tensorflow/stream_executor/plugin_registry.cc')
-rw-r--r-- | tensorflow/stream_executor/plugin_registry.cc | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc new file mode 100644 index 0000000000..eda44d1146 --- /dev/null +++ b/tensorflow/stream_executor/plugin_registry.cc @@ -0,0 +1,228 @@ +#include "tensorflow/stream_executor/plugin_registry.h" + +#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" + +namespace perftools { +namespace gputools { + +const PluginId kNullPlugin = nullptr; + +// Returns the string representation of the specified PluginKind. +string PluginKindString(PluginKind plugin_kind) { + switch (plugin_kind) { + case PluginKind::kBlas: + return "BLAS"; + case PluginKind::kDnn: + return "DNN"; + case PluginKind::kFft: + return "FFT"; + case PluginKind::kRng: + return "RNG"; + case PluginKind::kInvalid: + default: + return "kInvalid"; + } +} + +PluginRegistry::DefaultFactories::DefaultFactories() : + blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { } + +/* static */ mutex PluginRegistry::mu_(LINKER_INITIALIZED); +/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr; + +PluginRegistry::PluginRegistry() {} + +/* static */ PluginRegistry* PluginRegistry::Instance() { + mutex_lock lock{mu_}; + if (instance_ == nullptr) { + instance_ = new PluginRegistry(); + } + return instance_; +} + +void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind, + Platform::Id platform_id) { + platform_id_by_kind_[platform_kind] = platform_id; +} + +template <typename FACTORY_TYPE> +port::Status PluginRegistry::RegisterFactoryInternal( + PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, + std::map<PluginId, FACTORY_TYPE>* factories) { + mutex_lock lock{mu_}; + + if (factories->find(plugin_id) != factories->end()) { + return port::Status{ + port::error::ALREADY_EXISTS, + port::Printf("Attempting to register factory for plugin %s when " + "one has already been registered", + plugin_name.c_str())}; + } + + (*factories)[plugin_id] = factory; + plugin_names_[plugin_id] = plugin_name; + return port::Status::OK(); +} + +template <typename FACTORY_TYPE> +port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal( + PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories, + const std::map<PluginId, FACTORY_TYPE>& generic_factories) const { + auto iter = factories.find(plugin_id); + if (iter == factories.end()) { + iter = generic_factories.find(plugin_id); + if (iter == generic_factories.end()) { + return port::Status{ + port::error::NOT_FOUND, + port::Printf("Plugin ID %p not registered.", plugin_id)}; + } + } + + return iter->second; +} + +bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id, + PluginKind plugin_kind, + PluginId plugin_id) { + if (!HasFactory(platform_id, plugin_kind, plugin_id)) { + port::StatusOr<Platform*> status = + MultiPlatformManager::PlatformWithId(platform_id); + string platform_name = "<unregistered platform>"; + if (status.ok()) { + platform_name = status.ValueOrDie()->Name(); + } + + LOG(ERROR) << "A factory must be registered for a platform before being " + << "set as default! " + << "Platform name: " << platform_name + << ", PluginKind: " << PluginKindString(plugin_kind) + << ", PluginId: " << plugin_id; + return false; + } + + switch (plugin_kind) { + case PluginKind::kBlas: + default_factories_[platform_id].blas = plugin_id; + break; + case PluginKind::kDnn: + default_factories_[platform_id].dnn = plugin_id; + break; + case PluginKind::kFft: + default_factories_[platform_id].fft = plugin_id; + break; + case PluginKind::kRng: + default_factories_[platform_id].rng = plugin_id; + break; + default: + LOG(ERROR) << "Invalid plugin kind specified: " + << static_cast<int>(plugin_kind); + return false; + } + + return true; +} + +bool PluginRegistry::HasFactory(const PluginFactories& factories, + PluginKind plugin_kind, + PluginId plugin_id) const { + switch (plugin_kind) { + case PluginKind::kBlas: + return factories.blas.find(plugin_id) != factories.blas.end(); + case PluginKind::kDnn: + return factories.dnn.find(plugin_id) != factories.dnn.end(); + case PluginKind::kFft: + return factories.fft.find(plugin_id) != factories.fft.end(); + case PluginKind::kRng: + return factories.rng.find(plugin_id) != factories.rng.end(); + default: + LOG(ERROR) << "Invalid plugin kind specified: " + << PluginKindString(plugin_kind); + return false; + } +} + +bool PluginRegistry::HasFactory(Platform::Id platform_id, + PluginKind plugin_kind, + PluginId plugin_id) const { + auto iter = factories_.find(platform_id); + if (iter != factories_.end()) { + if (HasFactory(iter->second, plugin_kind, plugin_id)) { + return true; + } + } + + return HasFactory(generic_factories_, plugin_kind, plugin_id); +} + +// Explicit instantiations to support types exposed in user/public API. +#define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \ + template port::StatusOr<PluginRegistry::FACTORY_TYPE> \ + PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \ + PluginId plugin_id, \ + const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \ + const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \ + generic_factories) const; \ + \ + template port::Status \ + PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \ + PluginId plugin_id, const string& plugin_name, \ + PluginRegistry::FACTORY_TYPE factory, \ + std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \ + \ + template <> \ + port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \ + Platform::Id platform_id, PluginId plugin_id, const string& name, \ + PluginRegistry::FACTORY_TYPE factory) { \ + return RegisterFactoryInternal(plugin_id, name, factory, \ + &factories_[platform_id].FACTORY_VAR); \ + } \ + \ + template <> \ + port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \ + PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \ + PluginRegistry::FACTORY_TYPE factory) { \ + return RegisterFactoryInternal(plugin_id, name, factory, \ + &generic_factories_.FACTORY_VAR); \ + } \ + \ + template <> \ + port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \ + Platform::Id platform_id, PluginId plugin_id) { \ + if (plugin_id == PluginConfig::kDefault) { \ + plugin_id = default_factories_[platform_id].FACTORY_VAR; \ + \ + if (plugin_id == kNullPlugin) { \ + return port::Status{port::error::FAILED_PRECONDITION, \ + "No suitable " PLUGIN_STRING \ + " plugin registered, default or otherwise."}; \ + } else { \ + VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \ + << plugin_names_[plugin_id]; \ + } \ + } \ + return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \ + generic_factories_.FACTORY_VAR); \ + } \ + \ + /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \ + template <> \ + port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \ + PlatformKind platform_kind, PluginId plugin_id) { \ + auto iter = platform_id_by_kind_.find(platform_kind); \ + if (iter == platform_id_by_kind_.end()) { \ + return port::Status{port::error::FAILED_PRECONDITION, \ + port::Printf("Platform kind %d not registered.", \ + static_cast<int>(platform_kind))}; \ + } \ + return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \ + } + +EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS"); +EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN"); +EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT"); +EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG"); + +} // namespace gputools +} // namespace perftools |