diff options
author | Shanqing Cai <cais@google.com> | 2017-12-06 18:43:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-06 18:47:41 -0800 |
commit | fe8406149feec453250905965a14285465cd2063 (patch) | |
tree | be3cd75d543f3c0f29f368da61d915abbae7fcbf /tensorflow/contrib/verbs | |
parent | 8ad62af489df718992561710123bc8c037e7d17b (diff) |
Merge changes from github.
PiperOrigin-RevId: 178185697
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r-- | tensorflow/contrib/verbs/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 98 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/rdma.h | 29 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/rdma_mgr.cc | 51 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/rdma_mgr.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 46 | ||||
-rw-r--r-- | tensorflow/contrib/verbs/verbs_server_lib.cc | 5 |
7 files changed, 176 insertions, 64 deletions
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 746ff38b37..38a84ffb10 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -7,6 +7,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + exports_files(["LICENSE"]) filegroup( @@ -97,7 +99,7 @@ cc_library( alwayslink = 1, ) -cc_library( +tf_cuda_library( name = "rdma_rendezvous_mgr", srcs = ["rdma_rendezvous_mgr.cc"], hdrs = ["rdma_rendezvous_mgr.h"], @@ -130,7 +132,7 @@ cc_library( ], ) -cc_library( +tf_cuda_library( name = "rdma", srcs = ["rdma.cc"], hdrs = ["rdma.h"], diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index ac8d994502..ae9a384565 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -18,11 +18,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/rdma.h" #include <fcntl.h> #include <cstdlib> +#include <fcntl.h> #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/rendezvous.h" @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { @@ -418,9 +422,6 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) 0); CHECK(cq_) << "Failed to create completion queue"; CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification"; - polling_thread_.reset(Env::Default()->StartThread( - ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); - VLOG(2) << "Start RdmaAdapter: " << name(); } RdmaAdapter::~RdmaAdapter() { @@ -432,6 +433,12 @@ RdmaAdapter::~RdmaAdapter() { CHECK(!ibv_close_device(context_)) << "Failed to release context"; } +void RdmaAdapter::StartPolling() { + polling_thread_.reset(Env::Default()->StartThread( + ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); + VLOG(2) << "Start RdmaAdapter: " << name(); +} + string RdmaAdapter::name() const { return string(context_->device->name); } // Function to process incoming messages @@ -452,9 +459,9 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" - << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " - << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err; + << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " + << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " " + << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id); // put back a recv wr. @@ -557,9 +564,44 @@ void RdmaAdapter::Process_CQ() { } } +int RdmaChannel::PingPostRecv() { + struct ibv_recv_wr wr, *bad_wr; + memset(&wr, 0, sizeof(wr)); + wr.sg_list = &ping_sge_list_; + wr.num_sge = 1; + wr.wr_id = kPingRecvWrid; + + return ibv_post_recv(qp_, &wr, &bad_wr); +} + +int RdmaChannel::PingPostSend() { + struct ibv_send_wr wr, *bad_wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = (uint64_t) this; + wr.sg_list = &ping_sge_list_; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = IBV_SEND_SIGNALED; + + return ibv_post_send(qp_, &wr, &bad_wr); +} + RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, const string remote_name) : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) { + + struct ibv_sge list; + + mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize, + IBV_ACCESS_LOCAL_WRITE); + CHECK(mr_) << "Failed to register memory region"; + + memset(&list, 0, sizeof(list)); + list.addr = (uintptr_t)ping_buff_; + list.length = kPingBuffSize; + list.lkey = mr_->lkey; + + ping_sge_list_ = list; // Create queue pair { struct ibv_qp_init_attr attr; @@ -610,7 +652,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + "tx_ack_buffer", "rx_ack_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); @@ -632,15 +674,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, buffer_index_name_table_.insert({index, buffer_names[i]}); buffer_name_index_table_.insert({buffer_names[i], index}); } - - // Initiate recv - for (int i = 0; i < 100; i++) { - Recv(); - } } + CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_ + << " with error " << std::strerror(errno); } RdmaChannel::~RdmaChannel() { + ibv_dereg_mr(mr_); CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; delete tx_message_buffer_; delete rx_message_buffer_; @@ -671,7 +711,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -825,11 +865,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | + IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); @@ -840,10 +880,10 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; @@ -930,7 +970,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -1025,9 +1065,10 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { - CHECK(send_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); +#if GOOGLE_CUDA + CHECK(send_args.device_context) << "send dev name: " << src_dev->name() + << " gpu_info: " + << src_dev->tensorflow_gpu_device_info(); if (can_memcpy) { AllocatorAttributes host_alloc_attrs; @@ -1053,8 +1094,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( // aync instead GPUUtil::SetProtoFromGPU( in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead, send_args, recv_args](const Status& s) mutable { CHECK(s.ok()) << "copy proto from gpu sync"; auto tensor_bytes = proto.ByteSize(); buffer_size += tensor_bytes; @@ -1063,6 +1104,7 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( &proto, NULL, send_args, recv_args); }); } +#endif // GOOGLE_CUDA } else { // tensor is in CPU memory. StringPiece copy_buf; diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index 00217c81d4..fea2327d77 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -67,9 +67,20 @@ struct RemoteMR { uint64_t remote_addr; uint32_t rkey; }; -enum BufferStatus { none, idle, busy }; -enum Location { local, remote }; -enum BufferType { ACK, MESSAGE, TENSOR }; +enum BufferStatus { + none, + idle, + busy +}; +enum Location { + local, + remote +}; +enum BufferType { + ACK, + MESSAGE, + TENSOR +}; enum RdmaMessageType { RDMA_MESSAGE_ACK, RDMA_MESSAGE_BUFFER_IDLE, @@ -96,6 +107,7 @@ class RdmaAdapter { ~RdmaAdapter(); // Adapter name, e.g. mlx5_0. string name() const; + void StartPolling(); void Process_CQ(); protected: @@ -150,6 +162,15 @@ class RdmaChannel { void RemoveRecvCallback(const string& key); void RunRecvCallback(const string& key); static const int kNumMessageBuffers = 4; + static const int kPingRecvWrid = 0; + + private: + static const int kPingBuffSize = 1024; + char ping_buff_[kPingBuffSize]; + struct ibv_mr* mr_; + struct ibv_sge ping_sge_list_; + int PingPostRecv(); + int PingPostSend(); protected: const RdmaAdapter* adapter_; @@ -202,7 +223,7 @@ class RdmaBuffer { } void FreeBuffer(); void EnqueueItem(string Item); - virtual void SendNextItem(){}; + virtual void SendNextItem() {}; void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); uint32_t LookupBufferIndex(const string& buffer_name) { diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 09b878843f..9cb307bcfa 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -115,6 +115,57 @@ void RdmaMgr::SetupChannels() { } } +// Check connectivity by pinging every channel +bool RdmaMgr::ConnectivityCheck() { + int i, rcnt = 0, scnt = 0; + + for (const auto& p : channel_table_) { + string worker_name = p.first; + RdmaChannel* rc = p.second; + + VLOG(2) << "Ping to " << worker_name; + CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name + << " with error: " << std::strerror(errno); + for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) { + rc->Recv(); + } + } + + while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) { + int ne; + do { + ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_, + rdma_adapter_->wc_); + CHECK(ne >= 0) << "poll CQ failed " << ne << "with error" + << std::strerror(errno); + } while (ne < 1); + + for (i = 0; i < ne; ++i) { + ibv_wc_status s = rdma_adapter_->wc_[i].status; + // recv complete + if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) { + CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( + rdma_adapter_->wc_[i].status) + << "(" << rdma_adapter_->wc_[i].status + << ") for PING_RECV_WRID"; + ++rcnt; + // send complete + } else { + RdmaChannel* rc = + reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id); + CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( + rdma_adapter_->wc_[i].status) + << "(" << rdma_adapter_->wc_[i].status + << ") to " << rc->remote_name_; + ++scnt; + } + } // for + } // while + CHECK(rcnt == scnt) << "Connectivity check failed!"; + rdma_adapter_->StartPolling(); + return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt); +} + RdmaMgr::~RdmaMgr() { for (const auto& p : channel_table_) delete p.second; channel_table_.clear(); diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h index b156f64096..e711e60478 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.h +++ b/tensorflow/contrib/verbs/rdma_mgr.h @@ -28,12 +28,16 @@ limitations under the License. namespace tensorflow { class RdmaMgr { + friend class RdmaChannel; + friend class RdmaAdapter; + public: explicit RdmaMgr(const WorkerEnv* const worker_env, GrpcChannelCache* const channel_cache); ~RdmaMgr(); RdmaChannel* FindChannel(const string& key); void SetupChannels(); + bool ConnectivityCheck(); const string& local_worker() { return local_worker_; } private: @@ -44,7 +48,6 @@ class RdmaMgr { RdmaAdapter* rdma_adapter_; typedef std::unordered_map<string, RdmaChannel*> ChannelTable; ChannelTable channel_table_; - TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr); }; diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index ce82ca2883..74f6681af3 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#endif // GOOGLE_CUDA #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -58,20 +60,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( // parse src_name and dst_name string src_name, dst_name, unused; if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name, + &unused) || + !DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, &unused)) { - s = errors::Internal("Could not parse src name."); + s = errors::Internal("Could not parse src or dst name."); } - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor{}, false); - return; - } - if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, - &unused)) { - s = errors::Internal("Could not parse dst name."); - } - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); if (!s.ok()) { + LOG(ERROR) << "s is not ok, error code " << s.error_message(); done(s, Args(), recv_args, Tensor{}, false); return; } @@ -82,18 +77,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( // insert callback rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc, recv_args, parsed, done]() { - Status s; - Device* src_dev; - s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor(), true); - return; - } - Device* dst_dev; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { + Status src_s, dst_s, s; + Device* src_dev, *dst_dev; + src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); + dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); + if (!src_s.ok() || !dst_s.ok()) { + s = src_s.ok() ? dst_s : src_s; + LOG(ERROR) << "s is not ok, error code " << s.error_message(); done(s, Args(), recv_args, Tensor(), true); return; } @@ -110,9 +100,10 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( if (can_memcpy) { if (dst_dev->tensorflow_gpu_device_info() && (!recv_args.alloc_attrs.on_host())) { +#if GOOGLE_CUDA CHECK(recv_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); @@ -122,14 +113,15 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( GPUUtil::CopyCPUTensorToGPU( ©, recv_args.device_context, dst_dev, &gpu_copy, - [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, - rc](const Status& s) { + [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc]( + const Status& s) { CHECK(s.ok()) << "copy tensor to gpu sync"; Tensor val; val = std::move(gpu_copy); RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s); }); +#endif // GOOGLE_CUDA return; } else { AllocatorAttributes host_alloc_attrs; diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 6d1c79c0fb..a606ef75a4 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() { Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, GrpcChannelCache** channel_cache) { string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", - "/task:", server_def.task_index()); + strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:", + server_def.task_index()); GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); @@ -103,6 +103,7 @@ Status VerbsServer::Start() { ThreadOptions(), "TF_verbs_service", [this] { verbs_service_->HandleRPCsLoop(); })); rdma_mgr_->SetupChannels(); + CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!"; verbs_state_ = CONNECTED; } } |