aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma_mgr.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_mgr.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc213
1 files changed, 180 insertions, 33 deletions
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 9cb307bcfa..f3644af0b4 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -16,11 +16,16 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include <fstream>
#include <vector>
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/common_runtime/bfc_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -53,7 +58,7 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
void RdmaMgr::SetupChannels() {
for (const auto& p : channel_table_) {
string worker_name = p.first;
- LOG(INFO) << "connecting to remote node " << worker_name;
+ RDMA_LOG(2) << "Connecting to remote node " << worker_name;
RdmaChannel* rc = p.second;
GetRemoteAddressRequest req;
GetRemoteAddressResponse resp;
@@ -78,39 +83,49 @@ void RdmaMgr::SetupChannels() {
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
}
// synchronous call
- Status s = client->GetRemoteAddress(&req, &resp);
- // save obtained remote addresses
- // connect to the remote channel
- if (s.ok()) {
- CHECK(worker_name.compare(resp.host_name()) == 0);
- RdmaAddress ra;
- ra.lid = resp.channel().lid();
- ra.qpn = resp.channel().qpn();
- ra.psn = resp.channel().psn();
- ra.snp = resp.channel().snp();
- ra.iid = resp.channel().iid();
- rc->SetRemoteAddress(ra, false);
- rc->Connect();
- int i = 0;
- int idx[] = {1, 0, 3, 2};
- for (const auto& mr : resp.mr()) {
- // the connections are crossed, i.e.
- // local tx_message_buffer <---> remote rx_message_buffer_
- // local rx_message_buffer <---> remote tx_message_buffer_
- // local tx_ack_buffer <---> remote rx_ack_buffer_
- // local rx_ack_buffer <---> remote tx_ack_buffer_
- // hence idx[] = {1, 0, 3, 2}.
- RdmaBuffer* rb = rc->message_buffers_[idx[i]];
- RemoteMR rmr;
- rmr.remote_addr = mr.remote_addr();
- rmr.rkey = mr.rkey();
- rb->SetRemoteMR(rmr, false);
- i++;
+ Status s;
+ int attempts = 0;
+ static const int max_num_attempts = 5;
+ do {
+ s = client->GetRemoteAddress(&req, &resp);
+ // save obtained remote addresses
+ // connect to the remote channel
+ if (s.ok()) {
+ CHECK(worker_name.compare(resp.host_name()) == 0);
+ RdmaAddress ra;
+ ra.lid = resp.channel().lid();
+ ra.qpn = resp.channel().qpn();
+ ra.psn = resp.channel().psn();
+ ra.snp = resp.channel().snp();
+ ra.iid = resp.channel().iid();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0};
+ for (const auto& mr : resp.mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // hence idx[] = {1, 0}.
+ RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+ } else {
+ LOG(ERROR) << "Connecting to " << worker_name
+ << ": Got " << s.error_message() << ". Retrying ("
+ << (attempts + 1) << "/" << max_num_attempts << ")..." ;
+ if (++attempts == max_num_attempts) {
+ break;
+ }
+ worker_env_->env->SleepForMicroseconds(2000000);
}
- CHECK(i == RdmaChannel::kNumMessageBuffers);
- } else {
- LOG(ERROR) << s.error_message();
- }
+ } while (!s.ok());
+ RDMA_LOG(0) << "Connected to remote node " << worker_name;
delete client;
}
}
@@ -183,6 +198,138 @@ RdmaChannel* RdmaMgr::FindChannel(const string& name) {
return iter->second;
}
+bool IsGDRAvailable() {
+#if defined(__APPLE__)
+ return false;
+#elif defined(PLATFORM_WINDOWS)
+ return false;
+#else
+ std::ifstream ifs("/proc/modules");
+ string line;
+ while (std::getline(ifs, line)) {
+ auto sep = line.find(' ');
+ CHECK_NE(sep, std::string::npos);
+ if (line.substr(0, sep) == "nv_peer_mem") {
+ return true;
+ }
+ }
+ return false;
+#endif
+}
+
+int TryToReadNumaNode(ibv_device* device) {
+#if defined(__APPLE__)
+ LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
+ return 0;
+#elif defined(PLATFORM_WINDOWS)
+ // Windows support for NUMA is not currently implemented. Return node 0.
+ return 0;
+#else
+ VLOG(2) << "Trying to read NUMA node for device: " << device->name;
+ static const int kUnknownNumaNode = -1;
+
+ auto filename = string(device->ibdev_path) + "/device/numa_node";
+
+ std::ifstream ifs(filename.c_str());
+ string content;
+ CHECK(std::getline(ifs, content));
+
+ int32 value;
+ if (strings::safe_strto32(content, &value)) {
+ if (value < 0) {
+ LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
+ << value << "), but there must be at least one NUMA node"
+ ", so returning NUMA node zero";
+ return 0;
+ }
+ LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
+ return value;
+ }
+ return kUnknownNumaNode;
+#endif
+}
+
+void MRDeleter(ibv_mr* mr) {
+ if (mr) {
+ ibv_dereg_mr(mr);
+ }
+}
+
+// TODO(byronyi): remove this class duplicated from the one in
+// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
+class BasicCPUAllocator : public SubAllocator {
+ public:
+ ~BasicCPUAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ return port::AlignedMalloc(num_bytes, alignment);
+ }
+ void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
+};
+
+// TODO(byronyi): remove this class and its registration when the default
+// cpu_allocator() returns visitable allocator
+class BFCRdmaAllocator : public BFCAllocator {
+ public:
+ BFCRdmaAllocator()
+ : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
+ }
+};
+
+REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
+
+void RdmaMgr::InitAllocators() {
+ RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
+
+ Allocator* allocators[] = {
+#if GOOGLE_CUDA
+ ProcessState::singleton()->GetCUDAHostAllocator(0),
+ ProcessState::singleton()->GetCPUAllocator(0),
+#endif // GOOGLE_CUDA
+ cpu_allocator(),
+ };
+
+ using namespace std::placeholders;
+
+ std::set<Allocator*> instrumented_;
+
+ // Host memory allocators
+ for (Allocator* allocator : allocators) {
+ VisitableAllocator::Visitor alloc_visitor =
+ std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
+ &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
+ VisitableAllocator::Visitor free_visitor = std::bind(
+ &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
+
+ auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
+ CHECK(visitable_allocator) << "is not visitable for instrumentation"
+ << allocator->Name();
+ // Make sure we don't instrument the same allocator twice
+ if (instrumented_.find(allocator) == std::end(instrumented_)) {
+ visitable_allocator->AddAllocVisitor(alloc_visitor);
+ visitable_allocator->AddFreeVisitor(free_visitor);
+ instrumented_.insert(allocator);
+ LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
+ }
+ }
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ // Note we don't free allocated GPU memory so there is no free visitor
+ int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
+
+ char buf[8];
+ sprintf(buf, "gpu");
+ VisitableAllocator::Visitor cuda_alloc_visitor =
+ std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
+ &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
+
+ ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+}
+
} // end namespace tensorflow
#endif