aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc41
1 files changed, 35 insertions, 6 deletions
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 5871400f26..9ea696589a 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -99,12 +100,40 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
if (!rm.is_dead_) {
void* input = static_cast<char*>(rb->buffer_) +
RdmaMessage::kTensorBufferStartIndex;
- TensorProto proto;
- CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
- rb->size_);
- CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
- << "fail to parse proto from array";
- s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_);
+ if (can_memcpy) {
+ if (dst_dev->tensorflow_gpu_device_info() &&
+ (!recv_args.alloc_attrs.on_host())) {
+ CHECK(recv_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
+ memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
+
+ Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
+ Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
+ s = VerbsUtil::CopyCPUTensorToGPUSync(&copy, recv_args.device_context,
+ dst_dev, &gpu_copy);
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ val = std::move(gpu_copy);
+ } else {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs);
+ Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
+ memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
+ val = std::move(copy);
+ }
+ } else {
+ TensorProto proto;
+ CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
+ rb->size_);
+ CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
+ << "fail to parse proto from array";
+ s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ }
}
rc->RemoveRecvCallback(key_with_step_id);