blob: a65add05c5c51189eee2814ebce6fd58c6c9c3fb (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
namespace perftools {
namespace gputools {
/* static */ mutex MultiPlatformManager::platforms_mutex_(LINKER_INITIALIZED);
/* static */ port::Status MultiPlatformManager::RegisterPlatform(
std::unique_ptr<Platform> platform) {
CHECK(platform != nullptr);
string key = port::Lowercase(platform->Name());
mutex_lock lock(platforms_mutex_);
if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
return port::Status(port::error::INTERNAL,
"platform is already registered with name: \"" +
platform->Name() + "\"");
}
GetPlatformByIdMap()->insert(std::make_pair(platform->id(), platform.get()));
// Release ownership/uniqueness to prevent destruction on program exit.
// This avoids Platforms "cleaning up" on program exit, because otherwise,
// there are _very_ tricky races between StreamExecutor and underlying
// platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
// program, these are deemed acceptable.
(*GetPlatformMap())[key] = platform.release();
return port::Status::OK();
}
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
const string& target) {
mutex_lock lock(platforms_mutex_);
auto it = GetPlatformMap()->find(port::Lowercase(target));
if (it == GetPlatformMap()->end()) {
return port::Status(
port::error::NOT_FOUND,
"could not find registered platform with name: \"" + target + "\"");
}
return it->second;
}
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
const Platform::Id& id) {
mutex_lock lock(platforms_mutex_);
auto it = GetPlatformByIdMap()->find(id);
if (it == GetPlatformByIdMap()->end()) {
return port::Status(
port::error::NOT_FOUND,
port::Printf("could not find registered platform with id: 0x%p", id));
}
return it->second;
}
/* static */ void MultiPlatformManager::ClearPlatformRegistry() {
mutex_lock lock(platforms_mutex_);
GetPlatformMap()->clear();
GetPlatformByIdMap()->clear();
}
} // namespace gputools
} // namespace perftools
|