aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_platform.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/stream_executor/cuda/cuda_platform.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_platform.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc172
1 files changed, 172 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
new file mode 100644
index 0000000000..ef88b89eda
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -0,0 +1,172 @@
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLATFORM_DEFINE_ID(kCudaPlatformId);
+
+CudaPlatform::CudaPlatform()
+ : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
+
+CudaPlatform::~CudaPlatform() {}
+
+// Due to legacy issues in user code, we can't currently call InpectNumaNodes
+// at module initialization time, because non-GPU programs still include this
+// plugin via various methods, so instead, it has to be init-on-reference.
+void CudaPlatform::InspectNumaNodes() {
+ // To get NUMA node information, we need to create all executors, so we can
+ // examine their device descriptions to see their bus assignments.
+ static bool initialized = false;
+ static mutex numa_mutex(LINKER_INITIALIZED);
+ mutex_lock lock(numa_mutex);
+ if (initialized) {
+ return;
+ }
+
+ StreamExecutorConfig config;
+ for (int i = 0; i < VisibleDeviceCount(); i++) {
+ config.ordinal = i;
+ StreamExecutor* exec = GetExecutor(config).ValueOrDie();
+ if (i == 0) {
+ // NUMA nodes may not start at 0, so set the minimum node based on the
+ // first executor we see.
+ min_numa_node_ = exec->GetDeviceDescription().numa_node();
+ limit_numa_node_ = min_numa_node_ + 1;
+ } else {
+ min_numa_node_ =
+ std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
+ limit_numa_node_ = std::max(limit_numa_node_,
+ exec->GetDeviceDescription().numa_node() + 1);
+ }
+ }
+ initialized = true;
+}
+
+int CudaPlatform::BusCount() {
+ InspectNumaNodes();
+ return limit_numa_node_ - min_numa_node_;
+}
+
+int CudaPlatform::DeviceToBus(int device_ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ StreamExecutor* exec = GetExecutor(config).ValueOrDie();
+ return exec->GetDeviceDescription().numa_node() - min_numa_node_;
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
+ int bus_ordinal) {
+ InspectNumaNodes();
+ CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
+ for (int i = 0; i < VisibleDeviceCount(); i++) {
+ if (DeviceToBus(i) == bus_ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = i;
+ return GetExecutor(config).ValueOrDie();
+ }
+ }
+
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::Printf("Executor for bus %d not found.", bus_ordinal)};
+}
+
+Platform::Id CudaPlatform::id() const { return kCudaPlatformId; }
+
+int CudaPlatform::VisibleDeviceCount() const {
+ // Throw away the result - it logs internally, and this [containing] function
+ // isn't in the path of user control. It's safe to call this > 1x.
+ if (!cuda::CUDADriver::Init().ok()) {
+ return -1;
+ }
+
+ return CUDADriver::GetDeviceCount();
+}
+
+const string& CudaPlatform::Name() const { return name_; }
+
+port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ config.plugin_config = PluginConfig();
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
+ int device_ordinal, const PluginConfig& plugin_config) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ config.plugin_config = plugin_config;
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
+ const StreamExecutorConfig& config) {
+ mutex_lock lock(mu_);
+
+ port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
+ if (status.ok()) {
+ return status.ValueOrDie();
+ }
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
+ GetUncachedExecutor(config);
+ if (!executor.ok()) {
+ return executor.status();
+ }
+
+ StreamExecutor* naked_executor = executor.ValueOrDie().get();
+ executor_cache_.Insert(config, executor.ConsumeValueOrDie());
+ return naked_executor;
+}
+
+port::StatusOr<std::unique_ptr<StreamExecutor>>
+CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
+ auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda,
+ config.plugin_config);
+ auto init_status = executor->Init(config.ordinal, config.device_options);
+ if (!init_status.ok()) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf(
+ "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
+ config.ordinal, init_status.ToString().c_str())};
+ }
+
+ return std::move(executor);
+}
+
+void CudaPlatform::RegisterTraceListener(
+ std::unique_ptr<TraceListener> listener) {
+ LOG(FATAL) << "not yet implemented: register CUDA trace listener";
+}
+
+void CudaPlatform::UnregisterTraceListener(TraceListener* listener) {
+ LOG(FATAL) << "not yet implemented: unregister CUDA trace listener";
+}
+
+} // namespace cuda
+
+static void InitializeCudaPlatform() {
+ // Disabling leak checking, MultiPlatformManager does not destroy its
+ // registered platforms.
+
+ std::unique_ptr<cuda::CudaPlatform> platform(new cuda::CudaPlatform);
+ SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
+}
+
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(cuda_platform,
+ perftools::gputools::InitializeCudaPlatform());