aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc55
1 files changed, 45 insertions, 10 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index bc687be0ab..6f3a616fe8 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -683,7 +684,6 @@ void RdmaTensorBuffer::SendNextItem() {
<< " error message: " << status.error_message();
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
size_t tensor_bytes = 0;
- TensorProto proto;
// Figures out which device the tensor is hosted on.
Device* src_dev = nullptr;
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
@@ -703,21 +703,47 @@ void RdmaTensorBuffer::SendNextItem() {
CHECK(s.ok()) << "dst device not found";
AllocatorAttributes dst_alloc_attr;
dst_alloc_attr.set_on_host(true);
+
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
+ Tensor copy;
+ StringPiece copy_buf;
+ TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
CHECK(send_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
- // "val" is on a GPU. Uses GPUUtil to fill the proto.
- s = VerbsUtil::SetProtoFromGPUSync(
- in, src_dev, send_args.device_context, &proto, is_dead);
- CHECK(s.ok()) << "set proto from gpu sync";
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+
+ if (can_memcpy) {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ copy = Tensor(alloc, in.dtype(), in.shape());
+ s = VerbsUtil::CopyGPUTensorToCPUSync(
+ src_dev, send_args.device_context, &in, &copy);
+ CHECK(s.ok()) << "copy tensor from gpu sync";
+ copy_buf = copy.tensor_data();
+ } else {
+ // "val" is on a GPU. Uses GPUUtil to fill the proto.
+ s = VerbsUtil::SetProtoFromGPUSync(
+ in, src_dev, send_args.device_context, &proto, is_dead);
+ CHECK(s.ok()) << "set proto from gpu sync";
+ }
} else {
// tensor is in CPU memory.
- in.AsProtoTensorContent(&proto);
+ if (can_memcpy) {
+ copy_buf = in.tensor_data();
+ } else {
+ in.AsProtoTensorContent(&proto);
+ }
+ }
+ if (can_memcpy) {
+ tensor_bytes = in.TotalBytes();
+ } else {
+ tensor_bytes = proto.ByteSize();
}
- tensor_bytes = proto.ByteSize();
// maybe some margin for string tensor?
buffer_size += tensor_bytes;
// prepare message
@@ -771,7 +797,16 @@ void RdmaTensorBuffer::SendNextItem() {
static_cast<void*>(static_cast<char*>(buffer_) +
RdmaMessage::kTensorBufferStartIndex);
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- proto.SerializeToArray(output, tensor_bytes);
+ if (can_memcpy) {
+ CHECK(copy_buf.size() == tensor_bytes)
+ << "unexpected tensor size: "
+ << copy_buf.size()
+ << " != "
+ << tensor_bytes;
+ memcpy(output, copy_buf.data(), tensor_bytes);
+ } else {
+ proto.SerializeToArray(output, tensor_bytes);
+ }
} else {
buffer_size = RdmaMessage::kMessageTotalBytes;
}