aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc51
1 files changed, 35 insertions, 16 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 3ba6510711..ce82ca2883 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -33,6 +34,11 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr)
: BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {}
+ void RecvPostCopyOps(const string& key, const string& key_with_step_id,
+ const Rendezvous::Args& recv_args,
+ const DoneCallback& done, const RdmaMessage& rm,
+ RdmaChannel* rc, Tensor& val, const Status& s);
+
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
@@ -113,10 +119,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
- s = VerbsUtil::CopyCPUTensorToGPUSync(&copy, recv_args.device_context,
- dst_dev, &gpu_copy);
- CHECK(s.ok()) << "copy tensor to gpu sync";
- val = std::move(gpu_copy);
+
+ GPUUtil::CopyCPUTensorToGPU(
+ &copy, recv_args.device_context, dst_dev, &gpu_copy,
+ [this, gpu_copy, key, key_with_step_id, recv_args, done, rm,
+ rc](const Status& s) {
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ Tensor val;
+ val = std::move(gpu_copy);
+ RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
+ val, s);
+ });
+ return;
} else {
AllocatorAttributes host_alloc_attrs;
host_alloc_attrs.set_gpu_compatible(true);
@@ -135,18 +149,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
}
}
-
- rc->RemoveRecvCallback(key_with_step_id);
- // create message
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
- br.name_size_ = key.size();
- br.name_ = key;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* tb = rc->tx_message_buffer_;
- tb->EnqueueItem(message);
- tb->SendNextItem();
- done(s, Args(), recv_args, val, rm.is_dead_);
+ RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s);
});
// append key to message queue
RdmaBuffer* rb = rc->tx_message_buffer_;
@@ -160,6 +163,22 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
rb->SendNextItem();
}
+void RdmaRemoteRendezvous::RecvPostCopyOps(
+ const string& key, const string& key_with_step_id,
+ const Rendezvous::Args& recv_args, const DoneCallback& done,
+ const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) {
+ rc->RemoveRecvCallback(key_with_step_id);
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
+ br.name_size_ = key.size();
+ br.name_ = key;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* tb = rc->tx_message_buffer_;
+ tb->EnqueueItem(message);
+ tb->SendNextItem();
+ done(s, Args(), recv_args, val, rm.is_dead_);
+}
+
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env)
: BaseRendezvousMgr(env) {}