diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 55 |
1 files changed, 10 insertions, 45 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 6f3a616fe8..bc687be0ab 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/rendezvous.h" @@ -684,6 +683,7 @@ void RdmaTensorBuffer::SendNextItem() { << " error message: " << status.error_message(); size_t buffer_size = RdmaMessage::kMessageTotalBytes; size_t tensor_bytes = 0; + TensorProto proto; // Figures out which device the tensor is hosted on. Device* src_dev = nullptr; Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice( @@ -703,47 +703,21 @@ void RdmaTensorBuffer::SendNextItem() { CHECK(s.ok()) << "dst device not found"; AllocatorAttributes dst_alloc_attr; dst_alloc_attr.set_on_host(true); - - bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); // string tensor needs to be serialized - Tensor copy; - StringPiece copy_buf; - TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { CHECK(send_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - - if (can_memcpy) { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - copy = Tensor(alloc, in.dtype(), in.shape()); - s = VerbsUtil::CopyGPUTensorToCPUSync( - src_dev, send_args.device_context, &in, ©); - CHECK(s.ok()) << "copy tensor from gpu sync"; - copy_buf = copy.tensor_data(); - } else { - // "val" is on a GPU. Uses GPUUtil to fill the proto. - s = VerbsUtil::SetProtoFromGPUSync( - in, src_dev, send_args.device_context, &proto, is_dead); - CHECK(s.ok()) << "set proto from gpu sync"; - } + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + // "val" is on a GPU. Uses GPUUtil to fill the proto. + s = VerbsUtil::SetProtoFromGPUSync( + in, src_dev, send_args.device_context, &proto, is_dead); + CHECK(s.ok()) << "set proto from gpu sync"; } else { // tensor is in CPU memory. - if (can_memcpy) { - copy_buf = in.tensor_data(); - } else { - in.AsProtoTensorContent(&proto); - } - } - if (can_memcpy) { - tensor_bytes = in.TotalBytes(); - } else { - tensor_bytes = proto.ByteSize(); + in.AsProtoTensorContent(&proto); } + tensor_bytes = proto.ByteSize(); // maybe some margin for string tensor? buffer_size += tensor_bytes; // prepare message @@ -797,16 +771,7 @@ void RdmaTensorBuffer::SendNextItem() { static_cast<void*>(static_cast<char*>(buffer_) + RdmaMessage::kTensorBufferStartIndex); CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); - if (can_memcpy) { - CHECK(copy_buf.size() == tensor_bytes) - << "unexpected tensor size: " - << copy_buf.size() - << " != " - << tensor_bytes; - memcpy(output, copy_buf.data(), tensor_bytes); - } else { - proto.SerializeToArray(output, tensor_bytes); - } + proto.SerializeToArray(output, tensor_bytes); } else { buffer_size = RdmaMessage::kMessageTotalBytes; } |