aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc197
1 files changed, 110 insertions, 87 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 445cbe290a..ec5adfdaa0 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -707,7 +707,6 @@ void RdmaTensorBuffer::SendNextItem() {
bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
Tensor copy;
- StringPiece copy_buf;
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
@@ -721,109 +720,133 @@ void RdmaTensorBuffer::SendNextItem() {
host_alloc_attrs.set_on_host(true);
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
copy = Tensor(alloc, in.dtype(), in.shape());
- s = VerbsUtil::CopyGPUTensorToCPUSync(
- src_dev, send_args.device_context, &in, &copy);
- CHECK(s.ok()) << "copy tensor from gpu sync";
- copy_buf = copy.tensor_data();
+ tensor_bytes = in.TotalBytes();
+ buffer_size += tensor_bytes;
+ GPUUtil::CopyGPUTensorToCPU(
+ src_dev, send_args.device_context, &in, &copy,
+ [this, copy, tensor_bytes, buffer_size, key, in, step_id,
+ key_with_step_id, is_dead](const Status& s) {
+ CHECK(s.ok()) << "copy tensor from gpu sync";
+ StringPiece copy_buf;
+ copy_buf = copy.tensor_data();
+ PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, &copy,
+ NULL, &copy_buf);
+ });
} else {
- // "val" is on a GPU. Uses GPUUtil to fill the proto.
- s = VerbsUtil::SetProtoFromGPUSync(
- in, src_dev, send_args.device_context, &proto, is_dead);
- CHECK(s.ok()) << "set proto from gpu sync";
+ // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
+ // aync instead
+ GPUUtil::SetProtoFromGPU(
+ in, src_dev, send_args.device_context, &proto, is_dead,
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead](const Status& s) mutable {
+ CHECK(s.ok()) << "copy proto from gpu sync";
+ auto tensor_bytes = proto.ByteSize();
+ buffer_size += tensor_bytes;
+ PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, NULL,
+ &proto, NULL);
+ });
}
} else {
// tensor is in CPU memory.
+ StringPiece copy_buf;
if (can_memcpy) {
copy_buf = in.tensor_data();
+ tensor_bytes = in.TotalBytes();
} else {
in.AsProtoTensorContent(&proto);
+ tensor_bytes = proto.ByteSize();
}
- }
- if (can_memcpy) {
- tensor_bytes = in.TotalBytes();
- } else {
- tensor_bytes = proto.ByteSize();
+ buffer_size += tensor_bytes;
+ PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, &copy, &proto,
+ &copy_buf);
}
// maybe some margin for string tensor?
- buffer_size += tensor_bytes;
- // prepare message
- RdmaMessage rm;
- rm.name_size_ = key.size();
- rm.name_ = key;
- rm.tensor_shape_ = in.shape();
- rm.data_type_ = in.dtype();
- rm.step_id_ = step_id;
- rm.is_dead_ = is_dead;
- rm.tensor_bytes_ = tensor_bytes;
- rm.buffer_size_ = buffer_size;
- mu_.lock();
- if (local_status_ == none ||
- (buffer_size > size_ && local_status_ == idle &&
- remote_status_ == idle)) {
- if ((local_status_ != none) && (buffer_size > size_)) {
- VLOG(2) << "Extend RDMA buffer from " << size_ << " to "
- << buffer_size;
- }
- CreateCPUBuffer(buffer_size, false);
- mu_.unlock();
- // put back the key since it is not sent;
- EnqueueItem(key_with_step_id);
- // ask the remote to create the same buffer
- rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
- rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
- rm.rkey_ = self_->rkey;
- string message = RdmaMessage::CreateMessage(rm);
- channel_->tx_message_buffer_->EnqueueItem(message);
- channel_->tx_message_buffer_->SendNextItem();
- } else if ((local_status_ == idle) && (remote_status_ == idle)) {
- // both buffers are ready, send the tensor
- local_status_ = busy;
- remote_status_ = busy;
- // local/remote_status_ won't be set back to idle
- // unitl Write() is successful
- mu_.unlock();
- if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
- (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
- VLOG(2) << "Tensor and buffer size do not agree,"
- << " buffer_size = " << size_
- << " requested tensor size = "
- << buffer_size << in.DebugString();
- }
- uint32_t imm_data = LookupBufferIndex(key);
- rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- if (!is_dead) {
- // copy the tensor buffer content
- void* output =
- static_cast<void*>(static_cast<char*>(buffer_) +
- RdmaMessage::kTensorBufferStartIndex);
- CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- if (can_memcpy) {
- CHECK(copy_buf.size() == tensor_bytes)
- << "unexpected tensor size: "
- << copy_buf.size()
- << " != "
- << tensor_bytes;
- memcpy(output, copy_buf.data(), tensor_bytes);
- } else {
- proto.SerializeToArray(output, tensor_bytes);
- }
- } else {
- buffer_size = RdmaMessage::kMessageTotalBytes;
- }
- Write(imm_data, buffer_size);
- } else {
- mu_.unlock();
- // put back the key since it is not sent;
- EnqueueItem(key_with_step_id);
- }
};
+
channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id,
parsed, cb);
}
}
+void RdmaTensorBuffer::PostCopyOperations(
+ bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key,
+ const Tensor& in, int64 step_id, bool is_dead,
+ const string& key_with_step_id, const Tensor* copy,
+ const TensorProto* proto, const StringPiece* copy_buf) {
+ // prepare message
+ RdmaMessage rm;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.tensor_shape_ = in.shape();
+ rm.data_type_ = in.dtype();
+ rm.step_id_ = step_id;
+ rm.is_dead_ = is_dead;
+ rm.tensor_bytes_ = tensor_bytes;
+ rm.buffer_size_ = buffer_size;
+ mu_.lock();
+ if (local_status_ == none || (buffer_size > size_ && local_status_ == idle &&
+ remote_status_ == idle)) {
+ if ((local_status_ != none) && (buffer_size > size_)) {
+ VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size;
+ }
+ CreateCPUBuffer(buffer_size, false);
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ // ask the remote to create the same buffer
+ rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
+ rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
+ rm.rkey_ = self_->rkey;
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+ } else if ((local_status_ == idle) && (remote_status_ == idle)) {
+ // both buffers are ready, send the tensor
+ local_status_ = busy;
+ remote_status_ = busy;
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
+ (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
+ VLOG(2) << "Tensor and buffer size do not agree,"
+ << " buffer_size = " << size_
+ << " requested tensor size = " << buffer_size << in.DebugString();
+ }
+ uint32_t imm_data = LookupBufferIndex(key);
+ rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ if (!is_dead) {
+ // copy the tensor buffer content
+ void* output = static_cast<void*>(static_cast<char*>(buffer_) +
+ RdmaMessage::kTensorBufferStartIndex);
+ CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
+ if (can_memcpy) {
+ CHECK(copy != NULL) << "callback missing pointer to copy tensor";
+ CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer";
+ CHECK(copy_buf->size() == tensor_bytes)
+ << "unexpected tensor size: " << copy_buf->size()
+ << " != " << tensor_bytes;
+ memcpy(output, copy_buf->data(), tensor_bytes);
+ } else {
+ CHECK(proto != NULL) << "callback missing pointer to proto tensor";
+ proto->SerializeToArray(output, tensor_bytes);
+ }
+ } else {
+ buffer_size = RdmaMessage::kMessageTotalBytes;
+ }
+ Write(imm_data, buffer_size);
+ } else {
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ }
+}
+
// Create a RdmaMessage according to the pre-defined format
// Args:
// rm: the message structure