aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/machine_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/machine_manager.cc')
-rw-r--r--tensorflow/stream_executor/machine_manager.cc276
1 files changed, 276 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/machine_manager.cc b/tensorflow/stream_executor/machine_manager.cc
new file mode 100644
index 0000000000..6d7bc50379
--- /dev/null
+++ b/tensorflow/stream_executor/machine_manager.cc
@@ -0,0 +1,276 @@
+#include "tensorflow/stream_executor/machine_manager.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+mutex MachineManager::mu_{LINKER_INITIALIZED};
+
+MachineManager *MachineManager::singleton_ = nullptr;
+
+PlatformKind MachineManager::DetectPreferredPlatform() {
+// TODO(leary) for KNC card experiments, figure out a legitimate way to
+// determine this. For now, we use a compile-time hint so we can compile tests
+// for both.
+#if defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_OPENCL
+ return PlatformKind::kOpenCL;
+#elif defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_HOST
+ return PlatformKind::kHost;
+#else
+ return PlatformKind::kCuda;
+#endif
+}
+
+/* static */ port::StatusOr<std::unique_ptr<MachineManager>>
+MachineManager::Create(PlatformKind kind, DeviceOptions options,
+ const PluginConfig &config) {
+ std::unique_ptr<MachineManager> machine_manager{
+ new MachineManager{kind, options, config}};
+ auto init_status = machine_manager->Init();
+ if (!init_status.ok()) {
+ return init_status;
+ }
+
+ return std::move(machine_manager);
+}
+
+MachineManager::MachineManager(PlatformKind platform,
+ DeviceOptions device_options,
+ const PluginConfig &config)
+ : platform_(platform),
+ device_options_(device_options),
+ plugin_config_(config),
+ min_numa_node_(0),
+ limit_numa_node_(0) {}
+
+port::Status MachineManager::Init() {
+ // Initialize the first StreamExecutor, then use that platform interface to
+ // grab the device count.
+ executors_.resize(1);
+ executors_[0].reset(new StreamExecutor{platform_, plugin_config_});
+ auto status = executors_[0]->Init(0 /* = device_ordinal */, device_options_);
+ if (!status.ok()) {
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize StreamExecutor for device ordinal 0: ",
+ status.ToString())};
+ }
+ int device_count = executors_[0]->PlatformDeviceCount();
+ if (device_count == 0) {
+ LOG(WARNING) << "no devices found for platform "
+ << PlatformKindString(platform_);
+ min_numa_node_ = limit_numa_node_ = 0;
+ return port::Status::OK();
+ }
+
+ streams_.resize(device_count);
+ streams_[0].reset(new Stream(executors_[0].get()));
+ if (!streams_[0]->Init().ok()) {
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ "failed to initialize default stream for device ordinal 0"};
+ }
+
+ min_numa_node_ = executors_[0]->GetDeviceDescription().numa_node();
+ limit_numa_node_ = min_numa_node_ + 1;
+
+ executors_.resize(device_count);
+ for (int device_ordinal = 1; device_ordinal < device_count;
+ ++device_ordinal) {
+ StreamExecutor *stream_exec = new StreamExecutor{platform_, plugin_config_};
+ executors_[device_ordinal].reset(stream_exec);
+ auto status = stream_exec->Init(device_ordinal, device_options_);
+ if (!status.ok()) {
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize StreamExecutor for device ordinal ",
+ device_ordinal, ": ", status.ToString()));
+ }
+
+ min_numa_node_ = std::min(min_numa_node_,
+ stream_exec->GetDeviceDescription().numa_node());
+ limit_numa_node_ = std::max(
+ limit_numa_node_, stream_exec->GetDeviceDescription().numa_node() + 1);
+
+ if (!stream_exec->GetDeviceDescription().ecc_enabled()) {
+ LOG(WARNING) << "ECC not enabled for device ordinal: " << device_ordinal;
+ }
+
+ streams_[device_ordinal].reset(
+ new Stream(executors_[device_ordinal].get()));
+ if (!streams_[device_ordinal]->Init().ok()) {
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize default stream for device ordinal ",
+ device_ordinal));
+ }
+ }
+
+ return port::Status::OK();
+}
+
+int MachineManager::device_count() const { return executors_.size(); }
+
+port::Status MachineManager::EnablePeerAccess() {
+ auto peer_access_map = GetPeerAccessMap();
+ for (const auto &access : *peer_access_map) {
+ auto devices = access.first;
+ if (access.second) {
+ StreamExecutor *from = executors_[devices.first].get();
+ StreamExecutor *to = executors_[devices.second].get();
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return status;
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal "
+ << devices.first << " to device ordinal " << devices.second;
+ }
+ }
+ return port::Status::OK();
+}
+
+std::unique_ptr<std::map<std::pair<int, int>, bool>>
+MachineManager::GetPeerAccessMap() {
+ auto *map = new std::map<std::pair<int, int>, bool>;
+ for (int i = 0; i < device_count(); ++i) {
+ for (int j = 0; j < device_count(); ++j) {
+ StreamExecutor *from = executors_[i].get();
+ StreamExecutor *to = executors_[j].get();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<std::map<std::pair<int, int>, bool>>{map};
+}
+
+StreamExecutor *MachineManager::executor_for_device(int device_ordinal) const {
+ CHECK_GE(device_ordinal, 0) << "device ordinal must be non-negative";
+ CHECK(0 <= device_ordinal && device_ordinal < device_count())
+ << "device " << device_ordinal << " out of range with device count "
+ << device_count();
+ StreamExecutor *executor = executors_[device_ordinal].get();
+ CHECK(executor != nullptr);
+ return executor;
+}
+
+int MachineManager::ExecutorToBus(const StreamExecutor *stream_exec) const {
+ return stream_exec->GetDeviceDescription().numa_node() - min_numa_node_;
+}
+
+int MachineManager::DeviceToBus(int device_ordinal) const {
+ return ExecutorToBus(executor_for_device(device_ordinal));
+}
+
+int MachineManager::ExecutorToNumaNode(
+ const StreamExecutor *stream_exec) const {
+ return stream_exec->GetDeviceDescription().numa_node();
+}
+
+int MachineManager::DeviceToNumaNode(int device_ordinal) const {
+ return ExecutorToNumaNode(executor_for_device(device_ordinal));
+}
+
+StreamExecutor *MachineManager::first_executor_for_bus(int bus_ordinal) {
+ CHECK_LT(bus_ordinal, bus_count()) << "bus ordinal out of available range";
+ for (auto &executor : executors_) {
+ if (ExecutorToBus(executor.get()) == bus_ordinal) {
+ return executor.get();
+ }
+ }
+
+ LOG(WARNING) << "could not find executor requested for bus ordinal: "
+ << bus_ordinal;
+ return nullptr;
+}
+
+StreamExecutor *MachineManager::first_executor_for_numa_node(int numa_node) {
+ for (auto &executor : executors_) {
+ if (ExecutorToNumaNode(executor.get()) == numa_node) {
+ return executor.get();
+ }
+ }
+
+ LOG(WARNING) << "could not find executor requested for numa_node: "
+ << numa_node;
+ return nullptr;
+}
+
+Stream *MachineManager::stream_for_device(int device_ordinal) {
+ CHECK(0 <= device_ordinal && device_ordinal < device_count());
+ Stream *stream = streams_[device_ordinal].get();
+ CHECK(stream != nullptr);
+ return stream;
+}
+
+/* static */ port::StatusOr<MachineManager *>
+MachineManager::CreateSingletonInternal(PlatformKind platform,
+ DeviceOptions options,
+ const PluginConfig &config) {
+ if (singleton_ != nullptr) {
+ return port::Status{
+ port::error::ALREADY_EXISTS,
+ "cannot create machine manager singleton; one already exists"};
+ }
+
+ auto create_status = Create(platform, options, config);
+ if (!create_status.ok()) {
+ return create_status.status();
+ }
+
+ singleton_ = create_status.ConsumeValueOrDie().release();
+
+ VLOG(1) << "machine manager singleton is " << singleton_ << " with platform "
+ << PlatformKindString(platform) << " and device options "
+ << options.ToString();
+
+ return singleton_;
+}
+
+/* static */ MachineManager *MachineManager::CreateSingletonOrDie(
+ PlatformKind platform, DeviceOptions options, const PluginConfig &config) {
+ auto status = CreateSingleton(platform, options, config);
+ if (!status.ok()) {
+ LOG(FATAL) << "failed to create MachineManager singleton: "
+ << status.status();
+ }
+ return status.ValueOrDie();
+}
+
+/* static */ port::StatusOr<MachineManager *> MachineManager::CreateSingleton(
+ PlatformKind platform, DeviceOptions device_options,
+ const PluginConfig &config) {
+ mutex_lock lock{mu_};
+ return CreateSingletonInternal(platform, device_options, config);
+}
+
+/* static */ MachineManager *MachineManager::singleton() {
+ mutex_lock lock{mu_};
+ if (singleton_ == nullptr) {
+ PlatformKind platform = DetectPreferredPlatform();
+ DeviceOptions options = DeviceOptions::Default();
+ auto status = CreateSingletonInternal(platform, options, PluginConfig());
+ if (!status.ok()) {
+ LOG(FATAL)
+ << "failed to create MachineManager singleton: "
+ "singleton accessor attempted lazy construction but failed: "
+ << status.status();
+ }
+ return status.ValueOrDie();
+ }
+
+ return singleton_;
+}
+
+} // namespace gputools
+} // namespace perftools