aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/multi_platform_manager.cc
blob: a65add05c5c51189eee2814ebce6fd58c6c9c3fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "tensorflow/stream_executor/multi_platform_manager.h"

#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"

namespace perftools {
namespace gputools {

/* static */ mutex MultiPlatformManager::platforms_mutex_(LINKER_INITIALIZED);

/* static */ port::Status MultiPlatformManager::RegisterPlatform(
    std::unique_ptr<Platform> platform) {
  CHECK(platform != nullptr);
  string key = port::Lowercase(platform->Name());
  mutex_lock lock(platforms_mutex_);
  if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
    return port::Status(port::error::INTERNAL,
                        "platform is already registered with name: \"" +
                            platform->Name() + "\"");
  }
  GetPlatformByIdMap()->insert(std::make_pair(platform->id(), platform.get()));
  // Release ownership/uniqueness to prevent destruction on program exit.
  // This avoids Platforms "cleaning up" on program exit, because otherwise,
  // there are _very_ tricky races between StreamExecutor and underlying
  // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
  // program, these are deemed acceptable.
  (*GetPlatformMap())[key] = platform.release();
  return port::Status::OK();
}

/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
    const string& target) {
  mutex_lock lock(platforms_mutex_);
  auto it = GetPlatformMap()->find(port::Lowercase(target));

  if (it == GetPlatformMap()->end()) {
    return port::Status(
        port::error::NOT_FOUND,
        "could not find registered platform with name: \"" + target + "\"");
  }

  return it->second;
}

/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
    const Platform::Id& id) {
  mutex_lock lock(platforms_mutex_);
  auto it = GetPlatformByIdMap()->find(id);
  if (it == GetPlatformByIdMap()->end()) {
    return port::Status(
        port::error::NOT_FOUND,
        port::Printf("could not find registered platform with id: 0x%p", id));
  }

  return it->second;
}

/* static */ void MultiPlatformManager::ClearPlatformRegistry() {
  mutex_lock lock(platforms_mutex_);
  GetPlatformMap()->clear();
  GetPlatformByIdMap()->clear();
}

}  // namespace gputools
}  // namespace perftools