aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_platform.cc
blob: ef88b89edaaba046f83d4c60c29516d5a879855b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include "tensorflow/stream_executor/cuda/cuda_platform.h"

#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"

namespace perftools {
namespace gputools {
namespace cuda {

PLATFORM_DEFINE_ID(kCudaPlatformId);

CudaPlatform::CudaPlatform()
    : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}

CudaPlatform::~CudaPlatform() {}

// Due to legacy issues in user code, we can't currently call InpectNumaNodes
// at module initialization time, because non-GPU programs still include this
// plugin via various methods, so instead, it has to be init-on-reference.
void CudaPlatform::InspectNumaNodes() {
  // To get NUMA node information, we need to create all executors, so we can
  // examine their device descriptions to see their bus assignments.
  static bool initialized = false;
  static mutex numa_mutex(LINKER_INITIALIZED);
  mutex_lock lock(numa_mutex);
  if (initialized) {
    return;
  }

  StreamExecutorConfig config;
  for (int i = 0; i < VisibleDeviceCount(); i++) {
    config.ordinal = i;
    StreamExecutor* exec = GetExecutor(config).ValueOrDie();
    if (i == 0) {
      // NUMA nodes may not start at 0, so set the minimum node  based on the
      // first executor we see.
      min_numa_node_ = exec->GetDeviceDescription().numa_node();
      limit_numa_node_ = min_numa_node_ + 1;
    } else {
      min_numa_node_ =
          std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
      limit_numa_node_ = std::max(limit_numa_node_,
                                  exec->GetDeviceDescription().numa_node() + 1);
    }
  }
  initialized = true;
}

int CudaPlatform::BusCount() {
  InspectNumaNodes();
  return limit_numa_node_ - min_numa_node_;
}

int CudaPlatform::DeviceToBus(int device_ordinal) {
  StreamExecutorConfig config;
  config.ordinal = device_ordinal;
  StreamExecutor* exec = GetExecutor(config).ValueOrDie();
  return exec->GetDeviceDescription().numa_node() - min_numa_node_;
}

port::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
    int bus_ordinal) {
  InspectNumaNodes();
  CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
  for (int i = 0; i < VisibleDeviceCount(); i++) {
    if (DeviceToBus(i) == bus_ordinal) {
      StreamExecutorConfig config;
      config.ordinal = i;
      return GetExecutor(config).ValueOrDie();
    }
  }

  return port::Status{
      port::error::NOT_FOUND,
      port::Printf("Executor for bus %d not found.", bus_ordinal)};
}

Platform::Id CudaPlatform::id() const { return kCudaPlatformId; }

int CudaPlatform::VisibleDeviceCount() const {
  // Throw away the result - it logs internally, and this [containing] function
  // isn't in the path of user control. It's safe to call this > 1x.
  if (!cuda::CUDADriver::Init().ok()) {
    return -1;
  }

  return CUDADriver::GetDeviceCount();
}

const string& CudaPlatform::Name() const { return name_; }

port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
  StreamExecutorConfig config;
  config.ordinal = ordinal;
  config.plugin_config = PluginConfig();
  config.device_options = DeviceOptions::Default();
  return GetExecutor(config);
}

port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
    int device_ordinal, const PluginConfig& plugin_config) {
  StreamExecutorConfig config;
  config.ordinal = device_ordinal;
  config.plugin_config = plugin_config;
  config.device_options = DeviceOptions::Default();
  return GetExecutor(config);
}

port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
    const StreamExecutorConfig& config) {
  mutex_lock lock(mu_);

  port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
  if (status.ok()) {
    return status.ValueOrDie();
  }

  port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
      GetUncachedExecutor(config);
  if (!executor.ok()) {
    return executor.status();
  }

  StreamExecutor* naked_executor = executor.ValueOrDie().get();
  executor_cache_.Insert(config, executor.ConsumeValueOrDie());
  return naked_executor;
}

port::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
  auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda,
                                                   config.plugin_config);
  auto init_status = executor->Init(config.ordinal, config.device_options);
  if (!init_status.ok()) {
    return port::Status{
        port::error::INTERNAL,
        port::Printf(
            "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
            config.ordinal, init_status.ToString().c_str())};
  }

  return std::move(executor);
}

void CudaPlatform::RegisterTraceListener(
    std::unique_ptr<TraceListener> listener) {
  LOG(FATAL) << "not yet implemented: register CUDA trace listener";
}

void CudaPlatform::UnregisterTraceListener(TraceListener* listener) {
  LOG(FATAL) << "not yet implemented: unregister CUDA trace listener";
}

}  // namespace cuda

static void InitializeCudaPlatform() {
  // Disabling leak checking, MultiPlatformManager does not destroy its
  // registered platforms.
  
  std::unique_ptr<cuda::CudaPlatform> platform(new cuda::CudaPlatform);
  SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
}

}  // namespace gputools
}  // namespace perftools

REGISTER_MODULE_INITIALIZER(cuda_platform,
                            perftools::gputools::InitializeCudaPlatform());