diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 51 |
1 files changed, 35 insertions, 16 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 3ba6510711..ce82ca2883 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -33,6 +34,11 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous { RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) : BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {} + void RecvPostCopyOps(const string& key, const string& key_with_step_id, + const Rendezvous::Args& recv_args, + const DoneCallback& done, const RdmaMessage& rm, + RdmaChannel* rc, Tensor& val, const Status& s); + protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, @@ -113,10 +119,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs); Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_); - s = VerbsUtil::CopyCPUTensorToGPUSync(©, recv_args.device_context, - dst_dev, &gpu_copy); - CHECK(s.ok()) << "copy tensor to gpu sync"; - val = std::move(gpu_copy); + + GPUUtil::CopyCPUTensorToGPU( + ©, recv_args.device_context, dst_dev, &gpu_copy, + [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, + rc](const Status& s) { + CHECK(s.ok()) << "copy tensor to gpu sync"; + Tensor val; + val = std::move(gpu_copy); + RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, + val, s); + }); + return; } else { AllocatorAttributes host_alloc_attrs; host_alloc_attrs.set_gpu_compatible(true); @@ -135,18 +149,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); } } - - rc->RemoveRecvCallback(key_with_step_id); - // create message - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_IDLE; - br.name_size_ = key.size(); - br.name_ = key; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* tb = rc->tx_message_buffer_; - tb->EnqueueItem(message); - tb->SendNextItem(); - done(s, Args(), recv_args, val, rm.is_dead_); + RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s); }); // append key to message queue RdmaBuffer* rb = rc->tx_message_buffer_; @@ -160,6 +163,22 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( rb->SendNextItem(); } +void RdmaRemoteRendezvous::RecvPostCopyOps( + const string& key, const string& key_with_step_id, + const Rendezvous::Args& recv_args, const DoneCallback& done, + const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) { + rc->RemoveRecvCallback(key_with_step_id); + RdmaMessage br; + br.type_ = RDMA_MESSAGE_BUFFER_IDLE; + br.name_size_ = key.size(); + br.name_ = key; + string message = RdmaMessage::CreateMessage(br); + RdmaBuffer* tb = rc->tx_message_buffer_; + tb->EnqueueItem(message); + tb->SendNextItem(); + done(s, Args(), recv_args, val, rm.is_dead_); +} + RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {} |