diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 114 |
1 files changed, 11 insertions, 103 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 74f6681af3..ad3dce1784 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,10 +21,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -36,11 +32,6 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous { RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) : BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {} - void RecvPostCopyOps(const string& key, const string& key_with_step_id, - const Rendezvous::Args& recv_args, - const DoneCallback& done, const RdmaMessage& rm, - RdmaChannel* rc, Tensor& val, const Status& s); - protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, @@ -74,101 +65,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); string key(std::move(parsed.FullKey().ToString())); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); - // insert callback - rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc, - recv_args, parsed, done]() { - Status src_s, dst_s, s; - Device* src_dev, *dst_dev; - src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); - dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); - if (!src_s.ok() || !dst_s.ok()) { - s = src_s.ok() ? dst_s : src_s; - LOG(ERROR) << "s is not ok, error code " << s.error_message(); - done(s, Args(), recv_args, Tensor(), true); - return; - } - RdmaBuffer* rb = rc->FindBuffer(key); - RdmaMessage rm; - CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes); - RdmaMessage::ParseMessage(rm, rb->buffer_); - CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE); - Tensor val; - if (!rm.is_dead_) { - void* input = static_cast<char*>(rb->buffer_) + - RdmaMessage::kTensorBufferStartIndex; - bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_); - if (can_memcpy) { - if (dst_dev->tensorflow_gpu_device_info() && - (!recv_args.alloc_attrs.on_host())) { -#if GOOGLE_CUDA - CHECK(recv_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); - memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); - - Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs); - Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_); - - GPUUtil::CopyCPUTensorToGPU( - ©, recv_args.device_context, dst_dev, &gpu_copy, - [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc]( - const Status& s) { - CHECK(s.ok()) << "copy tensor to gpu sync"; - Tensor val; - val = std::move(gpu_copy); - RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, - val, s); - }); -#endif // GOOGLE_CUDA - return; - } else { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs); - Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); - memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); - val = std::move(copy); - } - } else { - TensorProto proto; - CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= - rb->size_); - CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) - << "fail to parse proto from array"; - s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); - } - } - RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s); - }); - // append key to message queue - RdmaBuffer* rb = rc->tx_message_buffer_; - RdmaMessage rm; - rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST; - rm.name_size_ = key.size(); - rm.name_ = key; - rm.step_id_ = step_id_; - string message = RdmaMessage::CreateMessage(rm); - rb->EnqueueItem(message); - rb->SendNextItem(); -} -void RdmaRemoteRendezvous::RecvPostCopyOps( - const string& key, const string& key_with_step_id, - const Rendezvous::Args& recv_args, const DoneCallback& done, - const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) { - rc->RemoveRecvCallback(key_with_step_id); - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_IDLE; - br.name_size_ = key.size(); - br.name_ = key; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* tb = rc->tx_message_buffer_; - tb->EnqueueItem(message); - tb->SendNextItem(); - done(s, Args(), recv_args, val, rm.is_dead_); + Device* dst_dev; + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); + CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), true); + return; + } + + RdmaTensorRequest* request = + rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done); + request->Start(); } RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) |