diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 1348 |
1 files changed, 829 insertions, 519 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index ae9a384565..ec5271abe0 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -16,18 +16,19 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS #include "tensorflow/contrib/verbs/rdma.h" -#include <fcntl.h> +#include "tensorflow/contrib/verbs/verbs_service.pb.h" #include <cstdlib> #include <fcntl.h> -#include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" #endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -41,32 +42,19 @@ namespace tensorflow { #define RoCE_V2 "RoCE v2" namespace { -// hash name to 32-bit integer -uint32_t NameHash(const string& name) { - return Hash32(name.data(), name.size(), 0x1234ABCD); -} // convenience function for printing message string MessageTypeToString(RdmaMessageType rmt) { switch (rmt) { - case RDMA_MESSAGE_ACK: - return "RDMA_MESSAGE_ACK"; - break; - case RDMA_MESSAGE_BUFFER_IDLE: - return "RDMA_MESSAGE_BUFFER_IDLE"; - break; - case RDMA_MESSAGE_BUFFER_REQUEST: - return "RDMA_MESSAGE_BUFFER_REQUEST"; + case RDMA_MESSAGE_META_DATA_UPDATE: + return "RDMA_MESSAGE_META_DATA_UPDATE"; break; - case RDMA_MESSAGE_BUFFER_RESPONSE: - return "RDMA_MESSAGE_BUFFER_RESPONSE"; + case RDMA_MESSAGE_TENSOR_RE_REQUEST: + return "RDMA_MESSAGE_TENSOR_RE_REQUEST"; break; case RDMA_MESSAGE_TENSOR_REQUEST: return "RDMA_MESSAGE_TENSOR_REQUEST"; break; - case RDMA_MESSAGE_TENSOR_WRITE: - return "RDMA_MESSAGE_TENSOR_WRITE"; - break; default: return "UNKNOWN MESSAGE"; } @@ -347,7 +335,7 @@ uint32_t set_param(uint32_t default_val, const char* env_param) { enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { ibv_port_attr port_attr; - enum ibv_mtu mtu; + enum ibv_mtu mtu = IBV_MTU_512; string mtu_s; int rc, mtu_i; @@ -468,97 +456,70 @@ void RdmaAdapter::Process_CQ() { rc->Recv(); // imm_data is the index of RX buffer in the buffer table. uint32_t imm_data = wc_[i].imm_data; - RdmaBuffer* rb = rc->FindBuffer(imm_data); + RdmaMessageBuffer* rb; RdmaMessage rm; - RdmaMessage::ParseMessage(rm, rb->buffer_); - VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_); - if (rm.type_ == RDMA_MESSAGE_ACK) { + if (imm_data == RDMA_IMM_DATA_ACK) { // receive an ack to a message rb = rc->tx_message_buffer_; rb->SetBufferStatus(remote, idle); rb->SendNextItem(); - } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) { - // received a request-for-tensor message - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find or create buffer - RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_); - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); - tb->EnqueueItem(key_with_step_id); - // send the next tensor - worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) { - // receive tensor-buffer-ready message - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find buffer - RdmaTensorBuffer* tb = - reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_)); - tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) { - // remote host requests to create a tensor buffer; - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find or create the buffer - RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR); - RemoteMR rmr; - rmr.remote_addr = rm.remote_addr_; - rmr.rkey = rm.rkey_; - tb->SetRemoteMR(rmr, true); - tb->CreateCPUBuffer(rm.buffer_size_); - // create RDMA_MESSAGE_BUFFER_RESPONSE message - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE; - br.name_size_ = rm.name_.size(); - br.name_ = rm.name_; - br.buffer_size_ = rm.buffer_size_; - br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_); - br.rkey_ = tb->self_->rkey; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* mb = rc->tx_message_buffer_; - mb->EnqueueItem(message); - mb->SendNextItem(); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) { - // remote creates a buffer and responds - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find buffer - RdmaTensorBuffer* tb = - reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_)); - CHECK(rm.buffer_size_ == tb->size_) - << "rm.buffer_size = " << rm.buffer_size_ - << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_; - RemoteMR rmr; - rmr.remote_addr = rm.remote_addr_; - rmr.rkey = rm.rkey_; - tb->SetRemoteMR(rmr, true); - tb->SetBufferStatus(local, idle); - tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { - // tensor RDMA write completed - worker_env_->compute_pool->Schedule([rm, rc]() { - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); - rc->RunRecvCallback(key_with_step_id); - }); + continue; } - } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { - RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id); - rb->SetBufferStatus(local, idle); - RdmaMessage rm; + + if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) { + // receive a tensor RDMA write + uint32_t request_index = imm_data; + RdmaTensorRequest* request = rc->GetTensorRequest(request_index); + request->RecvTensorContent(); + continue; + } + + // receive a control message + rb = rc->rx_message_buffer_; RdmaMessage::ParseMessage(rm, rb->buffer_); - VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_); - if (rm.type_ != RDMA_MESSAGE_ACK) { - worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); }); + RdmaMessageBuffer::SendAck(rc); + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Received " << MessageTypeToString(rm.type_) << " " + << "#" << rm.request_index_ << ": " << rm.name_; + + if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) { + RdmaTensorResponse* response = rc->AddTensorResponse(rm); + response->Start(); + } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) { + RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); + request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_, + rm.is_dead_, rm.tensor_bytes_); +#ifdef RDMA_DATA_VALIDATION + request->RecvTensorChecksum(rm.checksum_); +#endif + } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) { + RdmaTensorResponse* response = rc->UpdateTensorResponse(rm); + response->Resume(); + } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); + request->RecvErrorStatus(rm.status_); + } + } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { + RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id); + RDMA_LOG(2) << "Write complete of type " << wr_id->write_type; + switch (wr_id->write_type) { + case RDMA_WRITE_ID_ACK: + break; + case RDMA_WRITE_ID_MESSAGE: { + RdmaMessageBuffer* rb = + reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context); + rb->SetBufferStatus(local, idle); + rb->SendNextItem(); + break; + } + case RDMA_WRITE_ID_TENSOR_WRITE: { + RdmaTensorResponse* response = + reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context); + response->Destroy(); + } } + delete wr_id; } } } @@ -588,8 +549,10 @@ int RdmaChannel::PingPostSend() { RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, const string remote_name) - : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) { - + : adapter_(adapter), + local_name_(local_name), + remote_name_(remote_name), + request_serial_(0) { struct ibv_sge list; mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize, @@ -651,29 +614,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { - const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); - tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); - rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]); message_buffers_.reserve(kNumMessageBuffers); message_buffers_.push_back(tx_message_buffer_); message_buffers_.push_back(rx_message_buffer_); - message_buffers_.push_back(tx_ack_buffer_); - message_buffers_.push_back(rx_ack_buffer_); // create buffer on host tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); - tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize); - rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize); - // bt_mu_.lock() is not used in constructor. - for (int i = 0; i < kNumMessageBuffers; i++) { - uint32_t index = NameHash(buffer_names[i]); - buffer_table_.insert({index, message_buffers_[i]}); - buffer_index_name_table_.insert({index, buffer_names[i]}); - buffer_name_index_table_.insert({buffer_names[i], index}); - } } CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_ << " with error " << std::strerror(errno); @@ -684,8 +633,6 @@ RdmaChannel::~RdmaChannel() { CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; delete tx_message_buffer_; delete rx_message_buffer_; - delete tx_ack_buffer_; - delete rx_ack_buffer_; } void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { @@ -716,114 +663,31 @@ void RdmaChannel::Recv() { CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } -// Lookup 32-bit buffer index from buffer name -// Args: -// buffer_name: name of the buffer -// Returns: -// 32-bit index -uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) { - mutex_lock lock{bt_mu_}; - BufferNameIndexTable::iterator iter = - buffer_name_index_table_.find(buffer_name); - CHECK(iter != buffer_name_index_table_.end()); - return iter->second; -} - -// Find a buffer by its 32-bit index -// Args: -// index: 32-bit hash code of the tensor buffer name -// Returns: -// name of the tensor buffer -RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) { - mutex_lock lock{bt_mu_}; - BufferTable::iterator iter = buffer_table_.find(index); - CHECK(iter != buffer_table_.end()); - return iter->second; -} - -// Find a buffer by its name -// Args: -// name: name of the buffer -// Returns: -// the named rdma buffer -RdmaBuffer* RdmaChannel::FindBuffer(const string& name) { - uint32_t index = LookupBufferIndex(name); - return FindBuffer(index); -} - -// Find a buffer if it exists, otherwise create one. -// The memory inside the created buffer is not allocated. -// Args: -// name: the name of the buffer -// buffer_type: TENSOR, MESSAGE or ACK. -// Returns: -// the named buffer -RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, - BufferType buffer_type) { - mutex_lock lock{bt_mu_}; - RdmaBuffer* rb; - // find index - BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name); - if (iter != buffer_name_index_table_.end()) { - uint32_t index = iter->second; - // find buffer - BufferTable::iterator iter = buffer_table_.find(index); - CHECK(iter != buffer_table_.end()); - rb = iter->second; - } else { - uint32_t index = NameHash(name); - if (buffer_type == TENSOR) { - rb = new RdmaTensorBuffer(this, name); - } else if (buffer_type == MESSAGE) { - rb = new RdmaMessageBuffer(this, name); - } else if (buffer_type == ACK) { - rb = new RdmaAckBuffer(this, name); - } - buffer_name_index_table_.insert({name, index}); - buffer_index_name_table_.insert({index, name}); - buffer_table_.insert({index, rb}); +RdmaTensorRequest* RdmaChannel::InsertTensorRequest( + const string& key, int64 step_id, Device* dst_dev, + const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done) { + mutex_lock lock{ct_mu_}; + uint32_t request_index = request_serial_++; + if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) { + request_serial_ = 0; } - CHECK(rb); - return rb; + RdmaTensorRequest request(request_index, key, step_id, this, dst_dev, + recv_args, done); + auto it = request_table_.emplace(request_index, request); + return &it.first->second; } -// Insert callback to the callback_table. -// The callback is activated when the corresponding tensor is received. -// Arg: -// key: the name of the tensor -// recv_done: the callback associated with the tensor. -// Returns: -// None -void RdmaChannel::InsertRecvCallback(const string& key, - std::function<void()> recv_done) { +void RdmaChannel::RemoveTensorRequest(uint32_t request_index) { mutex_lock lock{ct_mu_}; - callback_table_.insert({key, recv_done}); + request_table_.erase(request_index); } -// Remove callback from the callback_table. -// Arg: -// key: the name of the tensor -// Returns: -// None -void RdmaChannel::RemoveRecvCallback(const string& key) { +RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) { mutex_lock lock{ct_mu_}; - callback_table_.erase(key); -} - -// Run named callback in the callback_table. -// Arg: -// key: the name of the tensor -// Returns: -// None -void RdmaChannel::RunRecvCallback(const string& key) { - std::function<void()> recv_done; - { - mutex_lock lock{ct_mu_}; - CallbackTable::iterator iter = callback_table_.find(key); - CHECK(iter != callback_table_.end()); - recv_done = iter->second; - } - recv_done(); + RequestTable::iterator iter = request_table_.find(request_index); + CHECK(iter != request_table_.end()); + return &iter->second; } void RdmaChannel::Connect() { @@ -888,25 +752,22 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { connected_ = true; } else { - LOG(INFO) << "channel already connected"; + RDMA_LOG(2) << "channel already connected"; } } -RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name) +RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) : channel_(channel), name_(name) {} -RdmaBuffer::~RdmaBuffer() { +RdmaMessageBuffer::~RdmaMessageBuffer() { CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; FreeBuffer(); } -void RdmaBuffer::FreeBuffer() { +void RdmaMessageBuffer::FreeBuffer() { if ((buffer_ != nullptr) && buffer_on_host_) { free(buffer_); } - // TODO - // release buffer if it is on device. - // We don't support RDMABuffer on device at this moment. } // Allocate CPU memory for the Rdma buffer @@ -915,7 +776,7 @@ void RdmaBuffer::FreeBuffer() { // lock: whether or not mutex_lock the process to protect concurrency. // Returns: // None -void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { +void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) { CHECK(size > 0); if (lock) { mu_.lock(); @@ -943,7 +804,7 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { // override: whether override existing information // Returns: // None -void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) { +void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) { mutex_lock lock{mu_}; if ((override) || (remote_status_ == none)) { remote_.remote_addr = rmr.remote_addr; @@ -956,63 +817,51 @@ void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) { } // Put a task in the buffer's job queue -void RdmaBuffer::EnqueueItem(string item) { +void RdmaMessageBuffer::EnqueueItem(string item) { mutex_lock lock{mu_}; queue_.push(item); } // Rdma-Write the content of the buffer -void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { +void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) { + Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey, + remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this); +} + +// Generalized Write method +void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data, + size_t buffer_size, uint64_t src_addr, + uint32_t lkey, uint64_t remote_addr, + uint32_t rkey, RdmaWriteIDType write_type, + void* write_context) { struct ibv_sge list; - list.addr = (uint64_t)buffer_; + list.addr = src_addr; list.length = buffer_size; - list.lkey = self_->lkey; + list.lkey = lkey; struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context); wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wr.send_flags = IBV_SEND_SIGNALED; wr.imm_data = imm_data; - wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr; - wr.wr.rdma.rkey = remote_.rkey; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = rkey; struct ibv_send_wr* bad_wr; - CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send"; -} - -RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaTensorBuffer::~RdmaTensorBuffer() { - for (Itable it = retable.begin(); it != retable.end(); ++it) { - delete (it->second); - } + CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send"; } // Send the next ack from the buffer's job queue. -void RdmaAckBuffer::SendNextItem() { - uint32_t imm_data = LookupBufferIndex("rx_ack_buffer"); - RdmaMessage rm; - rm.name_ = "rx_ack_buffer"; - rm.type_ = RDMA_MESSAGE_ACK; - rm.name_size_ = rm.name_.size(); - string message = RdmaMessage::CreateMessage(rm); - memcpy(buffer_, message.data(), message.size()); - Write(imm_data, message.size()); +void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) { + Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr); } // Send the next message from the buffer's job queue. void RdmaMessageBuffer::SendNextItem() { - uint32_t imm_data = LookupBufferIndex("rx_message_buffer"); + uint32_t imm_data = RDMA_IMM_DATA_MESSAGE; mu_.lock(); if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) { local_status_ = busy; @@ -1029,244 +878,392 @@ void RdmaMessageBuffer::SendNextItem() { } } -Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( - const string& key_with_step_id, const string& key, int64 step_id, - const Rendezvous::ParsedKey& parsed) { - Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, parsed]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { - CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id - << " error message: " << status.error_message(); - size_t buffer_size = RdmaMessage::kMessageTotalBytes; - size_t tensor_bytes = 0; - // Figures out which device the tensor is hosted on. - Device* src_dev = nullptr; - Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice( - parsed.src_device, &src_dev); - CHECK(s.ok()) << "src device not found"; - // Does the device have the right incarnation number we expect? - CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation) - << "RecvTensor expects a different device incarnation: " - << parsed.src_incarnation << " vs. " - << src_dev->attributes().incarnation() - << ". Your worker job was probably restarted. Check your " - << "worker job for the reason why it was restarted."; - Device* dst_dev = nullptr; - // destination is on CPU. - s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0", - &dst_dev); - CHECK(s.ok()) << "dst device not found"; - AllocatorAttributes dst_alloc_attr; - dst_alloc_attr.set_on_host(true); - - bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); - // string tensor needs to be serialized - Tensor copy; - TensorProto proto; - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { #if GOOGLE_CUDA - CHECK(send_args.device_context) << "send dev name: " << src_dev->name() - << " gpu_info: " - << src_dev->tensorflow_gpu_device_info(); - - if (can_memcpy) { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - copy = Tensor(alloc, in.dtype(), in.shape()); - tensor_bytes = in.TotalBytes(); - buffer_size += tensor_bytes; - GPUUtil::CopyGPUTensorToCPU( - src_dev, send_args.device_context, &in, ©, - [this, copy, tensor_bytes, buffer_size, key, in, step_id, - key_with_step_id, is_dead, send_args, recv_args](const Status& s) { - CHECK(s.ok()) << "copy tensor from gpu sync"; - StringPiece copy_buf; - copy_buf = copy.tensor_data(); - PostCopyOperations(true, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, ©, - NULL, ©_buf, send_args, recv_args); - }); - } else { - // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use - // aync instead - GPUUtil::SetProtoFromGPU( - in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { - CHECK(s.ok()) << "copy proto from gpu sync"; - auto tensor_bytes = proto.ByteSize(); - buffer_size += tensor_bytes; - PostCopyOperations(false, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, NULL, - &proto, NULL, send_args, recv_args); - }); - } -#endif // GOOGLE_CUDA - } else { - // tensor is in CPU memory. - StringPiece copy_buf; - if (can_memcpy) { - copy_buf = in.tensor_data(); - tensor_bytes = in.TotalBytes(); - } else { - in.AsProtoTensorContent(&proto); - tensor_bytes = proto.ByteSize(); - } - buffer_size += tensor_bytes; - PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, ©, &proto, - ©_buf, send_args, recv_args); +static void CountCopies(const std::string& key, void* src_addr, void* dst_addr, + size_t tensor_bytes, bool is_gpu_to_cpu) { +#ifdef RDMA_COUNT_COPIES + static uint64_t numGPUToCPUCopies = 0; + static uint64_t numGPUToCPUCopiedBytes = 0; + static uint64_t numCPUToGPUCopies = 0; + static uint64_t numCPUToGPUCopiedBytes = 0; + static uint64_t numTotalCopies = 0; + + if (is_gpu_to_cpu) { + ++numGPUToCPUCopies; + numGPUToCPUCopiedBytes += tensor_bytes; + } else { + ++numCPUToGPUCopies; + numCPUToGPUCopiedBytes += tensor_bytes; + } + if ((++numTotalCopies % 0x400) == 0) { + RDMA_LOG(0) << "Tensor copies:" + << " GPU to CPU: " << numGPUToCPUCopies + << " (" << numGPUToCPUCopiedBytes << " Bytes)" + << " CPU to GPU: " << numCPUToGPUCopies + << " (" << numCPUToGPUCopiedBytes << " Bytes)"; + } + RDMA_LOG(2) << "Copying tensor " << key + << " From: " << src_addr << " To: " << dst_addr; +#endif // RDMA_COUNT_COPIES +} +#endif // GOOGLE_CUDA + +#ifdef RDMA_DATA_VALIDATION +static uint64_t Checksum(Device* device, const DeviceContext* device_context, + const Tensor& in) { + uint64 checksum = 0; + if (DataTypeCanUseMemcpy(in.dtype())) { +#if GOOGLE_CUDA + if (in.TotalBytes() == 0) { + return 0; } - }; - return cb; + checksum = (device_context != nullptr) + ? GPUUtil::Checksum(device, device_context, in) + : GPUUtil::Checksum(in); +#endif // GOOGLE_CUDA + } else { + string s = in.SummarizeValue(999999); + checksum = Hash64(s.c_str(), s.size(), 0); + } + return checksum; } -// Send the next tensor from the buffer's job queue. -void RdmaTensorBuffer::SendNextItem() { - // get the key - string key_with_step_id = ""; - { - mutex_lock lock{mu_}; - if (!queue_.empty()) { - key_with_step_id = queue_.front(); - queue_.pop(); +static void ValidateChecksum(uint64_t expected, uint64_t actual, + const Tensor& in, uint32_t request_index, + const std::string& key, const std::string& msg) { + RDMA_LOG(2) << "Request #" << request_index << ": " << key + << ": Checksum: " << std::hex << " Expected = 0x" << expected + << ". Actual = 0x" << actual << "."; + + if (expected != actual) { + // Checksum failed. There is one case where this is allowed - if the + // tensor is an AssignAdd of the global step. Since the data-validation + // always postpones the Tensor response in order to send a checksum message, + // it is possible that the global-step was updated while the response was + // still in queue. + if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) { + int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1; + actual = Hash64((const char*)&prev_val, 8, 0); + } + if (expected != actual) { + LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #" + << request_index << ": " << key << std::hex << " " + << DataTypeString(in.dtype()) << " " + << in.shape().DebugString() << " (0x" << in.TotalBytes() + << " bytes): " + << " Expected 0x" << expected << ". Got 0x" << actual << "."; } } +} +#endif // RDMA_DATA_VALIDATION + +#if GOOGLE_CUDA +// Sync the 'done' operation on the GPU stream, but without all the data +// copying. +static void StreamGPUOp(Device* gpu_device, + const DeviceContext* device_context, + StatusCallback done) { + Tensor dummy1, dummy2; + GPUUtil::CopyGPUTensorToCPU( + gpu_device, device_context, &dummy1, &dummy2, done); +} +#endif // GOOGLE_CUDA + +RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) { + mutex_lock lock{mu_}; + auto it = + responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm)); + CHECK(it.second) << "Response with the ID " << rm.request_index_ + << " already exists."; + return &it.first->second; +} + +RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) { + mutex_lock lock{mu_}; + auto it = responses_table_.find(rm.request_index_); + CHECK(it != responses_table_.end()) << "No response found."; + RdmaTensorResponse* response = &it->second; + response->Update(rm); + return response; +} - // send the tensor if a key is acquired. - if (key_with_step_id != "") { - VLOG(2) << "try to send tensor: " << key_with_step_id; - string key; - int64 step_id; - VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id); - CHECK(key.compare(name_) == 0); - Rendezvous::ParsedKey parsed; - Rendezvous::ParseKey(key, &parsed); - Rendezvous::DoneCallback cb = - getRecvTensorCallback(key_with_step_id, key, step_id, parsed); - channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id, - parsed, cb); +void RdmaChannel::RemoveTensorResponse(uint32_t request_index) { + mutex_lock lock{mu_}; + responses_table_.erase(request_index); +} + +void RdmaTensorResponse::Start() { + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(rm_.name_, &parsed); + if (!s.ok()) { + SendErrorStatus(s); + return; } + + channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync( + rm_.step_id_, parsed, + [this, parsed](const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + bool is_dead) { + CHECK(status.ok()) << "RecvLocalAsync was not ok." + << " error message: " << status.error_message(); + RecvHandler(parsed, send_args, recv_args, in, is_dead); + }); } -void RdmaTensorBuffer::ReSendNextItem() { - // get the key - string key_with_step_id = ""; - { - mutex_lock lock{mu_}; - if (!requeue.empty()) { - key_with_step_id = requeue.front(); - requeue.pop(); - } +void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); } + +// Helper for RecvTensor. Validates "key" and returns the source +// device in "*src_dev". +Status RdmaTensorResponse::PrepareRecvTensor( + const Rendezvous::ParsedKey& parsed, Device** src_dev) { + // Figures out which device the tensor is hosted on. + string local_name = DeviceNameUtils::LocalName(parsed.src_device); + TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice( + local_name, src_dev)); + + // Does the device have the right incarnation number we expect? + if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { + return errors::Aborted( + "RecvTensor expects a different device incarnation: ", + parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), + ". Your worker job was probably restarted. Check your " + "worker job for the reason why it was restarted."); } - // send the tensor if a key is acquired. - if (key_with_step_id != "") { - VLOG(2) << "try to send tensor: " << key_with_step_id; - string key; - int64 step_id; - VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id); - CHECK(key.compare(name_) == 0); - Rendezvous::ParsedKey parsed; - Rendezvous::ParseKey(key, &parsed); - Rendezvous::DoneCallback cb = - getRecvTensorCallback(key_with_step_id, key, step_id, parsed); - ReItem* item; - { - mutex_lock lock{mu_}; - Itable it = retable.find(key_with_step_id); - CHECK(it != retable.end()) << "Could not find dup-recv context"; - item = it->second; - retable.erase(it); + return Status::OK(); +} + +void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { + Status s = PrepareRecvTensor(parsed, &src_dev_); + if (!s.ok()) { + SendErrorStatus(s); + return; + } + + meta_data_changed_ = TensorMetaDataChanged(in, is_dead); +#ifdef RDMA_DATA_VALIDATION + // Always send a meta data message with the source checksum + meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST; + checksum_ = Checksum(src_dev_, send_args.device_context, in); +#endif + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + // string tensor needs to be serialized + Tensor copy; + TensorProto proto; + const bool on_host = send_args.alloc_attrs.on_host(); + if (src_dev_->tensorflow_gpu_device_info() && !on_host) { +#if GOOGLE_CUDA + DeviceContext* send_dev_context = send_args.device_context; + CHECK(send_dev_context) + << "send dev name: " << src_dev_->name() + << " gpu_info: " << src_dev_->tensorflow_gpu_device_info(); + + if (can_memcpy) { + // If the tensor is located on a GDR compatible GPU, there is no need to + // copy it. We can send directly from the source, just need to make sure + // we are in sync with the GPU stream. + // If the tensor's meta-data changed however, we will need to clone it, + // so anyway we'll have to copy it from GPU to CPU first. If at some + // point in time Clone() is changed to only save a shallow copy, we can + // skip the copy here as well. + if ((in.TotalBytes() > 0) && !meta_data_changed_ && + (RdmaMemoryMgr::Singleton().FindMemoryRegion( + (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) { + StreamGPUOp(src_dev_, send_dev_context, + [this, in, proto, is_dead](const Status& s) { + Send(in, proto, is_dead, s); + }); + return; + } + + // The tensor must be copied from GPU to CPU, because either: + // 1. The tensor is located on a non GDR compatible GPU. + // 2. The tensor's meta-data has changed. + Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + copy = Tensor(alloc, in.dtype(), in.shape()); + CountCopies(rm_.name_, (void*)DMAHelper::base(&in), + (void*)DMAHelper::base(©), in.TotalBytes(), true); + GPUUtil::CopyGPUTensorToCPU( + src_dev_, send_dev_context, &in, ©, + [this, copy, proto, is_dead](const Status& s) { + Send(copy, proto, is_dead, s); + }); + } else { + GPUUtil::SetProtoFromGPU( + in, src_dev_, send_args.device_context, &proto, is_dead, + [this, in, proto, is_dead](const Status& s) mutable { + Send(in, proto, is_dead, s); + }); + } +#else + SendErrorStatus(errors::Internal("No GPU device in process")); +#endif // GOOGLE_CUDA + } else { + // tensor is in CPU memory. + if (!can_memcpy) { + in.AsProtoTensorContent(&proto); } - cb(Status::OK(), item->send_args, item->recv_args, item->in, item->is_dead); - delete (item); + Send(in, proto, is_dead, Status::OK()); + } +} + +void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto, + bool is_dead, const Status& status) { + if (!status.ok()) { + SendErrorStatus(status); + return; + } + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + bool proto_size_changed = (!can_memcpy) && + (proto.ByteSize() != rm_.tensor_bytes_); + if (meta_data_changed_ || proto_size_changed) { + Clone(in, proto, is_dead); + SendMetaData(in, proto, is_dead); + } else { + SendContent(in, proto, is_dead); + } +} + +bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) { + return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) || + (rm_.is_dead_ != is_dead); +} + +void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto, + bool is_dead) { + // Clone the data to be sent later. For simplicity, we clone the tensor's + // data even if it is already a copy. Performance is less of a concern here + // since the meta-data hardly ever changes. The reason we create a copy, is + // that some tensors share their buffer between different step-ids, so the + // tensor content may change before re-request was completed. + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + if (can_memcpy && (in.TotalBytes() > 0)) { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_nic_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs); + tensor_ = new Tensor(allocator, in.dtype(), in.shape()); + memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes()); + } else { + tensor_ = new Tensor(in.dtype(), in.shape()); } + if (!can_memcpy) { + proto_ = new TensorProto(proto); + } + is_dead_ = is_dead; } -void RdmaTensorBuffer::PostCopyOperations( - bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key, - const Tensor& in, int64 step_id, bool is_dead, - const string& key_with_step_id, const Tensor* copy, - const TensorProto* proto, const StringPiece* copy_buf, - const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args) { - // prepare message +void RdmaTensorResponse::SendMetaData(const Tensor& in, + const TensorProto& proto, bool is_dead) { + RDMA_LOG(2) << "Request #" << rm_.request_index_ + << ": Meta data changed: " << rm_.name_; + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); + + // Send meta-data update: RdmaMessage rm; - rm.name_size_ = key.size(); - rm.name_ = key; + rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE; + rm.name_size_ = rm_.name_.size(); + rm.name_ = rm_.name_; rm.tensor_shape_ = in.shape(); rm.data_type_ = in.dtype(); - rm.step_id_ = step_id; + rm.step_id_ = rm_.step_id_; rm.is_dead_ = is_dead; rm.tensor_bytes_ = tensor_bytes; - rm.buffer_size_ = buffer_size; - mu_.lock(); - if (local_status_ == none || (buffer_size > size_ && local_status_ == idle && - remote_status_ == idle)) { - if ((local_status_ != none) && (buffer_size > size_)) { - VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size; - } - CreateCPUBuffer(buffer_size, false); - // Need to be received again, put into the re-recv queue and the table - requeue.push(key_with_step_id); - ReItem* item = new ReItem(send_args, recv_args, in, is_dead); - retable.insert(std::pair<string, ReItem*>(key_with_step_id, item)); - mu_.unlock(); - // no longer used: put back the key since it is not sent; - // ask the remote to create the same buffer - rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST; - rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_); - rm.rkey_ = self_->rkey; - string message = RdmaMessage::CreateMessage(rm); - channel_->tx_message_buffer_->EnqueueItem(message); - channel_->tx_message_buffer_->SendNextItem(); - } else if ((local_status_ == idle) && (remote_status_ == idle)) { - // both buffers are ready, send the tensor - local_status_ = busy; - remote_status_ = busy; - // local/remote_status_ won't be set back to idle - // unitl Write() is successful - mu_.unlock(); - if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) || - (buffer_size <= size_ && rm.data_type_ == DT_STRING))) { - VLOG(2) << "Tensor and buffer size do not agree," - << " buffer_size = " << size_ - << " requested tensor size = " << buffer_size << in.DebugString(); - } - uint32_t imm_data = LookupBufferIndex(key); - rm.type_ = RDMA_MESSAGE_TENSOR_WRITE; - string message = RdmaMessage::CreateMessage(rm); - memcpy(buffer_, message.data(), message.size()); - if (!is_dead) { - // copy the tensor buffer content - void* output = static_cast<void*>(static_cast<char*>(buffer_) + - RdmaMessage::kTensorBufferStartIndex); - CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); - if (can_memcpy) { - CHECK(copy != NULL) << "callback missing pointer to copy tensor"; - CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer"; - CHECK(copy_buf->size() == tensor_bytes) - << "unexpected tensor size: " << copy_buf->size() - << " != " << tensor_bytes; - memcpy(output, copy_buf->data(), tensor_bytes); - } else { - CHECK(proto != NULL) << "callback missing pointer to proto tensor"; - proto->SerializeToArray(output, tensor_bytes); + rm.request_index_ = rm_.request_index_; +#ifdef RDMA_DATA_VALIDATION + rm.checksum_ = checksum_; +#endif + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #" + << rm.request_index_ << ": " << rm.name_ + << " (shape = " << rm.tensor_shape_.DebugString() << "." + << " data-type = " << DataTypeString(rm.data_type_) << "." + << " is-dead = " << rm.is_dead_ << ")"; + + string message = RdmaMessage::CreateMessage(rm); + channel_->tx_message_buffer_->EnqueueItem(message); + channel_->tx_message_buffer_->SendNextItem(); +} + +void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto, + bool is_dead) { + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); + uint32_t imm_data = rm_.request_index_; + if (!is_dead) { + if (can_memcpy) { + src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in)); + if (src_buffer_ != nullptr) { + src_buffer_->Ref(); // Keep buffer alive until write is complete + src_addr_ = src_buffer_->data(); + mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_, + tensor_bytes); } } else { - buffer_size = RdmaMessage::kMessageTotalBytes; + RDMA_LOG(2) << "Encoding proto: " << rm_.name_ + << " (Size: " << tensor_bytes << ") " << in.DebugString(); + src_addr_ = malloc(tensor_bytes); + mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + proto.SerializeToArray(src_addr_, tensor_bytes); } - Write(imm_data, buffer_size); } else { - // Need to be received again, put into the re-recv queue and the table - requeue.push(key_with_step_id); - ReItem* item = new ReItem(send_args, recv_args, in, is_dead); - retable.insert(std::pair<string, ReItem*>(key_with_step_id, item)); - mu_.unlock(); + tensor_bytes = 0; } + + uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey; + RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec + << ": Sending tensor content #" << rm_.request_index_ << " from " + << std::hex << src_addr_ << " (0x" << lkey << ")" + << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_ + << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes + << ")"; + + RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes, + (uint64_t)src_addr_, lkey, rm_.remote_addr_, + rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this); +} + +void RdmaTensorResponse::SendErrorStatus(const Status& status) { + RdmaMessage rm; + rm.type_ = RDMA_MESSAGE_ERROR_STATUS; + rm.name_size_ = rm_.name_.size(); + rm.name_ = rm_.name_; + rm.step_id_ = rm_.step_id_; + rm.request_index_ = rm_.request_index_; + rm.status_ = status; + LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending RDMA_MESSAGE_ERROR_STATUS #" + << rm.request_index_ << ": " << rm.name_ + << ". Status: " << status.ToString(); + + string message = RdmaMessage::CreateMessage(rm); + channel_->tx_message_buffer_->EnqueueItem(message); + channel_->tx_message_buffer_->SendNextItem(); + + // Destroy the response. + Destroy(); +} + +void RdmaTensorResponse::Destroy() { + if (src_buffer_ != nullptr) { + src_buffer_->Unref(); + } + if (tensor_ != nullptr) { + delete tensor_; + } + if (proto_ != nullptr) { + ibv_dereg_mr(mr_); + free(src_addr_); + delete proto_; + } + // Remove response from the pending list: + channel_->RemoveTensorResponse(rm_.request_index_); } // Create a RdmaMessage according to the pre-defined format @@ -1276,43 +1273,46 @@ void RdmaTensorBuffer::PostCopyOperations( // message in string format string RdmaMessage::CreateMessage(const RdmaMessage& rm) { // Rdma Message format - // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... - // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... - // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer - // ...| XB | XB | 8B |... + // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|... + // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... + // ...|data_type|tensor_shape|tensor_bytes|error_status | + // ...| XB | XB | 8B |size - 4B, proto - XB | // - // ACK: type|13|"rx_ack_buffer" - // TENSOR_REQUEST: type|name_size|tensor_name|step_id - // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead - // |data_type|tensor_shape|tensor_bytes - // BUFFER_IDLE: type|name_size|buffer_name - // BUFFER_REQUEST: - // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| - // BUFFER_RESPONSE: - // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| - char message[kMessageTotalBytes]; + // ACK: Imm-type: ACK + // TENSOR_REQUEST: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, remote_addr, + // rkey, is_dead, data_type, tensor_shape, tensor_bytes + // META_DATA_UPDATE: Imm-type: MESSAGE + // Fields: type, request_index, is_dead, data_type, + // tensor_shape, tensor_bytes + // TENSOR_RE_REQUST: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, remote_addr, + // rkey, is_dead, data_type, tensor_shape, tensor_bytes + // ERROR_STATUS: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, error_status + // Tensor content: Imm-type: request_index + size_t message_size = kMessageTotalBytes; + char message[kMessageTotalBytes + kErrorStatusMaxSize]; // type message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff; - // size of name - memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_)); - // name - memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); - // buffer_size, remote_addr, rkey - if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || - (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_, - sizeof(rm.buffer_size_)); + // request index + memcpy(&message[kRequestIndexStartIndex], &rm.request_index_, + sizeof(rm.request_index_)); + // name, step_id, remote_addr, rkey + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { + memcpy(&message[kNameSizeStartIndex], &rm.name_size_, + sizeof(rm.name_size_)); + memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_, sizeof(rm.remote_addr_)); memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_)); - } - // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || - (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_)); } // is_dead, data_type, tensor_shape, tensor_bytes - if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_)); memcpy(&message[kDataTypeStartIndex], &rm.data_type_, @@ -1322,7 +1322,31 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) { memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_, sizeof(rm.tensor_bytes_)); } - return string(message, kMessageTotalBytes); + // checksum +#ifdef RDMA_DATA_VALIDATION + memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_)); +#endif + // error status + if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + ::grpc::Status gs = ToGrpcStatus(rm.status_); + ErrorStatusProto gsProto; + gsProto.set_error_code(gs.error_code()); + gsProto.set_error_message(gs.error_message()); + gsProto.set_error_details(gs.error_details()); + uint32_t gsProtoSize = gsProto.ByteSize(); + if (gsProtoSize + 4 > kErrorStatusMaxSize) { + LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) " + << "is too big to fit in RDMA message (" + << kErrorStatusMaxSize << " bytes). Truncated."; + gsProtoSize = kErrorStatusMaxSize - 4; + } + uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex]; + *proto_size = gsProtoSize; + gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], + gsProtoSize); + message_size += gsProtoSize + 4; + } + return string(message, message_size); } // Parse a RdmaMessage according to the pre-defined format @@ -1335,26 +1359,24 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { char* message = static_cast<char*>(buffer); // type rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]); - // name_size_ - memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_)); - // name - rm.name_ = string(&message[kNameStartIndex], rm.name_size_); - // buffer_size, remote_addr, rkey - if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || - (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex], - sizeof(rm.buffer_size_)); + // request index + memcpy(&rm.request_index_, &message[kRequestIndexStartIndex], + sizeof(rm.request_index_)); + // name, step_id, remote_addr, rkey + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { + memcpy(&rm.name_size_, &message[kNameSizeStartIndex], + sizeof(rm.name_size_)); + rm.name_ = string(&message[kNameStartIndex], rm.name_size_); memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex], sizeof(rm.remote_addr_)); memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_)); - } - // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || - (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_)); } // data_type, tensor_bytes, tensor_shape, is_dead - if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_)); memcpy(&rm.data_type_, &message[kDataTypeStartIndex], sizeof(rm.data_type_)); @@ -1363,6 +1385,294 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex], sizeof(rm.tensor_bytes_)); } + // checksum +#ifdef RDMA_DATA_VALIDATION + memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_)); +#endif + // error status + if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + ErrorStatusProto gsProto; + uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex]; + CHECK(ParseProtoUnlimited( + &gsProto, &message[kErrorStatusStartIndex + 4], gsProtoSize)) + << "Failed to parse error status proto from message. Aborting."; + ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(), + gsProto.error_message(), gsProto.error_details()); + rm.status_ = FromGrpcStatus(gs); + } +} + +//***************************************************************************** +// RdmaMemoryMgr +//***************************************************************************** + +ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) { + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + if (iter == std::end(mrs_) || iter->get()->addr > addr) { + return nullptr; + } else { + return iter->get(); + } +} + +void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name) { + if (length == 0) return; + ibv_mr* mr = ibv_reg_mr(pd_, addr, length, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". [" + << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]" + << " SIZE: 0x" << length << " (" << allocator_name << ")."; + if (mr != nullptr) { + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + mrs_.insert(iter, {mr, &MRDeleter}); + } else { + LOG(WARNING) << "Cannot register memory region"; + } +} + +void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) { + if (length == 0) return; + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + if (iter != std::end(mrs_) && iter->get()->addr == addr) { + mrs_.erase(iter); + RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey; + + } else { + LOG(WARNING) << "Failed to de-register memory region"; + } +} + +const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData( + const std::string& tensor_name) { + mutex_lock l(tensor_meta_data_mu_); + auto it = tensors_meta_data_.find(tensor_name); + if (it == tensors_meta_data_.end()) { + return nullptr; + } + return &it->second; +} + +const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData( + const std::string& tensor_name, DataType dtype, const TensorShape& shape, + bool is_dead, size_t proto_size) { + mutex_lock l(tensor_meta_data_mu_); + TensorMetaData& meta_data = tensors_meta_data_[tensor_name]; + meta_data.data_type_ = dtype; + meta_data.tensor_shape_ = shape; + meta_data.proto_size_ = proto_size; + meta_data.is_dead_ = is_dead; + return &meta_data; +} + +//***************************************************************************** +// RdmaTensorRequest +//***************************************************************************** + +RdmaTensorRequest::RdmaTensorRequest( + uint32_t index, const string& key, int64 step_id, RdmaChannel* channel, + Device* dst_dev, const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done) + : index_(index), + key_(key), + step_id_(step_id), + channel_(channel), + dst_dev_(dst_dev), + recv_args_(recv_args), + meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)), + result_tensor_(nullptr), + proxy_tensor_(nullptr), + rdma_addr_(nullptr), + mr_(nullptr), + done_(done) {} + +RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); } + +void RdmaTensorRequest::Done(const Status& s) { + Tensor val = std::move(*result_tensor_); + +#ifdef RDMA_DATA_VALIDATION + // Validate checksum + // Unfortunately we can't always do a Checksum directly on the result tensor. + // If the result tensor is on GPU, then we need to copy it back to CPU. If + // we happen to be in the midst of a proxy callback, then the copying will + // get stuck. + uint64_t checksum = (proxy_tensor_ != nullptr) + ? Checksum(nullptr, nullptr, *proxy_tensor_) + : Checksum(dst_dev_, recv_args_.device_context, val); + ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA"); +#endif + + Rendezvous::Args recv_args = std::move(recv_args_); + bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_; + RecvDoneCallback done = done_; + DeallocateTensors(); + channel_->RemoveTensorRequest(index_); + done(s, Rendezvous::Args(), recv_args, val, is_dead); +} + +void RdmaTensorRequest::DeallocateTensors() { + if (result_tensor_ != nullptr) { + delete result_tensor_; + result_tensor_ = nullptr; + } + if (proxy_tensor_ != nullptr) { + delete proxy_tensor_; + proxy_tensor_ = nullptr; + } +} + +bool RdmaTensorRequest::AllocateTensors() { + result_tensor_ = + new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs), + meta_data_->data_type_, meta_data_->tensor_shape_); + + size_t tensor_size = result_tensor_->TotalBytes(); + bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype()); + if (can_memcpy) { + if (tensor_size == 0) { + return true; + } + rdma_addr_ = DMAHelper::base(result_tensor_); + mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); +#if GOOGLE_CUDA + if (mr_ == nullptr) { + // Can't RDMA directly to result. Use a proxy. + proxy_tensor_ = + new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), + result_tensor_->dtype(), result_tensor_->shape()); + rdma_addr_ = DMAHelper::base(proxy_tensor_); + mr_ = + RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); + } +#endif + } else { + uint32_t proto_size = meta_data_->proto_size_; + rdma_addr_ = malloc(proto_size); + mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + } + CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_ + << ": " << key_; + return true; +} + +void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) { + AllocateTensors(); + bool on_host = recv_args_.alloc_attrs.on_host(); + if (dst_dev_->tensorflow_gpu_device_info() && !on_host && + (proxy_tensor_ == nullptr)) { +#if GOOGLE_CUDA + // We need to sync the memory allocation on the GPU: + StreamGPUOp(dst_dev_, recv_args_.device_context, done); +#endif + } else { + done(Status::OK()); + } +} + +void RdmaTensorRequest::Send(RdmaMessageType message_type) { + RdmaMessageBuffer* rb = channel_->tx_message_buffer_; + RdmaMessage rm; + rm.type_ = message_type; + rm.request_index_ = index_; + rm.name_size_ = key_.size(); + rm.name_ = key_; + rm.step_id_ = step_id_; + rm.remote_addr_ = (uint64_t)rdma_addr_; + if (meta_data_ != nullptr) { + rm.data_type_ = meta_data_->data_type_; + rm.tensor_shape_ = meta_data_->tensor_shape_; + rm.is_dead_ = meta_data_->is_dead_; + rm.tensor_bytes_ = meta_data_->proto_size_; + } else { + rm.data_type_ = DT_INVALID; + } + rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey; + + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending " << MessageTypeToString(message_type) + << " #" << index_ << ": " + << rm.name_ << " on " << rdma_addr_ + << " (rkey: 0x" << std::hex << rm.rkey_ << ")"; + + string message = RdmaMessage::CreateMessage(rm); + rb->EnqueueItem(message); + rb->SendNextItem(); +} + +void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape, + bool is_dead, size_t proto_size) { + meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData( + key_, dtype, shape, is_dead, proto_size); + + DeallocateTensors(); + AllocateTensorsAsync([this](const Status& s) { + Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); + }); +} + +void RdmaTensorRequest::RecvTensorContent() { + bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_); + size_t message_size = + can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_; + RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec + << ": Received tensor content #" << index_ << ": " + << key_ << " (Size: 0x" << std::hex << message_size << ")"; + + Tensor val; + +#if GOOGLE_CUDA + if (proxy_tensor_ != nullptr) { + CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_), + (void*)DMAHelper::base(result_tensor_), + result_tensor_->TotalBytes(), false); + GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context, + dst_dev_, result_tensor_, + [this](const Status& s) { + CHECK(s.ok()) << "copy tensor to gpu sync"; + Done(s); + }); + return; + } +#endif + + if (can_memcpy) { + Done(Status::OK()); + } else { + RDMA_LOG(2) << "Decoding proto: " << key_ + << " (Size: " << meta_data_->proto_size_ << ")"; + TensorProto proto; + CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_)) + << "fail to parse proto from array"; + ibv_dereg_mr(mr_); + free(rdma_addr_); + Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs, + result_tensor_); + Done(s); + } +} + +void RdmaTensorRequest::RecvErrorStatus(const Status& status) { + if (result_tensor_ == nullptr) { + result_tensor_ = new Tensor(); + } + LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString(); + Done(status); +} + +void RdmaTensorRequest::Start() { + meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_); + if (meta_data_ != nullptr) { + AllocateTensorsAsync([this](const Status& s) { + Send(RDMA_MESSAGE_TENSOR_REQUEST); + }); + } else { + Send(RDMA_MESSAGE_TENSOR_REQUEST); + } } } // end namespace tensorflow |