aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/multi_platform_manager.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 11:38:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 11:43:00 -0800
commit05c31035abedb2983899c49d172ac0382b6eceb7 (patch)
treefbf038709425824522bf8599634b2f29a241c842 /tensorflow/stream_executor/multi_platform_manager.cc
parenta6a0c0bf9486c11793b7dd0b4883a75ff3dcf3f3 (diff)
[SE] Initial perftools::gputools::Platform initialization support
Adds initialization methods to Platform. Some platforms require initialization. Those that do not have trivial implementations of these methods. PiperOrigin-RevId: 188363315
Diffstat (limited to 'tensorflow/stream_executor/multi_platform_manager.cc')
-rw-r--r--tensorflow/stream_executor/multi_platform_manager.cc86
1 files changed, 71 insertions, 15 deletions
diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc
index f23224ae77..f9f3737a06 100644
--- a/tensorflow/stream_executor/multi_platform_manager.cc
+++ b/tensorflow/stream_executor/multi_platform_manager.cc
@@ -23,11 +23,37 @@ limitations under the License.
namespace perftools {
namespace gputools {
+/* static */ mutex MultiPlatformManager::platforms_mutex_{LINKER_INITIALIZED};
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByNameLocked(
+ const string& target) {
+ PlatformMap* platform_map = GetPlatformMap();
+ auto it = platform_map->find(port::Lowercase(target));
+ if (it == platform_map->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ "could not find registered platform with name: \"" + target + "\"");
+ }
+ return it->second;
+}
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByIdLocked(
+ const Platform::Id& id) {
+ PlatformIdMap* platform_map = GetPlatformByIdMap();
+ auto it = platform_map->find(id);
+ if (it == platform_map->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::Printf("could not find registered platform with id: 0x%p", id));
+ }
+ return it->second;
+}
+
/* static */ port::Status MultiPlatformManager::RegisterPlatform(
std::unique_ptr<Platform> platform) {
CHECK(platform != nullptr);
string key = port::Lowercase(platform->Name());
- mutex_lock lock(GetPlatformsMutex());
+ mutex_lock lock(platforms_mutex_);
if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
return port::Status(port::error::INTERNAL,
"platform is already registered with name: \"" +
@@ -45,33 +71,63 @@ namespace gputools {
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
const string& target) {
- tf_shared_lock lock(GetPlatformsMutex());
- auto it = GetPlatformMap()->find(port::Lowercase(target));
+ mutex_lock lock(platforms_mutex_);
- if (it == GetPlatformMap()->end()) {
- return port::Status(
- port::error::NOT_FOUND,
- "could not find registered platform with name: \"" + target + "\"");
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+ if (!platform->Initialized()) {
+ SE_RETURN_IF_ERROR(platform->Initialize({}));
}
- return it->second;
+ return platform;
}
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
const Platform::Id& id) {
- tf_shared_lock lock(GetPlatformsMutex());
- auto it = GetPlatformByIdMap()->find(id);
- if (it == GetPlatformByIdMap()->end()) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+ if (!platform->Initialized()) {
+ SE_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+
+ return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithName(
+ const string& target, const std::map<string, string>& options) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+ if (platform->Initialized()) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "platform \"" + target + "\" is already initialized");
+ }
+
+ SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+ return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithId(
+ const Platform::Id& id, const std::map<string, string>& options) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+ if (platform->Initialized()) {
return port::Status(
- port::error::NOT_FOUND,
- port::Printf("could not find registered platform with id: 0x%p", id));
+ port::error::FAILED_PRECONDITION,
+ port::Printf("platform with id 0x%p is already initialized", id));
}
- return it->second;
+ SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+ return platform;
}
/* static */ void MultiPlatformManager::ClearPlatformRegistry() {
- mutex_lock lock(GetPlatformsMutex());
+ mutex_lock lock(platforms_mutex_);
GetPlatformMap()->clear();
GetPlatformByIdMap()->clear();
}