aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/platform.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/platform.cc')
-rw-r--r--tensorflow/stream_executor/platform.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc
new file mode 100644
index 0000000000..8be9353bbe
--- /dev/null
+++ b/tensorflow/stream_executor/platform.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/stream_executor/platform.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+
+string PlatformKindString(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ return "CUDA";
+ case PlatformKind::kOpenCL:
+ return "OpenCL";
+ case PlatformKind::kOpenCLAltera:
+ return "OpenCL+Altera";
+ case PlatformKind::kHost:
+ return "Host";
+ case PlatformKind::kMock:
+ return "Mock";
+ default:
+ return port::StrCat("InvalidPlatformKind(", static_cast<int>(kind), ")");
+ }
+}
+
+PlatformKind PlatformKindFromString(string kind) {
+ for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) {
+ if (kind == PlatformKindString(static_cast<PlatformKind>(i))) {
+ return static_cast<PlatformKind>(i);
+ }
+ }
+
+ return PlatformKind::kInvalid;
+}
+
+bool PlatformIsRunnable(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ case PlatformKind::kOpenCL:
+ case PlatformKind::kHost:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool PlatformIsRunnableOnDevice(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ case PlatformKind::kOpenCL:
+ return true;
+ default:
+ return false;
+ }
+}
+
+void CheckPlatformKindIsValid(PlatformKind kind) {
+ CHECK(static_cast<int>(PlatformKind::kCuda) <= static_cast<int>(kind) &&
+ static_cast<int>(kind) <= static_cast<int>(PlatformKind::kMock))
+ << "invalid GPU executor kind: " << PlatformKindString(kind);
+}
+
+StreamExecutorConfig::StreamExecutorConfig()
+ : ordinal(-1), device_options(DeviceOptions::Default()) {}
+
+StreamExecutorConfig::StreamExecutorConfig(int ordinal_in)
+ : ordinal(ordinal_in), device_options(DeviceOptions::Default()) {}
+
+Platform::~Platform() {}
+
+port::Status Platform::ForceExecutorShutdown() {
+ return port::Status(port::error::UNIMPLEMENTED,
+ "executor shutdown is not supported on this platform");
+}
+
+std::unique_ptr<Platform::PeerAccessMap> Platform::GetPeerAccessMap() {
+ auto *map = new PeerAccessMap;
+
+ int device_count = VisibleDeviceCount();
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ StreamExecutor *from = ExecutorForDevice(i).ValueOrDie();
+ StreamExecutor *to = ExecutorForDevice(j).ValueOrDie();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<Platform::PeerAccessMap>{map};
+}
+
+port::Status Platform::EnablePeerAccess() {
+ auto peer_access_map = GetPeerAccessMap();
+ for (const auto &access : *peer_access_map) {
+ auto devices = access.first;
+ if (access.second) {
+ StreamExecutor *from = ExecutorForDevice(devices.first).ValueOrDie();
+ StreamExecutor *to = ExecutorForDevice(devices.second).ValueOrDie();
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return status;
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal "
+ << devices.first << " to device ordinal " << devices.second;
+ }
+ }
+ return port::Status::OK();
+}
+
+} // namespace gputools
+} // namespace perftools