diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 197 |
1 files changed, 110 insertions, 87 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 445cbe290a..ec5adfdaa0 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -707,7 +707,6 @@ void RdmaTensorBuffer::SendNextItem() { bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); // string tensor needs to be serialized Tensor copy; - StringPiece copy_buf; TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { @@ -721,109 +720,133 @@ void RdmaTensorBuffer::SendNextItem() { host_alloc_attrs.set_on_host(true); Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); - s = VerbsUtil::CopyGPUTensorToCPUSync( - src_dev, send_args.device_context, &in, ©); - CHECK(s.ok()) << "copy tensor from gpu sync"; - copy_buf = copy.tensor_data(); + tensor_bytes = in.TotalBytes(); + buffer_size += tensor_bytes; + GPUUtil::CopyGPUTensorToCPU( + src_dev, send_args.device_context, &in, ©, + [this, copy, tensor_bytes, buffer_size, key, in, step_id, + key_with_step_id, is_dead](const Status& s) { + CHECK(s.ok()) << "copy tensor from gpu sync"; + StringPiece copy_buf; + copy_buf = copy.tensor_data(); + PostCopyOperations(true, buffer_size, tensor_bytes, key, in, + step_id, is_dead, key_with_step_id, ©, + NULL, ©_buf); + }); } else { - // "val" is on a GPU. Uses GPUUtil to fill the proto. - s = VerbsUtil::SetProtoFromGPUSync( - in, src_dev, send_args.device_context, &proto, is_dead); - CHECK(s.ok()) << "set proto from gpu sync"; + // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use + // aync instead + GPUUtil::SetProtoFromGPU( + in, src_dev, send_args.device_context, &proto, is_dead, + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead](const Status& s) mutable { + CHECK(s.ok()) << "copy proto from gpu sync"; + auto tensor_bytes = proto.ByteSize(); + buffer_size += tensor_bytes; + PostCopyOperations(false, buffer_size, tensor_bytes, key, in, + step_id, is_dead, key_with_step_id, NULL, + &proto, NULL); + }); } } else { // tensor is in CPU memory. + StringPiece copy_buf; if (can_memcpy) { copy_buf = in.tensor_data(); + tensor_bytes = in.TotalBytes(); } else { in.AsProtoTensorContent(&proto); + tensor_bytes = proto.ByteSize(); } - } - if (can_memcpy) { - tensor_bytes = in.TotalBytes(); - } else { - tensor_bytes = proto.ByteSize(); + buffer_size += tensor_bytes; + PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in, + step_id, is_dead, key_with_step_id, ©, &proto, + ©_buf); } // maybe some margin for string tensor? - buffer_size += tensor_bytes; - // prepare message - RdmaMessage rm; - rm.name_size_ = key.size(); - rm.name_ = key; - rm.tensor_shape_ = in.shape(); - rm.data_type_ = in.dtype(); - rm.step_id_ = step_id; - rm.is_dead_ = is_dead; - rm.tensor_bytes_ = tensor_bytes; - rm.buffer_size_ = buffer_size; - mu_.lock(); - if (local_status_ == none || - (buffer_size > size_ && local_status_ == idle && - remote_status_ == idle)) { - if ((local_status_ != none) && (buffer_size > size_)) { - VLOG(2) << "Extend RDMA buffer from " << size_ << " to " - << buffer_size; - } - CreateCPUBuffer(buffer_size, false); - mu_.unlock(); - // put back the key since it is not sent; - EnqueueItem(key_with_step_id); - // ask the remote to create the same buffer - rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST; - rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_); - rm.rkey_ = self_->rkey; - string message = RdmaMessage::CreateMessage(rm); - channel_->tx_message_buffer_->EnqueueItem(message); - channel_->tx_message_buffer_->SendNextItem(); - } else if ((local_status_ == idle) && (remote_status_ == idle)) { - // both buffers are ready, send the tensor - local_status_ = busy; - remote_status_ = busy; - // local/remote_status_ won't be set back to idle - // unitl Write() is successful - mu_.unlock(); - if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) || - (buffer_size <= size_ && rm.data_type_ == DT_STRING))) { - VLOG(2) << "Tensor and buffer size do not agree," - << " buffer_size = " << size_ - << " requested tensor size = " - << buffer_size << in.DebugString(); - } - uint32_t imm_data = LookupBufferIndex(key); - rm.type_ = RDMA_MESSAGE_TENSOR_WRITE; - string message = RdmaMessage::CreateMessage(rm); - memcpy(buffer_, message.data(), message.size()); - if (!is_dead) { - // copy the tensor buffer content - void* output = - static_cast<void*>(static_cast<char*>(buffer_) + - RdmaMessage::kTensorBufferStartIndex); - CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); - if (can_memcpy) { - CHECK(copy_buf.size() == tensor_bytes) - << "unexpected tensor size: " - << copy_buf.size() - << " != " - << tensor_bytes; - memcpy(output, copy_buf.data(), tensor_bytes); - } else { - proto.SerializeToArray(output, tensor_bytes); - } - } else { - buffer_size = RdmaMessage::kMessageTotalBytes; - } - Write(imm_data, buffer_size); - } else { - mu_.unlock(); - // put back the key since it is not sent; - EnqueueItem(key_with_step_id); - } }; + channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb); } } +void RdmaTensorBuffer::PostCopyOperations( + bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key, + const Tensor& in, int64 step_id, bool is_dead, + const string& key_with_step_id, const Tensor* copy, + const TensorProto* proto, const StringPiece* copy_buf) { + // prepare message + RdmaMessage rm; + rm.name_size_ = key.size(); + rm.name_ = key; + rm.tensor_shape_ = in.shape(); + rm.data_type_ = in.dtype(); + rm.step_id_ = step_id; + rm.is_dead_ = is_dead; + rm.tensor_bytes_ = tensor_bytes; + rm.buffer_size_ = buffer_size; + mu_.lock(); + if (local_status_ == none || (buffer_size > size_ && local_status_ == idle && + remote_status_ == idle)) { + if ((local_status_ != none) && (buffer_size > size_)) { + VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size; + } + CreateCPUBuffer(buffer_size, false); + mu_.unlock(); + // put back the key since it is not sent; + EnqueueItem(key_with_step_id); + // ask the remote to create the same buffer + rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST; + rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_); + rm.rkey_ = self_->rkey; + string message = RdmaMessage::CreateMessage(rm); + channel_->tx_message_buffer_->EnqueueItem(message); + channel_->tx_message_buffer_->SendNextItem(); + } else if ((local_status_ == idle) && (remote_status_ == idle)) { + // both buffers are ready, send the tensor + local_status_ = busy; + remote_status_ = busy; + // local/remote_status_ won't be set back to idle + // unitl Write() is successful + mu_.unlock(); + if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) || + (buffer_size <= size_ && rm.data_type_ == DT_STRING))) { + VLOG(2) << "Tensor and buffer size do not agree," + << " buffer_size = " << size_ + << " requested tensor size = " << buffer_size << in.DebugString(); + } + uint32_t imm_data = LookupBufferIndex(key); + rm.type_ = RDMA_MESSAGE_TENSOR_WRITE; + string message = RdmaMessage::CreateMessage(rm); + memcpy(buffer_, message.data(), message.size()); + if (!is_dead) { + // copy the tensor buffer content + void* output = static_cast<void*>(static_cast<char*>(buffer_) + + RdmaMessage::kTensorBufferStartIndex); + CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); + if (can_memcpy) { + CHECK(copy != NULL) << "callback missing pointer to copy tensor"; + CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer"; + CHECK(copy_buf->size() == tensor_bytes) + << "unexpected tensor size: " << copy_buf->size() + << " != " << tensor_bytes; + memcpy(output, copy_buf->data(), tensor_bytes); + } else { + CHECK(proto != NULL) << "callback missing pointer to proto tensor"; + proto->SerializeToArray(output, tensor_bytes); + } + } else { + buffer_size = RdmaMessage::kMessageTotalBytes; + } + Write(imm_data, buffer_size); + } else { + mu_.unlock(); + // put back the key since it is not sent; + EnqueueItem(key_with_step_id); + } +} + // Create a RdmaMessage according to the pre-defined format // Args: // rm: the message structure |