aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc1348
1 files changed, 829 insertions, 519 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index ae9a384565..ec5271abe0 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -16,18 +16,19 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma.h"
-#include <fcntl.h>
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include <cstdlib>
#include <fcntl.h>
-#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -41,32 +42,19 @@ namespace tensorflow {
#define RoCE_V2 "RoCE v2"
namespace {
-// hash name to 32-bit integer
-uint32_t NameHash(const string& name) {
- return Hash32(name.data(), name.size(), 0x1234ABCD);
-}
// convenience function for printing message
string MessageTypeToString(RdmaMessageType rmt) {
switch (rmt) {
- case RDMA_MESSAGE_ACK:
- return "RDMA_MESSAGE_ACK";
- break;
- case RDMA_MESSAGE_BUFFER_IDLE:
- return "RDMA_MESSAGE_BUFFER_IDLE";
- break;
- case RDMA_MESSAGE_BUFFER_REQUEST:
- return "RDMA_MESSAGE_BUFFER_REQUEST";
+ case RDMA_MESSAGE_META_DATA_UPDATE:
+ return "RDMA_MESSAGE_META_DATA_UPDATE";
break;
- case RDMA_MESSAGE_BUFFER_RESPONSE:
- return "RDMA_MESSAGE_BUFFER_RESPONSE";
+ case RDMA_MESSAGE_TENSOR_RE_REQUEST:
+ return "RDMA_MESSAGE_TENSOR_RE_REQUEST";
break;
case RDMA_MESSAGE_TENSOR_REQUEST:
return "RDMA_MESSAGE_TENSOR_REQUEST";
break;
- case RDMA_MESSAGE_TENSOR_WRITE:
- return "RDMA_MESSAGE_TENSOR_WRITE";
- break;
default:
return "UNKNOWN MESSAGE";
}
@@ -347,7 +335,7 @@ uint32_t set_param(uint32_t default_val, const char* env_param) {
enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
ibv_port_attr port_attr;
- enum ibv_mtu mtu;
+ enum ibv_mtu mtu = IBV_MTU_512;
string mtu_s;
int rc, mtu_i;
@@ -468,97 +456,70 @@ void RdmaAdapter::Process_CQ() {
rc->Recv();
// imm_data is the index of RX buffer in the buffer table.
uint32_t imm_data = wc_[i].imm_data;
- RdmaBuffer* rb = rc->FindBuffer(imm_data);
+ RdmaMessageBuffer* rb;
RdmaMessage rm;
- RdmaMessage::ParseMessage(rm, rb->buffer_);
- VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
- if (rm.type_ == RDMA_MESSAGE_ACK) {
+ if (imm_data == RDMA_IMM_DATA_ACK) {
// receive an ack to a message
rb = rc->tx_message_buffer_;
rb->SetBufferStatus(remote, idle);
rb->SendNextItem();
- } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
- // received a request-for-tensor message
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find or create buffer
- RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
- string key_with_step_id =
- VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
- tb->EnqueueItem(key_with_step_id);
- // send the next tensor
- worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
- // receive tensor-buffer-ready message
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find buffer
- RdmaTensorBuffer* tb =
- reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
- tb->SetBufferStatus(remote, idle);
- worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
- // remote host requests to create a tensor buffer;
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find or create the buffer
- RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
- RemoteMR rmr;
- rmr.remote_addr = rm.remote_addr_;
- rmr.rkey = rm.rkey_;
- tb->SetRemoteMR(rmr, true);
- tb->CreateCPUBuffer(rm.buffer_size_);
- // create RDMA_MESSAGE_BUFFER_RESPONSE message
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
- br.name_size_ = rm.name_.size();
- br.name_ = rm.name_;
- br.buffer_size_ = rm.buffer_size_;
- br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
- br.rkey_ = tb->self_->rkey;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* mb = rc->tx_message_buffer_;
- mb->EnqueueItem(message);
- mb->SendNextItem();
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
- // remote creates a buffer and responds
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find buffer
- RdmaTensorBuffer* tb =
- reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
- CHECK(rm.buffer_size_ == tb->size_)
- << "rm.buffer_size = " << rm.buffer_size_
- << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
- RemoteMR rmr;
- rmr.remote_addr = rm.remote_addr_;
- rmr.rkey = rm.rkey_;
- tb->SetRemoteMR(rmr, true);
- tb->SetBufferStatus(local, idle);
- tb->SetBufferStatus(remote, idle);
- worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
- // tensor RDMA write completed
- worker_env_->compute_pool->Schedule([rm, rc]() {
- string key_with_step_id =
- VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
- rc->RunRecvCallback(key_with_step_id);
- });
+ continue;
}
- } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
- RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
- rb->SetBufferStatus(local, idle);
- RdmaMessage rm;
+
+ if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
+ // receive a tensor RDMA write
+ uint32_t request_index = imm_data;
+ RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
+ request->RecvTensorContent();
+ continue;
+ }
+
+ // receive a control message
+ rb = rc->rx_message_buffer_;
RdmaMessage::ParseMessage(rm, rb->buffer_);
- VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
- if (rm.type_ != RDMA_MESSAGE_ACK) {
- worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
+ RdmaMessageBuffer::SendAck(rc);
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Received " << MessageTypeToString(rm.type_) << " "
+ << "#" << rm.request_index_ << ": " << rm.name_;
+
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
+ RdmaTensorResponse* response = rc->AddTensorResponse(rm);
+ response->Start();
+ } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
+ RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
+ request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
+ rm.is_dead_, rm.tensor_bytes_);
+#ifdef RDMA_DATA_VALIDATION
+ request->RecvTensorChecksum(rm.checksum_);
+#endif
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
+ RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
+ response->Resume();
+ } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
+ request->RecvErrorStatus(rm.status_);
+ }
+ } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
+ RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
+ RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
+ switch (wr_id->write_type) {
+ case RDMA_WRITE_ID_ACK:
+ break;
+ case RDMA_WRITE_ID_MESSAGE: {
+ RdmaMessageBuffer* rb =
+ reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
+ rb->SetBufferStatus(local, idle);
+ rb->SendNextItem();
+ break;
+ }
+ case RDMA_WRITE_ID_TENSOR_WRITE: {
+ RdmaTensorResponse* response =
+ reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
+ response->Destroy();
+ }
}
+ delete wr_id;
}
}
}
@@ -588,8 +549,10 @@ int RdmaChannel::PingPostSend() {
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name)
- : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
-
+ : adapter_(adapter),
+ local_name_(local_name),
+ remote_name_(remote_name),
+ request_serial_(0) {
struct ibv_sge list;
mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
@@ -651,29 +614,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
- const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
- tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
- rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
message_buffers_.reserve(kNumMessageBuffers);
message_buffers_.push_back(tx_message_buffer_);
message_buffers_.push_back(rx_message_buffer_);
- message_buffers_.push_back(tx_ack_buffer_);
- message_buffers_.push_back(rx_ack_buffer_);
// create buffer on host
tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
- tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
- rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
- // bt_mu_.lock() is not used in constructor.
- for (int i = 0; i < kNumMessageBuffers; i++) {
- uint32_t index = NameHash(buffer_names[i]);
- buffer_table_.insert({index, message_buffers_[i]});
- buffer_index_name_table_.insert({index, buffer_names[i]});
- buffer_name_index_table_.insert({buffer_names[i], index});
- }
}
CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
<< " with error " << std::strerror(errno);
@@ -684,8 +633,6 @@ RdmaChannel::~RdmaChannel() {
CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
delete tx_message_buffer_;
delete rx_message_buffer_;
- delete tx_ack_buffer_;
- delete rx_ack_buffer_;
}
void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
@@ -716,114 +663,31 @@ void RdmaChannel::Recv() {
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
-// Lookup 32-bit buffer index from buffer name
-// Args:
-// buffer_name: name of the buffer
-// Returns:
-// 32-bit index
-uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
- mutex_lock lock{bt_mu_};
- BufferNameIndexTable::iterator iter =
- buffer_name_index_table_.find(buffer_name);
- CHECK(iter != buffer_name_index_table_.end());
- return iter->second;
-}
-
-// Find a buffer by its 32-bit index
-// Args:
-// index: 32-bit hash code of the tensor buffer name
-// Returns:
-// name of the tensor buffer
-RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
- mutex_lock lock{bt_mu_};
- BufferTable::iterator iter = buffer_table_.find(index);
- CHECK(iter != buffer_table_.end());
- return iter->second;
-}
-
-// Find a buffer by its name
-// Args:
-// name: name of the buffer
-// Returns:
-// the named rdma buffer
-RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
- uint32_t index = LookupBufferIndex(name);
- return FindBuffer(index);
-}
-
-// Find a buffer if it exists, otherwise create one.
-// The memory inside the created buffer is not allocated.
-// Args:
-// name: the name of the buffer
-// buffer_type: TENSOR, MESSAGE or ACK.
-// Returns:
-// the named buffer
-RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
- BufferType buffer_type) {
- mutex_lock lock{bt_mu_};
- RdmaBuffer* rb;
- // find index
- BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
- if (iter != buffer_name_index_table_.end()) {
- uint32_t index = iter->second;
- // find buffer
- BufferTable::iterator iter = buffer_table_.find(index);
- CHECK(iter != buffer_table_.end());
- rb = iter->second;
- } else {
- uint32_t index = NameHash(name);
- if (buffer_type == TENSOR) {
- rb = new RdmaTensorBuffer(this, name);
- } else if (buffer_type == MESSAGE) {
- rb = new RdmaMessageBuffer(this, name);
- } else if (buffer_type == ACK) {
- rb = new RdmaAckBuffer(this, name);
- }
- buffer_name_index_table_.insert({name, index});
- buffer_index_name_table_.insert({index, name});
- buffer_table_.insert({index, rb});
+RdmaTensorRequest* RdmaChannel::InsertTensorRequest(
+ const string& key, int64 step_id, Device* dst_dev,
+ const Rendezvous::Args recv_args,
+ const RdmaTensorRequest::RecvDoneCallback& done) {
+ mutex_lock lock{ct_mu_};
+ uint32_t request_index = request_serial_++;
+ if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) {
+ request_serial_ = 0;
}
- CHECK(rb);
- return rb;
+ RdmaTensorRequest request(request_index, key, step_id, this, dst_dev,
+ recv_args, done);
+ auto it = request_table_.emplace(request_index, request);
+ return &it.first->second;
}
-// Insert callback to the callback_table.
-// The callback is activated when the corresponding tensor is received.
-// Arg:
-// key: the name of the tensor
-// recv_done: the callback associated with the tensor.
-// Returns:
-// None
-void RdmaChannel::InsertRecvCallback(const string& key,
- std::function<void()> recv_done) {
+void RdmaChannel::RemoveTensorRequest(uint32_t request_index) {
mutex_lock lock{ct_mu_};
- callback_table_.insert({key, recv_done});
+ request_table_.erase(request_index);
}
-// Remove callback from the callback_table.
-// Arg:
-// key: the name of the tensor
-// Returns:
-// None
-void RdmaChannel::RemoveRecvCallback(const string& key) {
+RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) {
mutex_lock lock{ct_mu_};
- callback_table_.erase(key);
-}
-
-// Run named callback in the callback_table.
-// Arg:
-// key: the name of the tensor
-// Returns:
-// None
-void RdmaChannel::RunRecvCallback(const string& key) {
- std::function<void()> recv_done;
- {
- mutex_lock lock{ct_mu_};
- CallbackTable::iterator iter = callback_table_.find(key);
- CHECK(iter != callback_table_.end());
- recv_done = iter->second;
- }
- recv_done();
+ RequestTable::iterator iter = request_table_.find(request_index);
+ CHECK(iter != request_table_.end());
+ return &iter->second;
}
void RdmaChannel::Connect() {
@@ -888,25 +752,22 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
connected_ = true;
} else {
- LOG(INFO) << "channel already connected";
+ RDMA_LOG(2) << "channel already connected";
}
}
-RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
+RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
: channel_(channel), name_(name) {}
-RdmaBuffer::~RdmaBuffer() {
+RdmaMessageBuffer::~RdmaMessageBuffer() {
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
FreeBuffer();
}
-void RdmaBuffer::FreeBuffer() {
+void RdmaMessageBuffer::FreeBuffer() {
if ((buffer_ != nullptr) && buffer_on_host_) {
free(buffer_);
}
- // TODO
- // release buffer if it is on device.
- // We don't support RDMABuffer on device at this moment.
}
// Allocate CPU memory for the Rdma buffer
@@ -915,7 +776,7 @@ void RdmaBuffer::FreeBuffer() {
// lock: whether or not mutex_lock the process to protect concurrency.
// Returns:
// None
-void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
+void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) {
CHECK(size > 0);
if (lock) {
mu_.lock();
@@ -943,7 +804,7 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
// override: whether override existing information
// Returns:
// None
-void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
+void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
mutex_lock lock{mu_};
if ((override) || (remote_status_ == none)) {
remote_.remote_addr = rmr.remote_addr;
@@ -956,63 +817,51 @@ void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
}
// Put a task in the buffer's job queue
-void RdmaBuffer::EnqueueItem(string item) {
+void RdmaMessageBuffer::EnqueueItem(string item) {
mutex_lock lock{mu_};
queue_.push(item);
}
// Rdma-Write the content of the buffer
-void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+ Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey,
+ remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this);
+}
+
+// Generalized Write method
+void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data,
+ size_t buffer_size, uint64_t src_addr,
+ uint32_t lkey, uint64_t remote_addr,
+ uint32_t rkey, RdmaWriteIDType write_type,
+ void* write_context) {
struct ibv_sge list;
- list.addr = (uint64_t)buffer_;
+ list.addr = src_addr;
list.length = buffer_size;
- list.lkey = self_->lkey;
+ list.lkey = lkey;
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context);
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.send_flags = IBV_SEND_SIGNALED;
wr.imm_data = imm_data;
- wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
- wr.wr.rdma.rkey = remote_.rkey;
+ wr.wr.rdma.remote_addr = remote_addr;
+ wr.wr.rdma.rkey = rkey;
struct ibv_send_wr* bad_wr;
- CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
-}
-
-RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaTensorBuffer::~RdmaTensorBuffer() {
- for (Itable it = retable.begin(); it != retable.end(); ++it) {
- delete (it->second);
- }
+ CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send";
}
// Send the next ack from the buffer's job queue.
-void RdmaAckBuffer::SendNextItem() {
- uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
- RdmaMessage rm;
- rm.name_ = "rx_ack_buffer";
- rm.type_ = RDMA_MESSAGE_ACK;
- rm.name_size_ = rm.name_.size();
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- Write(imm_data, message.size());
+void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) {
+ Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr);
}
// Send the next message from the buffer's job queue.
void RdmaMessageBuffer::SendNextItem() {
- uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
+ uint32_t imm_data = RDMA_IMM_DATA_MESSAGE;
mu_.lock();
if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
local_status_ = busy;
@@ -1029,244 +878,392 @@ void RdmaMessageBuffer::SendNextItem() {
}
}
-Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
- const string& key_with_step_id, const string& key, int64 step_id,
- const Rendezvous::ParsedKey& parsed) {
- Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, parsed](
- const Status& status, const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
- CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
- << " error message: " << status.error_message();
- size_t buffer_size = RdmaMessage::kMessageTotalBytes;
- size_t tensor_bytes = 0;
- // Figures out which device the tensor is hosted on.
- Device* src_dev = nullptr;
- Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
- parsed.src_device, &src_dev);
- CHECK(s.ok()) << "src device not found";
- // Does the device have the right incarnation number we expect?
- CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
- << "RecvTensor expects a different device incarnation: "
- << parsed.src_incarnation << " vs. "
- << src_dev->attributes().incarnation()
- << ". Your worker job was probably restarted. Check your "
- << "worker job for the reason why it was restarted.";
- Device* dst_dev = nullptr;
- // destination is on CPU.
- s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
- &dst_dev);
- CHECK(s.ok()) << "dst device not found";
- AllocatorAttributes dst_alloc_attr;
- dst_alloc_attr.set_on_host(true);
-
- bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
- // string tensor needs to be serialized
- Tensor copy;
- TensorProto proto;
- if (src_dev->tensorflow_gpu_device_info() &&
- (!send_args.alloc_attrs.on_host())) {
#if GOOGLE_CUDA
- CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
- << " gpu_info: "
- << src_dev->tensorflow_gpu_device_info();
-
- if (can_memcpy) {
- AllocatorAttributes host_alloc_attrs;
- host_alloc_attrs.set_gpu_compatible(true);
- host_alloc_attrs.set_on_host(true);
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
- copy = Tensor(alloc, in.dtype(), in.shape());
- tensor_bytes = in.TotalBytes();
- buffer_size += tensor_bytes;
- GPUUtil::CopyGPUTensorToCPU(
- src_dev, send_args.device_context, &in, &copy,
- [this, copy, tensor_bytes, buffer_size, key, in, step_id,
- key_with_step_id, is_dead, send_args, recv_args](const Status& s) {
- CHECK(s.ok()) << "copy tensor from gpu sync";
- StringPiece copy_buf;
- copy_buf = copy.tensor_data();
- PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, &copy,
- NULL, &copy_buf, send_args, recv_args);
- });
- } else {
- // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
- // aync instead
- GPUUtil::SetProtoFromGPU(
- in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
- CHECK(s.ok()) << "copy proto from gpu sync";
- auto tensor_bytes = proto.ByteSize();
- buffer_size += tensor_bytes;
- PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, NULL,
- &proto, NULL, send_args, recv_args);
- });
- }
-#endif // GOOGLE_CUDA
- } else {
- // tensor is in CPU memory.
- StringPiece copy_buf;
- if (can_memcpy) {
- copy_buf = in.tensor_data();
- tensor_bytes = in.TotalBytes();
- } else {
- in.AsProtoTensorContent(&proto);
- tensor_bytes = proto.ByteSize();
- }
- buffer_size += tensor_bytes;
- PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, &copy, &proto,
- &copy_buf, send_args, recv_args);
+static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
+ size_t tensor_bytes, bool is_gpu_to_cpu) {
+#ifdef RDMA_COUNT_COPIES
+ static uint64_t numGPUToCPUCopies = 0;
+ static uint64_t numGPUToCPUCopiedBytes = 0;
+ static uint64_t numCPUToGPUCopies = 0;
+ static uint64_t numCPUToGPUCopiedBytes = 0;
+ static uint64_t numTotalCopies = 0;
+
+ if (is_gpu_to_cpu) {
+ ++numGPUToCPUCopies;
+ numGPUToCPUCopiedBytes += tensor_bytes;
+ } else {
+ ++numCPUToGPUCopies;
+ numCPUToGPUCopiedBytes += tensor_bytes;
+ }
+ if ((++numTotalCopies % 0x400) == 0) {
+ RDMA_LOG(0) << "Tensor copies:"
+ << " GPU to CPU: " << numGPUToCPUCopies
+ << " (" << numGPUToCPUCopiedBytes << " Bytes)"
+ << " CPU to GPU: " << numCPUToGPUCopies
+ << " (" << numCPUToGPUCopiedBytes << " Bytes)";
+ }
+ RDMA_LOG(2) << "Copying tensor " << key
+ << " From: " << src_addr << " To: " << dst_addr;
+#endif // RDMA_COUNT_COPIES
+}
+#endif // GOOGLE_CUDA
+
+#ifdef RDMA_DATA_VALIDATION
+static uint64_t Checksum(Device* device, const DeviceContext* device_context,
+ const Tensor& in) {
+ uint64 checksum = 0;
+ if (DataTypeCanUseMemcpy(in.dtype())) {
+#if GOOGLE_CUDA
+ if (in.TotalBytes() == 0) {
+ return 0;
}
- };
- return cb;
+ checksum = (device_context != nullptr)
+ ? GPUUtil::Checksum(device, device_context, in)
+ : GPUUtil::Checksum(in);
+#endif // GOOGLE_CUDA
+ } else {
+ string s = in.SummarizeValue(999999);
+ checksum = Hash64(s.c_str(), s.size(), 0);
+ }
+ return checksum;
}
-// Send the next tensor from the buffer's job queue.
-void RdmaTensorBuffer::SendNextItem() {
- // get the key
- string key_with_step_id = "";
- {
- mutex_lock lock{mu_};
- if (!queue_.empty()) {
- key_with_step_id = queue_.front();
- queue_.pop();
+static void ValidateChecksum(uint64_t expected, uint64_t actual,
+ const Tensor& in, uint32_t request_index,
+ const std::string& key, const std::string& msg) {
+ RDMA_LOG(2) << "Request #" << request_index << ": " << key
+ << ": Checksum: " << std::hex << " Expected = 0x" << expected
+ << ". Actual = 0x" << actual << ".";
+
+ if (expected != actual) {
+ // Checksum failed. There is one case where this is allowed - if the
+ // tensor is an AssignAdd of the global step. Since the data-validation
+ // always postpones the Tensor response in order to send a checksum message,
+ // it is possible that the global-step was updated while the response was
+ // still in queue.
+ if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) {
+ int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1;
+ actual = Hash64((const char*)&prev_val, 8, 0);
+ }
+ if (expected != actual) {
+ LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #"
+ << request_index << ": " << key << std::hex << " "
+ << DataTypeString(in.dtype()) << " "
+ << in.shape().DebugString() << " (0x" << in.TotalBytes()
+ << " bytes): "
+ << " Expected 0x" << expected << ". Got 0x" << actual << ".";
}
}
+}
+#endif // RDMA_DATA_VALIDATION
+
+#if GOOGLE_CUDA
+// Sync the 'done' operation on the GPU stream, but without all the data
+// copying.
+static void StreamGPUOp(Device* gpu_device,
+ const DeviceContext* device_context,
+ StatusCallback done) {
+ Tensor dummy1, dummy2;
+ GPUUtil::CopyGPUTensorToCPU(
+ gpu_device, device_context, &dummy1, &dummy2, done);
+}
+#endif // GOOGLE_CUDA
+
+RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) {
+ mutex_lock lock{mu_};
+ auto it =
+ responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm));
+ CHECK(it.second) << "Response with the ID " << rm.request_index_
+ << " already exists.";
+ return &it.first->second;
+}
+
+RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) {
+ mutex_lock lock{mu_};
+ auto it = responses_table_.find(rm.request_index_);
+ CHECK(it != responses_table_.end()) << "No response found.";
+ RdmaTensorResponse* response = &it->second;
+ response->Update(rm);
+ return response;
+}
- // send the tensor if a key is acquired.
- if (key_with_step_id != "") {
- VLOG(2) << "try to send tensor: " << key_with_step_id;
- string key;
- int64 step_id;
- VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
- CHECK(key.compare(name_) == 0);
- Rendezvous::ParsedKey parsed;
- Rendezvous::ParseKey(key, &parsed);
- Rendezvous::DoneCallback cb =
- getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
- channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id,
- parsed, cb);
+void RdmaChannel::RemoveTensorResponse(uint32_t request_index) {
+ mutex_lock lock{mu_};
+ responses_table_.erase(request_index);
+}
+
+void RdmaTensorResponse::Start() {
+ Rendezvous::ParsedKey parsed;
+ Status s = Rendezvous::ParseKey(rm_.name_, &parsed);
+ if (!s.ok()) {
+ SendErrorStatus(s);
+ return;
}
+
+ channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(
+ rm_.step_id_, parsed,
+ [this, parsed](const Status& status, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ bool is_dead) {
+ CHECK(status.ok()) << "RecvLocalAsync was not ok."
+ << " error message: " << status.error_message();
+ RecvHandler(parsed, send_args, recv_args, in, is_dead);
+ });
}
-void RdmaTensorBuffer::ReSendNextItem() {
- // get the key
- string key_with_step_id = "";
- {
- mutex_lock lock{mu_};
- if (!requeue.empty()) {
- key_with_step_id = requeue.front();
- requeue.pop();
- }
+void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); }
+
+// Helper for RecvTensor. Validates "key" and returns the source
+// device in "*src_dev".
+Status RdmaTensorResponse::PrepareRecvTensor(
+ const Rendezvous::ParsedKey& parsed, Device** src_dev) {
+ // Figures out which device the tensor is hosted on.
+ string local_name = DeviceNameUtils::LocalName(parsed.src_device);
+ TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice(
+ local_name, src_dev));
+
+ // Does the device have the right incarnation number we expect?
+ if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
+ return errors::Aborted(
+ "RecvTensor expects a different device incarnation: ",
+ parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
+ ". Your worker job was probably restarted. Check your "
+ "worker job for the reason why it was restarted.");
}
- // send the tensor if a key is acquired.
- if (key_with_step_id != "") {
- VLOG(2) << "try to send tensor: " << key_with_step_id;
- string key;
- int64 step_id;
- VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
- CHECK(key.compare(name_) == 0);
- Rendezvous::ParsedKey parsed;
- Rendezvous::ParseKey(key, &parsed);
- Rendezvous::DoneCallback cb =
- getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
- ReItem* item;
- {
- mutex_lock lock{mu_};
- Itable it = retable.find(key_with_step_id);
- CHECK(it != retable.end()) << "Could not find dup-recv context";
- item = it->second;
- retable.erase(it);
+ return Status::OK();
+}
+
+void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ Status s = PrepareRecvTensor(parsed, &src_dev_);
+ if (!s.ok()) {
+ SendErrorStatus(s);
+ return;
+ }
+
+ meta_data_changed_ = TensorMetaDataChanged(in, is_dead);
+#ifdef RDMA_DATA_VALIDATION
+ // Always send a meta data message with the source checksum
+ meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST;
+ checksum_ = Checksum(src_dev_, send_args.device_context, in);
+#endif
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ // string tensor needs to be serialized
+ Tensor copy;
+ TensorProto proto;
+ const bool on_host = send_args.alloc_attrs.on_host();
+ if (src_dev_->tensorflow_gpu_device_info() && !on_host) {
+#if GOOGLE_CUDA
+ DeviceContext* send_dev_context = send_args.device_context;
+ CHECK(send_dev_context)
+ << "send dev name: " << src_dev_->name()
+ << " gpu_info: " << src_dev_->tensorflow_gpu_device_info();
+
+ if (can_memcpy) {
+ // If the tensor is located on a GDR compatible GPU, there is no need to
+ // copy it. We can send directly from the source, just need to make sure
+ // we are in sync with the GPU stream.
+ // If the tensor's meta-data changed however, we will need to clone it,
+ // so anyway we'll have to copy it from GPU to CPU first. If at some
+ // point in time Clone() is changed to only save a shallow copy, we can
+ // skip the copy here as well.
+ if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
+ (RdmaMemoryMgr::Singleton().FindMemoryRegion(
+ (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
+ StreamGPUOp(src_dev_, send_dev_context,
+ [this, in, proto, is_dead](const Status& s) {
+ Send(in, proto, is_dead, s);
+ });
+ return;
+ }
+
+ // The tensor must be copied from GPU to CPU, because either:
+ // 1. The tensor is located on a non GDR compatible GPU.
+ // 2. The tensor's meta-data has changed.
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ copy = Tensor(alloc, in.dtype(), in.shape());
+ CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
+ (void*)DMAHelper::base(&copy), in.TotalBytes(), true);
+ GPUUtil::CopyGPUTensorToCPU(
+ src_dev_, send_dev_context, &in, &copy,
+ [this, copy, proto, is_dead](const Status& s) {
+ Send(copy, proto, is_dead, s);
+ });
+ } else {
+ GPUUtil::SetProtoFromGPU(
+ in, src_dev_, send_args.device_context, &proto, is_dead,
+ [this, in, proto, is_dead](const Status& s) mutable {
+ Send(in, proto, is_dead, s);
+ });
+ }
+#else
+ SendErrorStatus(errors::Internal("No GPU device in process"));
+#endif // GOOGLE_CUDA
+ } else {
+ // tensor is in CPU memory.
+ if (!can_memcpy) {
+ in.AsProtoTensorContent(&proto);
}
- cb(Status::OK(), item->send_args, item->recv_args, item->in, item->is_dead);
- delete (item);
+ Send(in, proto, is_dead, Status::OK());
+ }
+}
+
+void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
+ bool is_dead, const Status& status) {
+ if (!status.ok()) {
+ SendErrorStatus(status);
+ return;
+ }
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ bool proto_size_changed = (!can_memcpy) &&
+ (proto.ByteSize() != rm_.tensor_bytes_);
+ if (meta_data_changed_ || proto_size_changed) {
+ Clone(in, proto, is_dead);
+ SendMetaData(in, proto, is_dead);
+ } else {
+ SendContent(in, proto, is_dead);
+ }
+}
+
+bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) {
+ return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) ||
+ (rm_.is_dead_ != is_dead);
+}
+
+void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto,
+ bool is_dead) {
+ // Clone the data to be sent later. For simplicity, we clone the tensor's
+ // data even if it is already a copy. Performance is less of a concern here
+ // since the meta-data hardly ever changes. The reason we create a copy, is
+ // that some tensors share their buffer between different step-ids, so the
+ // tensor content may change before re-request was completed.
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ if (can_memcpy && (in.TotalBytes() > 0)) {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_nic_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs);
+ tensor_ = new Tensor(allocator, in.dtype(), in.shape());
+ memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes());
+ } else {
+ tensor_ = new Tensor(in.dtype(), in.shape());
}
+ if (!can_memcpy) {
+ proto_ = new TensorProto(proto);
+ }
+ is_dead_ = is_dead;
}
-void RdmaTensorBuffer::PostCopyOperations(
- bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key,
- const Tensor& in, int64 step_id, bool is_dead,
- const string& key_with_step_id, const Tensor* copy,
- const TensorProto* proto, const StringPiece* copy_buf,
- const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args) {
- // prepare message
+void RdmaTensorResponse::SendMetaData(const Tensor& in,
+ const TensorProto& proto, bool is_dead) {
+ RDMA_LOG(2) << "Request #" << rm_.request_index_
+ << ": Meta data changed: " << rm_.name_;
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
+
+ // Send meta-data update:
RdmaMessage rm;
- rm.name_size_ = key.size();
- rm.name_ = key;
+ rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE;
+ rm.name_size_ = rm_.name_.size();
+ rm.name_ = rm_.name_;
rm.tensor_shape_ = in.shape();
rm.data_type_ = in.dtype();
- rm.step_id_ = step_id;
+ rm.step_id_ = rm_.step_id_;
rm.is_dead_ = is_dead;
rm.tensor_bytes_ = tensor_bytes;
- rm.buffer_size_ = buffer_size;
- mu_.lock();
- if (local_status_ == none || (buffer_size > size_ && local_status_ == idle &&
- remote_status_ == idle)) {
- if ((local_status_ != none) && (buffer_size > size_)) {
- VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size;
- }
- CreateCPUBuffer(buffer_size, false);
- // Need to be received again, put into the re-recv queue and the table
- requeue.push(key_with_step_id);
- ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
- retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
- mu_.unlock();
- // no longer used: put back the key since it is not sent;
- // ask the remote to create the same buffer
- rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
- rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
- rm.rkey_ = self_->rkey;
- string message = RdmaMessage::CreateMessage(rm);
- channel_->tx_message_buffer_->EnqueueItem(message);
- channel_->tx_message_buffer_->SendNextItem();
- } else if ((local_status_ == idle) && (remote_status_ == idle)) {
- // both buffers are ready, send the tensor
- local_status_ = busy;
- remote_status_ = busy;
- // local/remote_status_ won't be set back to idle
- // unitl Write() is successful
- mu_.unlock();
- if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
- (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
- VLOG(2) << "Tensor and buffer size do not agree,"
- << " buffer_size = " << size_
- << " requested tensor size = " << buffer_size << in.DebugString();
- }
- uint32_t imm_data = LookupBufferIndex(key);
- rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- if (!is_dead) {
- // copy the tensor buffer content
- void* output = static_cast<void*>(static_cast<char*>(buffer_) +
- RdmaMessage::kTensorBufferStartIndex);
- CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- if (can_memcpy) {
- CHECK(copy != NULL) << "callback missing pointer to copy tensor";
- CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer";
- CHECK(copy_buf->size() == tensor_bytes)
- << "unexpected tensor size: " << copy_buf->size()
- << " != " << tensor_bytes;
- memcpy(output, copy_buf->data(), tensor_bytes);
- } else {
- CHECK(proto != NULL) << "callback missing pointer to proto tensor";
- proto->SerializeToArray(output, tensor_bytes);
+ rm.request_index_ = rm_.request_index_;
+#ifdef RDMA_DATA_VALIDATION
+ rm.checksum_ = checksum_;
+#endif
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #"
+ << rm.request_index_ << ": " << rm.name_
+ << " (shape = " << rm.tensor_shape_.DebugString() << "."
+ << " data-type = " << DataTypeString(rm.data_type_) << "."
+ << " is-dead = " << rm.is_dead_ << ")";
+
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+}
+
+void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto,
+ bool is_dead) {
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
+ uint32_t imm_data = rm_.request_index_;
+ if (!is_dead) {
+ if (can_memcpy) {
+ src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in));
+ if (src_buffer_ != nullptr) {
+ src_buffer_->Ref(); // Keep buffer alive until write is complete
+ src_addr_ = src_buffer_->data();
+ mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_,
+ tensor_bytes);
}
} else {
- buffer_size = RdmaMessage::kMessageTotalBytes;
+ RDMA_LOG(2) << "Encoding proto: " << rm_.name_
+ << " (Size: " << tensor_bytes << ") " << in.DebugString();
+ src_addr_ = malloc(tensor_bytes);
+ mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ proto.SerializeToArray(src_addr_, tensor_bytes);
}
- Write(imm_data, buffer_size);
} else {
- // Need to be received again, put into the re-recv queue and the table
- requeue.push(key_with_step_id);
- ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
- retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
- mu_.unlock();
+ tensor_bytes = 0;
}
+
+ uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey;
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec
+ << ": Sending tensor content #" << rm_.request_index_ << " from "
+ << std::hex << src_addr_ << " (0x" << lkey << ")"
+ << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_
+ << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes
+ << ")";
+
+ RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes,
+ (uint64_t)src_addr_, lkey, rm_.remote_addr_,
+ rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this);
+}
+
+void RdmaTensorResponse::SendErrorStatus(const Status& status) {
+ RdmaMessage rm;
+ rm.type_ = RDMA_MESSAGE_ERROR_STATUS;
+ rm.name_size_ = rm_.name_.size();
+ rm.name_ = rm_.name_;
+ rm.step_id_ = rm_.step_id_;
+ rm.request_index_ = rm_.request_index_;
+ rm.status_ = status;
+ LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending RDMA_MESSAGE_ERROR_STATUS #"
+ << rm.request_index_ << ": " << rm.name_
+ << ". Status: " << status.ToString();
+
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+
+ // Destroy the response.
+ Destroy();
+}
+
+void RdmaTensorResponse::Destroy() {
+ if (src_buffer_ != nullptr) {
+ src_buffer_->Unref();
+ }
+ if (tensor_ != nullptr) {
+ delete tensor_;
+ }
+ if (proto_ != nullptr) {
+ ibv_dereg_mr(mr_);
+ free(src_addr_);
+ delete proto_;
+ }
+ // Remove response from the pending list:
+ channel_->RemoveTensorResponse(rm_.request_index_);
}
// Create a RdmaMessage according to the pre-defined format
@@ -1276,43 +1273,46 @@ void RdmaTensorBuffer::PostCopyOperations(
// message in string format
string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
// Rdma Message format
- // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
- // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
- // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
- // ...| XB | XB | 8B |...
+ // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|error_status |
+ // ...| XB | XB | 8B |size - 4B, proto - XB |
//
- // ACK: type|13|"rx_ack_buffer"
- // TENSOR_REQUEST: type|name_size|tensor_name|step_id
- // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
- // |data_type|tensor_shape|tensor_bytes
- // BUFFER_IDLE: type|name_size|buffer_name
- // BUFFER_REQUEST:
- // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
- // BUFFER_RESPONSE:
- // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
- char message[kMessageTotalBytes];
+ // ACK: Imm-type: ACK
+ // TENSOR_REQUEST: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, remote_addr,
+ // rkey, is_dead, data_type, tensor_shape, tensor_bytes
+ // META_DATA_UPDATE: Imm-type: MESSAGE
+ // Fields: type, request_index, is_dead, data_type,
+ // tensor_shape, tensor_bytes
+ // TENSOR_RE_REQUST: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, remote_addr,
+ // rkey, is_dead, data_type, tensor_shape, tensor_bytes
+ // ERROR_STATUS: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, error_status
+ // Tensor content: Imm-type: request_index
+ size_t message_size = kMessageTotalBytes;
+ char message[kMessageTotalBytes + kErrorStatusMaxSize];
// type
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
- // size of name
- memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
- // name
- memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
- // buffer_size, remote_addr, rkey
- if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
- (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
- memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
- sizeof(rm.buffer_size_));
+ // request index
+ memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
+ sizeof(rm.request_index_));
+ // name, step_id, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
+ memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
+ sizeof(rm.name_size_));
+ memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
sizeof(rm.remote_addr_));
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
- }
- // step_id
- if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
- (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
}
// is_dead, data_type, tensor_shape, tensor_bytes
- if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
@@ -1322,7 +1322,31 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
sizeof(rm.tensor_bytes_));
}
- return string(message, kMessageTotalBytes);
+ // checksum
+#ifdef RDMA_DATA_VALIDATION
+ memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
+#endif
+ // error status
+ if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ ::grpc::Status gs = ToGrpcStatus(rm.status_);
+ ErrorStatusProto gsProto;
+ gsProto.set_error_code(gs.error_code());
+ gsProto.set_error_message(gs.error_message());
+ gsProto.set_error_details(gs.error_details());
+ uint32_t gsProtoSize = gsProto.ByteSize();
+ if (gsProtoSize + 4 > kErrorStatusMaxSize) {
+ LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
+ << "is too big to fit in RDMA message ("
+ << kErrorStatusMaxSize << " bytes). Truncated.";
+ gsProtoSize = kErrorStatusMaxSize - 4;
+ }
+ uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
+ *proto_size = gsProtoSize;
+ gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4],
+ gsProtoSize);
+ message_size += gsProtoSize + 4;
+ }
+ return string(message, message_size);
}
// Parse a RdmaMessage according to the pre-defined format
@@ -1335,26 +1359,24 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
char* message = static_cast<char*>(buffer);
// type
rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
- // name_size_
- memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
- // name
- rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
- // buffer_size, remote_addr, rkey
- if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
- (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
- memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
- sizeof(rm.buffer_size_));
+ // request index
+ memcpy(&rm.request_index_, &message[kRequestIndexStartIndex],
+ sizeof(rm.request_index_));
+ // name, step_id, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
+ memcpy(&rm.name_size_, &message[kNameSizeStartIndex],
+ sizeof(rm.name_size_));
+ rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
sizeof(rm.remote_addr_));
memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
- }
- // step_id
- if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
- (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
}
// data_type, tensor_bytes, tensor_shape, is_dead
- if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
sizeof(rm.data_type_));
@@ -1363,6 +1385,294 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
sizeof(rm.tensor_bytes_));
}
+ // checksum
+#ifdef RDMA_DATA_VALIDATION
+ memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_));
+#endif
+ // error status
+ if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ ErrorStatusProto gsProto;
+ uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
+ CHECK(ParseProtoUnlimited(
+ &gsProto, &message[kErrorStatusStartIndex + 4], gsProtoSize))
+ << "Failed to parse error status proto from message. Aborting.";
+ ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
+ gsProto.error_message(), gsProto.error_details());
+ rm.status_ = FromGrpcStatus(gs);
+ }
+}
+
+//*****************************************************************************
+// RdmaMemoryMgr
+//*****************************************************************************
+
+ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) {
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ if (iter == std::end(mrs_) || iter->get()->addr > addr) {
+ return nullptr;
+ } else {
+ return iter->get();
+ }
+}
+
+void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name) {
+ if (length == 0) return;
+ ibv_mr* mr = ibv_reg_mr(pd_, addr, length,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". ["
+ << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]"
+ << " SIZE: 0x" << length << " (" << allocator_name << ").";
+ if (mr != nullptr) {
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ mrs_.insert(iter, {mr, &MRDeleter});
+ } else {
+ LOG(WARNING) << "Cannot register memory region";
+ }
+}
+
+void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) {
+ if (length == 0) return;
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ if (iter != std::end(mrs_) && iter->get()->addr == addr) {
+ mrs_.erase(iter);
+ RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey;
+
+ } else {
+ LOG(WARNING) << "Failed to de-register memory region";
+ }
+}
+
+const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData(
+ const std::string& tensor_name) {
+ mutex_lock l(tensor_meta_data_mu_);
+ auto it = tensors_meta_data_.find(tensor_name);
+ if (it == tensors_meta_data_.end()) {
+ return nullptr;
+ }
+ return &it->second;
+}
+
+const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData(
+ const std::string& tensor_name, DataType dtype, const TensorShape& shape,
+ bool is_dead, size_t proto_size) {
+ mutex_lock l(tensor_meta_data_mu_);
+ TensorMetaData& meta_data = tensors_meta_data_[tensor_name];
+ meta_data.data_type_ = dtype;
+ meta_data.tensor_shape_ = shape;
+ meta_data.proto_size_ = proto_size;
+ meta_data.is_dead_ = is_dead;
+ return &meta_data;
+}
+
+//*****************************************************************************
+// RdmaTensorRequest
+//*****************************************************************************
+
+RdmaTensorRequest::RdmaTensorRequest(
+ uint32_t index, const string& key, int64 step_id, RdmaChannel* channel,
+ Device* dst_dev, const Rendezvous::Args recv_args,
+ const RdmaTensorRequest::RecvDoneCallback& done)
+ : index_(index),
+ key_(key),
+ step_id_(step_id),
+ channel_(channel),
+ dst_dev_(dst_dev),
+ recv_args_(recv_args),
+ meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)),
+ result_tensor_(nullptr),
+ proxy_tensor_(nullptr),
+ rdma_addr_(nullptr),
+ mr_(nullptr),
+ done_(done) {}
+
+RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); }
+
+void RdmaTensorRequest::Done(const Status& s) {
+ Tensor val = std::move(*result_tensor_);
+
+#ifdef RDMA_DATA_VALIDATION
+ // Validate checksum
+ // Unfortunately we can't always do a Checksum directly on the result tensor.
+ // If the result tensor is on GPU, then we need to copy it back to CPU. If
+ // we happen to be in the midst of a proxy callback, then the copying will
+ // get stuck.
+ uint64_t checksum = (proxy_tensor_ != nullptr)
+ ? Checksum(nullptr, nullptr, *proxy_tensor_)
+ : Checksum(dst_dev_, recv_args_.device_context, val);
+ ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA");
+#endif
+
+ Rendezvous::Args recv_args = std::move(recv_args_);
+ bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_;
+ RecvDoneCallback done = done_;
+ DeallocateTensors();
+ channel_->RemoveTensorRequest(index_);
+ done(s, Rendezvous::Args(), recv_args, val, is_dead);
+}
+
+void RdmaTensorRequest::DeallocateTensors() {
+ if (result_tensor_ != nullptr) {
+ delete result_tensor_;
+ result_tensor_ = nullptr;
+ }
+ if (proxy_tensor_ != nullptr) {
+ delete proxy_tensor_;
+ proxy_tensor_ = nullptr;
+ }
+}
+
+bool RdmaTensorRequest::AllocateTensors() {
+ result_tensor_ =
+ new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs),
+ meta_data_->data_type_, meta_data_->tensor_shape_);
+
+ size_t tensor_size = result_tensor_->TotalBytes();
+ bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype());
+ if (can_memcpy) {
+ if (tensor_size == 0) {
+ return true;
+ }
+ rdma_addr_ = DMAHelper::base(result_tensor_);
+ mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
+#if GOOGLE_CUDA
+ if (mr_ == nullptr) {
+ // Can't RDMA directly to result. Use a proxy.
+ proxy_tensor_ =
+ new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0),
+ result_tensor_->dtype(), result_tensor_->shape());
+ rdma_addr_ = DMAHelper::base(proxy_tensor_);
+ mr_ =
+ RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
+ }
+#endif
+ } else {
+ uint32_t proto_size = meta_data_->proto_size_;
+ rdma_addr_ = malloc(proto_size);
+ mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ }
+ CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_
+ << ": " << key_;
+ return true;
+}
+
+void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
+ AllocateTensors();
+ bool on_host = recv_args_.alloc_attrs.on_host();
+ if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
+ (proxy_tensor_ == nullptr)) {
+#if GOOGLE_CUDA
+ // We need to sync the memory allocation on the GPU:
+ StreamGPUOp(dst_dev_, recv_args_.device_context, done);
+#endif
+ } else {
+ done(Status::OK());
+ }
+}
+
+void RdmaTensorRequest::Send(RdmaMessageType message_type) {
+ RdmaMessageBuffer* rb = channel_->tx_message_buffer_;
+ RdmaMessage rm;
+ rm.type_ = message_type;
+ rm.request_index_ = index_;
+ rm.name_size_ = key_.size();
+ rm.name_ = key_;
+ rm.step_id_ = step_id_;
+ rm.remote_addr_ = (uint64_t)rdma_addr_;
+ if (meta_data_ != nullptr) {
+ rm.data_type_ = meta_data_->data_type_;
+ rm.tensor_shape_ = meta_data_->tensor_shape_;
+ rm.is_dead_ = meta_data_->is_dead_;
+ rm.tensor_bytes_ = meta_data_->proto_size_;
+ } else {
+ rm.data_type_ = DT_INVALID;
+ }
+ rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
+
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending " << MessageTypeToString(message_type)
+ << " #" << index_ << ": "
+ << rm.name_ << " on " << rdma_addr_
+ << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
+
+ string message = RdmaMessage::CreateMessage(rm);
+ rb->EnqueueItem(message);
+ rb->SendNextItem();
+}
+
+void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
+ bool is_dead, size_t proto_size) {
+ meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData(
+ key_, dtype, shape, is_dead, proto_size);
+
+ DeallocateTensors();
+ AllocateTensorsAsync([this](const Status& s) {
+ Send(RDMA_MESSAGE_TENSOR_RE_REQUEST);
+ });
+}
+
+void RdmaTensorRequest::RecvTensorContent() {
+ bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_);
+ size_t message_size =
+ can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
+ RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
+ << ": Received tensor content #" << index_ << ": "
+ << key_ << " (Size: 0x" << std::hex << message_size << ")";
+
+ Tensor val;
+
+#if GOOGLE_CUDA
+ if (proxy_tensor_ != nullptr) {
+ CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_),
+ (void*)DMAHelper::base(result_tensor_),
+ result_tensor_->TotalBytes(), false);
+ GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context,
+ dst_dev_, result_tensor_,
+ [this](const Status& s) {
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ Done(s);
+ });
+ return;
+ }
+#endif
+
+ if (can_memcpy) {
+ Done(Status::OK());
+ } else {
+ RDMA_LOG(2) << "Decoding proto: " << key_
+ << " (Size: " << meta_data_->proto_size_ << ")";
+ TensorProto proto;
+ CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_))
+ << "fail to parse proto from array";
+ ibv_dereg_mr(mr_);
+ free(rdma_addr_);
+ Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs,
+ result_tensor_);
+ Done(s);
+ }
+}
+
+void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
+ if (result_tensor_ == nullptr) {
+ result_tensor_ = new Tensor();
+ }
+ LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString();
+ Done(status);
+}
+
+void RdmaTensorRequest::Start() {
+ meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
+ if (meta_data_ != nullptr) {
+ AllocateTensorsAsync([this](const Status& s) {
+ Send(RDMA_MESSAGE_TENSOR_REQUEST);
+ });
+ } else {
+ Send(RDMA_MESSAGE_TENSOR_REQUEST);
+ }
}
} // end namespace tensorflow