diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 5871400f26..9ea696589a 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -99,12 +100,40 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( if (!rm.is_dead_) { void* input = static_cast<char*>(rb->buffer_) + RdmaMessage::kTensorBufferStartIndex; - TensorProto proto; - CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= - rb->size_); - CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) - << "fail to parse proto from array"; - s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); + bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_); + if (can_memcpy) { + if (dst_dev->tensorflow_gpu_device_info() && + (!recv_args.alloc_attrs.on_host())) { + CHECK(recv_args.device_context) + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); + memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); + + Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs); + Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_); + s = VerbsUtil::CopyCPUTensorToGPUSync(©, recv_args.device_context, + dst_dev, &gpu_copy); + CHECK(s.ok()) << "copy tensor to gpu sync"; + val = std::move(gpu_copy); + } else { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs); + Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); + memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); + val = std::move(copy); + } + } else { + TensorProto proto; + CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= + rb->size_); + CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) + << "fail to parse proto from array"; + s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); + } } rc->RemoveRecvCallback(key_with_step_id); |