diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-08 11:38:46 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-08 11:43:00 -0800 |
commit | 05c31035abedb2983899c49d172ac0382b6eceb7 (patch) | |
tree | fbf038709425824522bf8599634b2f29a241c842 /tensorflow/stream_executor/multi_platform_manager.h | |
parent | a6a0c0bf9486c11793b7dd0b4883a75ff3dcf3f3 (diff) |
[SE] Initial perftools::gputools::Platform initialization support
Adds initialization methods to Platform. Some platforms require initialization.
Those that do not have trivial implementations of these methods.
PiperOrigin-RevId: 188363315
Diffstat (limited to 'tensorflow/stream_executor/multi_platform_manager.h')
-rw-r--r-- | tensorflow/stream_executor/multi_platform_manager.h | 63 |
1 files changed, 44 insertions, 19 deletions
diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index ea6155b482..438653ee20 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -67,13 +67,13 @@ limitations under the License. #include <functional> #include <map> #include <memory> -#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" namespace perftools { namespace gputools { @@ -85,26 +85,43 @@ class MultiPlatformManager { // already registered. The associated listener, if not null, will be used to // trace events for ALL executors for that platform. // Takes ownership of listener. - static port::Status RegisterPlatform(std::unique_ptr<Platform> platform); + static port::Status RegisterPlatform(std::unique_ptr<Platform> platform) + LOCKS_EXCLUDED(platforms_mutex_); - // Retrieves the platform registered with the given platform name; e.g. - // "CUDA", "OpenCL", ... + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // If the platform has not already been initialized, it will be initialized + // with a default set of parameters. // // If the requested platform is not registered, an error status is returned. // Ownership of the platform is NOT transferred to the caller -- // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static port::StatusOr<Platform*> PlatformWithName(const string& target); - - // Retrieves the platform registered with the given platform ID, which - // is an opaque (but comparable) value. + static port::StatusOr<Platform*> PlatformWithName(const string& target) + LOCKS_EXCLUDED(platforms_mutex_); + static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id) + LOCKS_EXCLUDED(platforms_mutex_); + + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // The platform will be initialized with the given options. If the platform + // was already initialized, an error will be returned. // // If the requested platform is not registered, an error status is returned. // Ownership of the platform is NOT transferred to the caller -- // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id); + static port::StatusOr<Platform*> InitializePlatformWithName( + const string& target, const std::map<string, string>& options) + LOCKS_EXCLUDED(platforms_mutex_); + static port::StatusOr<Platform*> InitializePlatformWithId( + const Platform::Id& id, const std::map<string, string>& options) + LOCKS_EXCLUDED(platforms_mutex_); // Clears the set of registered platforms, primarily used for testing. - static void ClearPlatformRegistry(); + static void ClearPlatformRegistry() LOCKS_EXCLUDED(platforms_mutex_); // Although the MultiPlatformManager "owns" its platforms, it holds them as // undecorated pointers to prevent races during program exit (between this @@ -122,17 +139,16 @@ class MultiPlatformManager { // Provides access to the available set of platforms under a lock. static port::Status WithPlatforms( - std::function<port::Status(PlatformMap*)> callback) { - mutex_lock lock(GetPlatformsMutex()); + std::function<port::Status(PlatformMap*)> callback) + LOCKS_EXCLUDED(platforms_mutex_) { + mutex_lock lock(platforms_mutex_); return callback(GetPlatformMap()); } private: - // mutex that guards the platform map. - static mutex& GetPlatformsMutex() { - static mutex* platforms_mutex = new mutex; - return *platforms_mutex; - } + using PlatformIdMap = std::map<Platform::Id, Platform*>; + + static mutex platforms_mutex_; // TODO(b/22689637): Clean up these two maps; make sure they coexist nicely. // TODO(b/22689637): Move this (whatever the final/"official" map is) to @@ -147,12 +163,21 @@ class MultiPlatformManager { // Holds a Platform::Id-to-object mapping. // Unlike platforms_ above, this map does not own its contents. - static std::map<Platform::Id, Platform*>* GetPlatformByIdMap() { - using PlatformIdMap = std::map<Platform::Id, Platform*>; + static PlatformIdMap* GetPlatformByIdMap() { static PlatformIdMap* instance = new PlatformIdMap; return instance; } + // Looks up the platform object with the given name. Assumes the Platforms + // mutex is held. + static port::StatusOr<Platform*> LookupByNameLocked(const string& target) + EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_); + + // Looks up the platform object with the given id. Assumes the Platforms + // mutex is held. + static port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id) + EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_); + SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager); }; |