aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc114
1 files changed, 11 insertions, 103 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 74f6681af3..ad3dce1784 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,10 +21,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
-#endif // GOOGLE_CUDA
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -36,11 +32,6 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr)
: BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {}
- void RecvPostCopyOps(const string& key, const string& key_with_step_id,
- const Rendezvous::Args& recv_args,
- const DoneCallback& done, const RdmaMessage& rm,
- RdmaChannel* rc, Tensor& val, const Status& s);
-
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
@@ -74,101 +65,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
string key(std::move(parsed.FullKey().ToString()));
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
- // insert callback
- rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
- recv_args, parsed, done]() {
- Status src_s, dst_s, s;
- Device* src_dev, *dst_dev;
- src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
- dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
- if (!src_s.ok() || !dst_s.ok()) {
- s = src_s.ok() ? dst_s : src_s;
- LOG(ERROR) << "s is not ok, error code " << s.error_message();
- done(s, Args(), recv_args, Tensor(), true);
- return;
- }
- RdmaBuffer* rb = rc->FindBuffer(key);
- RdmaMessage rm;
- CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
- RdmaMessage::ParseMessage(rm, rb->buffer_);
- CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
- Tensor val;
- if (!rm.is_dead_) {
- void* input = static_cast<char*>(rb->buffer_) +
- RdmaMessage::kTensorBufferStartIndex;
- bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_);
- if (can_memcpy) {
- if (dst_dev->tensorflow_gpu_device_info() &&
- (!recv_args.alloc_attrs.on_host())) {
-#if GOOGLE_CUDA
- CHECK(recv_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
- Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
- memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
-
- Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
- Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
-
- GPUUtil::CopyCPUTensorToGPU(
- &copy, recv_args.device_context, dst_dev, &gpu_copy,
- [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc](
- const Status& s) {
- CHECK(s.ok()) << "copy tensor to gpu sync";
- Tensor val;
- val = std::move(gpu_copy);
- RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
- val, s);
- });
-#endif // GOOGLE_CUDA
- return;
- } else {
- AllocatorAttributes host_alloc_attrs;
- host_alloc_attrs.set_gpu_compatible(true);
- host_alloc_attrs.set_on_host(true);
- Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs);
- Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
- memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
- val = std::move(copy);
- }
- } else {
- TensorProto proto;
- CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
- rb->size_);
- CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
- << "fail to parse proto from array";
- s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
- }
- }
- RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s);
- });
- // append key to message queue
- RdmaBuffer* rb = rc->tx_message_buffer_;
- RdmaMessage rm;
- rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
- rm.name_size_ = key.size();
- rm.name_ = key;
- rm.step_id_ = step_id_;
- string message = RdmaMessage::CreateMessage(rm);
- rb->EnqueueItem(message);
- rb->SendNextItem();
-}
-void RdmaRemoteRendezvous::RecvPostCopyOps(
- const string& key, const string& key_with_step_id,
- const Rendezvous::Args& recv_args, const DoneCallback& done,
- const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) {
- rc->RemoveRecvCallback(key_with_step_id);
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
- br.name_size_ = key.size();
- br.name_ = key;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* tb = rc->tx_message_buffer_;
- tb->EnqueueItem(message);
- tb->SendNextItem();
- done(s, Args(), recv_args, val, rm.is_dead_);
+ Device* dst_dev;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+
+ RdmaTensorRequest* request =
+ rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done);
+ request->Start();
}
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env)