#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" namespace perftools { namespace gputools { const PluginId kNullPlugin = nullptr; // Returns the string representation of the specified PluginKind. string PluginKindString(PluginKind plugin_kind) { switch (plugin_kind) { case PluginKind::kBlas: return "BLAS"; case PluginKind::kDnn: return "DNN"; case PluginKind::kFft: return "FFT"; case PluginKind::kRng: return "RNG"; case PluginKind::kInvalid: default: return "kInvalid"; } } PluginRegistry::DefaultFactories::DefaultFactories() : blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { } /* static */ mutex PluginRegistry::mu_(LINKER_INITIALIZED); /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr; PluginRegistry::PluginRegistry() {} /* static */ PluginRegistry* PluginRegistry::Instance() { mutex_lock lock{mu_}; if (instance_ == nullptr) { instance_ = new PluginRegistry(); } return instance_; } void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind, Platform::Id platform_id) { platform_id_by_kind_[platform_kind] = platform_id; } template port::Status PluginRegistry::RegisterFactoryInternal( PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, std::map* factories) { mutex_lock lock{mu_}; if (factories->find(plugin_id) != factories->end()) { return port::Status{ port::error::ALREADY_EXISTS, port::Printf("Attempting to register factory for plugin %s when " "one has already been registered", plugin_name.c_str())}; } (*factories)[plugin_id] = factory; plugin_names_[plugin_id] = plugin_name; return port::Status::OK(); } template port::StatusOr PluginRegistry::GetFactoryInternal( PluginId plugin_id, const std::map& factories, const std::map& generic_factories) const { auto iter = factories.find(plugin_id); if (iter == factories.end()) { iter = generic_factories.find(plugin_id); if (iter == generic_factories.end()) { return port::Status{ port::error::NOT_FOUND, port::Printf("Plugin ID %p not registered.", plugin_id)}; } } return iter->second; } bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind, PluginId plugin_id) { if (!HasFactory(platform_id, plugin_kind, plugin_id)) { port::StatusOr status = MultiPlatformManager::PlatformWithId(platform_id); string platform_name = ""; if (status.ok()) { platform_name = status.ValueOrDie()->Name(); } LOG(ERROR) << "A factory must be registered for a platform before being " << "set as default! " << "Platform name: " << platform_name << ", PluginKind: " << PluginKindString(plugin_kind) << ", PluginId: " << plugin_id; return false; } switch (plugin_kind) { case PluginKind::kBlas: default_factories_[platform_id].blas = plugin_id; break; case PluginKind::kDnn: default_factories_[platform_id].dnn = plugin_id; break; case PluginKind::kFft: default_factories_[platform_id].fft = plugin_id; break; case PluginKind::kRng: default_factories_[platform_id].rng = plugin_id; break; default: LOG(ERROR) << "Invalid plugin kind specified: " << static_cast(plugin_kind); return false; } return true; } bool PluginRegistry::HasFactory(const PluginFactories& factories, PluginKind plugin_kind, PluginId plugin_id) const { switch (plugin_kind) { case PluginKind::kBlas: return factories.blas.find(plugin_id) != factories.blas.end(); case PluginKind::kDnn: return factories.dnn.find(plugin_id) != factories.dnn.end(); case PluginKind::kFft: return factories.fft.find(plugin_id) != factories.fft.end(); case PluginKind::kRng: return factories.rng.find(plugin_id) != factories.rng.end(); default: LOG(ERROR) << "Invalid plugin kind specified: " << PluginKindString(plugin_kind); return false; } } bool PluginRegistry::HasFactory(Platform::Id platform_id, PluginKind plugin_kind, PluginId plugin_id) const { auto iter = factories_.find(platform_id); if (iter != factories_.end()) { if (HasFactory(iter->second, plugin_kind, plugin_id)) { return true; } } return HasFactory(generic_factories_, plugin_kind, plugin_id); } // Explicit instantiations to support types exposed in user/public API. #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \ template port::StatusOr \ PluginRegistry::GetFactoryInternal( \ PluginId plugin_id, \ const std::map& factories, \ const std::map& \ generic_factories) const; \ \ template port::Status \ PluginRegistry::RegisterFactoryInternal( \ PluginId plugin_id, const string& plugin_name, \ PluginRegistry::FACTORY_TYPE factory, \ std::map* factories); \ \ template <> \ port::Status PluginRegistry::RegisterFactory( \ Platform::Id platform_id, PluginId plugin_id, const string& name, \ PluginRegistry::FACTORY_TYPE factory) { \ return RegisterFactoryInternal(plugin_id, name, factory, \ &factories_[platform_id].FACTORY_VAR); \ } \ \ template <> \ port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \ PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \ PluginRegistry::FACTORY_TYPE factory) { \ return RegisterFactoryInternal(plugin_id, name, factory, \ &generic_factories_.FACTORY_VAR); \ } \ \ template <> \ port::StatusOr PluginRegistry::GetFactory( \ Platform::Id platform_id, PluginId plugin_id) { \ if (plugin_id == PluginConfig::kDefault) { \ plugin_id = default_factories_[platform_id].FACTORY_VAR; \ \ if (plugin_id == kNullPlugin) { \ return port::Status{port::error::FAILED_PRECONDITION, \ "No suitable " PLUGIN_STRING \ " plugin registered, default or otherwise."}; \ } else { \ VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \ << plugin_names_[plugin_id]; \ } \ } \ return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \ generic_factories_.FACTORY_VAR); \ } \ \ /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \ template <> \ port::StatusOr PluginRegistry::GetFactory( \ PlatformKind platform_kind, PluginId plugin_id) { \ auto iter = platform_id_by_kind_.find(platform_kind); \ if (iter == platform_id_by_kind_.end()) { \ return port::Status{port::error::FAILED_PRECONDITION, \ port::Printf("Platform kind %d not registered.", \ static_cast(platform_kind))}; \ } \ return GetFactory(iter->second, plugin_id); \ } EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS"); EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN"); EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT"); EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG"); } // namespace gputools } // namespace perftools