aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-04-22 06:08:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-22 07:28:38 -0700
commit326942394e69074d50d5889218a24c9371eff259 (patch)
tree50c78852c36b828440761a16650718f224560f7b /tensorflow/contrib/verbs
parent3c0900a49c11b7975c7accc026153bbc2001c018 (diff)
Merge changes from github.
Change: 153925676
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r--tensorflow/contrib/verbs/BUILD168
-rw-r--r--tensorflow/contrib/verbs/README.md77
-rw-r--r--tensorflow/contrib/verbs/design_diagram.pngbin0 -> 13625 bytes
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.cc47
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.h50
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.cc165
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.h72
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.cc68
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h89
-rw-r--r--tensorflow/contrib/verbs/rdma.cc874
-rw-r--r--tensorflow/contrib/verbs/rdma.h277
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc133
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h54
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc149
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.h64
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc172
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.h66
-rw-r--r--tensorflow/contrib/verbs/verbs_service.proto60
-rw-r--r--tensorflow/contrib/verbs/verbs_util.cc61
-rw-r--r--tensorflow/contrib/verbs/verbs_util.h41
20 files changed, 2687 insertions, 0 deletions
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
new file mode 100644
index 0000000000..e747fa4c9e
--- /dev/null
+++ b/tensorflow/contrib/verbs/BUILD
@@ -0,0 +1,168 @@
+# Description:
+# Verbs RDMA communication interfaces and implementations for TensorFlow.
+
+package(default_visibility = [
+ "//tensorflow:__subpackages__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+)
+
+tf_proto_library_cc(
+ name = "verbs_service_proto",
+ srcs = ["verbs_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+cc_library(
+ name = "verbs_util",
+ srcs = ["verbs_util.cc"],
+ hdrs = ["verbs_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "grpc_verbs_service",
+ srcs = ["grpc_verbs_service.cc"],
+ hdrs = ["grpc_verbs_service.h"],
+ deps = [
+ ":grpc_verbs_service_impl",
+ ":rdma_mgr",
+ ":verbs_service_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_call",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ "@grpc//:grpc++_unsecure",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "grpc_verbs_service_impl",
+ srcs = ["grpc_verbs_service_impl.cc"],
+ hdrs = ["grpc_verbs_service_impl.h"],
+ deps = [
+ ":verbs_service_proto_cc",
+ "@grpc//:grpc++_unsecure",
+ ],
+)
+
+cc_library(
+ name = "grpc_verbs_client",
+ srcs = ["grpc_verbs_client.cc"],
+ hdrs = ["grpc_verbs_client.h"],
+ deps = [
+ ":grpc_verbs_service_impl",
+ ":verbs_service_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:call_options",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "rdma_rendezvous_mgr",
+ srcs = ["rdma_rendezvous_mgr.cc"],
+ hdrs = ["rdma_rendezvous_mgr.h"],
+ deps = [
+ ":rdma_mgr",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ ],
+)
+
+cc_library(
+ name = "rdma_mgr",
+ srcs = ["rdma_mgr.cc"],
+ hdrs = ["rdma_mgr.h"],
+ deps = [
+ ":grpc_verbs_client",
+ ":rdma",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ ],
+)
+
+cc_library(
+ name = "rdma",
+ srcs = ["rdma.cc"],
+ hdrs = ["rdma.h"],
+ linkopts = select({
+ "//tensorflow:with_verbs_support": ["-libverbs"],
+ "//conditions:default": [],
+ }),
+ deps = [
+ ":verbs_util",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
+ "//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ ],
+)
+
+cc_library(
+ name = "verbs_server_lib",
+ srcs = ["verbs_server_lib.cc"],
+ hdrs = ["verbs_server_lib.h"],
+ linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
+ deps = [
+ ":grpc_verbs_service",
+ ":rdma_mgr",
+ ":rdma_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
new file mode 100644
index 0000000000..37a543dda8
--- /dev/null
+++ b/tensorflow/contrib/verbs/README.md
@@ -0,0 +1,77 @@
+## How to compile and use Rdma-enabled tensorflow
+1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based Rdma support, answer yes to this question:
+
+ ```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
+
+2. To turn on Rdma connection, add the protocol "grpc+verbs" in server definition:
+
+ ```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
+
+## Overview
+The design is based on Tensorflow r1.0. An Rdma path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the Rdma path, exchanging computation graphs, etc.
+
+During the server setup, an Rdma manager is created to manage low-level Rdma components such as Rdma channel and Rdma adapter, an Rdma rendezvous manager is created to oversee send/recv operations between servers. Following the distributed Tensorflow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
+
+Tensorflow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for Rdma operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each Rdma channel, representing a Rdma connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned.
+
+When a tensor is prepared fro transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via Rdma write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory.
+![Tensorflow Rdma path](./design_diagram.png)
+
+The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption.
+## Design details
+
+### Rdma components
+
+* **Rdma adapter:** The base for Rdma communications. It may contain multiple channels and buffers. It is responsible for handling various incoming Rdma messages.
+* **Rdma channel:** Responsible for Rdma connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors.
+* **Rdma buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers.
+* **Rdma manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
+* **Rdma rendezvous manager:** manages multiple rdma rendezvous.
+* **Rdma rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
+
+### The SEND operation
+
+In tensorflow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node.
+
+
+### The RECV operation
+
+When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding Rdma buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the Rdma buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transimssion can be deferred is when the buffer is still being used by an on-going transmission.
+
+### Three types of Rdma buffers
+
+* **Message buffer:** responsible for sending message only.
+* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer.
+* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer.
+
+### Rdma packet format
+
+|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer|
+
+### Six types of Rdma messages
+* RDMA_MESSAGE_ACK
+* RDMA_MESSAGE_BUFFER_IDLE
+* RDMA_MESSAGE_BUFFER_REQUEST
+* RDMA_MESSAGE_BUFFER_RESPONSE
+* RDMA_MESSAGE_TENSOR_REQUEST
+* RDMA_MESSAGE_TENSOR_WRITE
+
+### Actions upon receiving Rdma messages
+* RDMA_MESSAGE_ACK
+ * sender: mark local ack buffer idle.
+ * receiver: mark remote message buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_IDLE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, set remote tensor buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_REQUEST
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE.
+* RDMA_MESSAGE_BUFFER_RESPONSE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item.
+* RDMA_MESSAGE_TENSOR_REQUEST
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item.
+* RDMA_MESSAGE_TENSOR_WRITE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: run callback.
diff --git a/tensorflow/contrib/verbs/design_diagram.png b/tensorflow/contrib/verbs/design_diagram.png
new file mode 100644
index 0000000000..f0ad27455f
--- /dev/null
+++ b/tensorflow/contrib/verbs/design_diagram.png
Binary files differ
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.cc b/tensorflow/contrib/verbs/grpc_verbs_client.cc
new file mode 100644
index 0000000000..608a9140d3
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ ::grpc::ClientContext ctx;
+ ctx.set_fail_fast(false);
+ SetDeadline(&ctx, call_options->GetTimeout());
+ return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response));
+}
+
+Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ CallOptions call_options;
+ call_options.SetTimeout(-1); // no time out
+ return GetRemoteAddress(&call_options, request, response);
+}
+
+void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
+ int64 time_in_ms) {
+ if (time_in_ms > 0) {
+ ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
+ }
+}
+
+} // namespace tensorflow \ No newline at end of file
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h
new file mode 100644
index 0000000000..358977f925
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service.
+class GrpcVerbsClient {
+ public:
+ explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
+ : stub_(grpc::VerbsService::NewStub(client_channel)) {}
+ ~GrpcVerbsClient() {}
+
+ Status GetRemoteAddress(CallOptions* call_options,
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+ Status GetRemoteAddress(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+
+ private:
+ std::unique_ptr<grpc::VerbsService::Stub> stub_;
+
+ void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
new file mode 100644
index 0000000000..e73b2700bd
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "grpc++/alarm.h"
+#include "grpc++/grpc++.h"
+#include "grpc++/server_builder.h"
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+
+namespace tensorflow {
+
+GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder)
+ : is_shutdown_(false), worker_env_(worker_env) {
+ builder->RegisterService(&verbs_service_);
+ cq_ = builder->AddCompletionQueue().release();
+}
+
+GrpcVerbsService::~GrpcVerbsService() {
+ delete shutdown_alarm_;
+ delete cq_;
+}
+
+void GrpcVerbsService::Shutdown() {
+ bool did_shutdown = false;
+ {
+ mutex_lock l(shutdown_mu_);
+ if (!is_shutdown_) {
+ LOG(INFO) << "Shutting down GrpcWorkerService.";
+ is_shutdown_ = true;
+ did_shutdown = true;
+ }
+ }
+ if (did_shutdown) {
+ shutdown_alarm_ =
+ new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
+ }
+}
+
+// This macro creates a new request for the given RPC method name
+// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
+// `this->cq_`.
+//
+// This macro is invoked one or more times for each RPC method to
+// ensure that there are sufficient completion queue entries to
+// handle incoming requests without blocking.
+//
+// The implementation of the request handler for each RPC method
+// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
+// to keep accepting new requests.
+#define ENQUEUE_REQUEST(method, supports_cancel) \
+ do { \
+ mutex_lock l(shutdown_mu_); \
+ if (!is_shutdown_) { \
+ Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
+ method##Request, method##Response>:: \
+ EnqueueRequest(&verbs_service_, cq_, \
+ &grpc::VerbsService::AsyncService::Request##method, \
+ &GrpcVerbsService::method##Handler, \
+ (supports_cancel)); \
+ } \
+ } while (0)
+
+// This method blocks forever handling requests from the completion queue.
+void GrpcVerbsService::HandleRPCsLoop() {
+ for (int i = 0; i < 10; ++i) {
+ ENQUEUE_REQUEST(GetRemoteAddress, false);
+ }
+
+ void* tag;
+ bool ok;
+
+ while (cq_->Next(&tag, &ok)) {
+ UntypedCall<GrpcVerbsService>::Tag* callback_tag =
+ static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
+ if (callback_tag) {
+ callback_tag->OnCompleted(this, ok);
+ } else {
+ cq_->Shutdown();
+ }
+ }
+}
+
+void GrpcVerbsService::GetRemoteAddressHandler(
+ WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
+ Status s = GetRemoteAddressSync(&call->request, &call->response);
+ call->SendResponse(ToGrpcStatus(s));
+ ENQUEUE_REQUEST(GetRemoteAddress, false);
+}
+
+// synchronous method
+Status GrpcVerbsService::GetRemoteAddressSync(
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ // analyzing request
+ // the channel setting part is redundant.
+ const string remote_host_name = request->host_name();
+ RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
+ CHECK(rc);
+ RdmaAddress ra;
+ ra.lid = request->channel().lid();
+ ra.qpn = request->channel().qpn();
+ ra.psn = request->channel().psn();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0, 3, 2};
+ std::vector<RdmaBuffer*> mb(rc->message_buffers());
+ CHECK_EQ(request->mr_size(), 4);
+ for (const auto& mr : request->mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // local tx_ack_buffer <---> remote rx_ack_buffer_
+ // local rx_ack_buffer <---> remote tx_ack_buffer_
+ // hence idx[] = {1, 0, 3, 2}.
+ RdmaBuffer* rb = mb[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+
+ // setting up response
+ response->set_host_name(
+ worker_env_->session_mgr->LegacySession()->worker_name);
+ Channel* channel_info = response->mutable_channel();
+ channel_info->set_lid(rc->self().lid);
+ channel_info->set_qpn(rc->self().qpn);
+ channel_info->set_psn(rc->self().psn);
+ for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+ MemoryRegion* mr = response->add_mr();
+ mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
+ mr->set_rkey(mb[i]->self()->rkey);
+ }
+ return Status::OK();
+}
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder) {
+ *handle = new GrpcVerbsService(worker_env, builder);
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h
new file mode 100644
index 0000000000..aa509602b5
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace grpc {
+class ServerBuilder;
+class ServerCompletionQueue;
+class Alarm;
+} // namespace grpc
+
+namespace tensorflow {
+
+class GrpcVerbsService : public AsyncServiceInterface {
+ public:
+ GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder);
+ ~GrpcVerbsService();
+ void HandleRPCsLoop() override;
+ void Shutdown() override;
+ void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ private:
+ template <class RequestMessage, class ResponseMessage>
+ using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
+ RequestMessage, ResponseMessage>;
+ void GetRemoteAddressHandler(
+ WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
+ Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+
+ ::grpc::ServerCompletionQueue* cq_;
+ grpc::VerbsService::AsyncService verbs_service_;
+ mutex shutdown_mu_;
+ bool is_shutdown_ GUARDED_BY(shutdown_mu_);
+ ::grpc::Alarm* shutdown_alarm_;
+ // not owned
+ RdmaMgr* rdma_mgr_;
+ const WorkerEnv* const worker_env_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
+};
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
new file mode 100644
index 0000000000..e0ba78dbfd
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
@@ -0,0 +1,68 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/channel_interface.h"
+#include "grpc++/impl/codegen/client_unary_call.h"
+#include "grpc++/impl/codegen/method_handler_impl.h"
+#include "grpc++/impl/codegen/rpc_service_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+namespace tensorflow {
+
+namespace grpc {
+
+static const char* grpcVerbsService_method_names[] = {
+ "/tensorflow.VerbsService/GetRemoteAddress",
+};
+
+std::unique_ptr<VerbsService::Stub> VerbsService::NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options) {
+ std::unique_ptr<VerbsService::Stub> stub(new VerbsService::Stub(channel));
+ return stub;
+}
+
+VerbsService::Stub::Stub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel)
+ : channel_(channel),
+ rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
+ ::grpc::RpcMethod::NORMAL_RPC, channel) {}
+
+::grpc::Status VerbsService::Stub::GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) {
+ return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_GetRemoteAddress_,
+ context, request, response);
+}
+
+VerbsService::AsyncService::AsyncService() {
+ for (int i = 0; i < 1; ++i) {
+ AddMethod(new ::grpc::RpcServiceMethod(grpcVerbsService_method_names[i],
+ ::grpc::RpcMethod::NORMAL_RPC,
+ nullptr));
+ ::grpc::Service::MarkMethodAsync(i);
+ }
+}
+
+VerbsService::AsyncService::~AsyncService() {}
+
+} // namespace grpc
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
new file mode 100644
index 0000000000..f7ea774b66
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -0,0 +1,89 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/proto_utils.h"
+#include "grpc++/impl/codegen/rpc_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/status.h"
+#include "grpc++/impl/codegen/stub_options.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+
+namespace grpc {
+class CompletionQueue;
+class Channel;
+class RpcService;
+class ServerCompletionQueue;
+class ServerContext;
+} // namespace grpc
+
+namespace tensorflow {
+
+namespace grpc {
+
+// Implementation of `tensorflow.VerbsService`, based on the
+// definition in "//tensorflow/contrib/verbs/verbs_service.proto",
+// and the gRPC generated stub and service classes.
+// See the proto file for the definition of methods and messages.
+class VerbsService GRPC_FINAL {
+ public:
+ class StubInterface {
+ public:
+ virtual ~StubInterface() {}
+ virtual ::grpc::Status GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) = 0;
+ };
+ class Stub GRPC_FINAL : public StubInterface {
+ public:
+ Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
+ ::grpc::Status GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) GRPC_OVERRIDE;
+
+ private:
+ std::shared_ptr< ::grpc::ChannelInterface> channel_;
+ const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_;
+ };
+ static std::unique_ptr<Stub> NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options = ::grpc::StubOptions());
+
+ class AsyncService : public ::grpc::Service {
+ public:
+ AsyncService();
+ virtual ~AsyncService();
+ void RequestGetRemoteAddress(
+ ::grpc::ServerContext* context, GetRemoteAddressRequest* request,
+ ::grpc::ServerAsyncResponseWriter<GetRemoteAddressResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(0, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ };
+};
+
+} // namespace grpc
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
new file mode 100644
index 0000000000..53d840f5d1
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -0,0 +1,874 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include <cstdlib>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+namespace {
+// hash name to 32-bit integer
+uint32_t NameHash(const string& name) {
+ return Hash32(name.data(), name.size(), 0x1234ABCD);
+}
+
+// convenience function for printing message
+string MessageTypeToString(RdmaMessageType rmt) {
+ switch (rmt) {
+ case RDMA_MESSAGE_ACK:
+ return "RDMA_MESSAGE_ACK";
+ break;
+ case RDMA_MESSAGE_BUFFER_IDLE:
+ return "RDMA_MESSAGE_BUFFER_IDLE";
+ break;
+ case RDMA_MESSAGE_BUFFER_REQUEST:
+ return "RDMA_MESSAGE_BUFFER_REQUEST";
+ break;
+ case RDMA_MESSAGE_BUFFER_RESPONSE:
+ return "RDMA_MESSAGE_BUFFER_RESPONSE";
+ break;
+ case RDMA_MESSAGE_TENSOR_REQUEST:
+ return "RDMA_MESSAGE_TENSOR_REQUEST";
+ break;
+ case RDMA_MESSAGE_TENSOR_WRITE:
+ return "RDMA_MESSAGE_TENSOR_WRITE";
+ break;
+ default:
+ return "UNKNOWN MESSAGE";
+ }
+}
+} // namespace
+
+ibv_context* open_default_device() {
+ ibv_device** dev_list;
+ ibv_device* ib_dev;
+ dev_list = ibv_get_device_list(NULL);
+ CHECK(dev_list) << "No InfiniBand device found";
+ ib_dev = dev_list[0];
+ CHECK(ib_dev) << "No InfiniBand device found";
+ ibv_context* context = ibv_open_device(ib_dev);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
+ return context;
+}
+
+ibv_pd* alloc_protection_domain(ibv_context* context) {
+ ibv_pd* pd = ibv_alloc_pd(context);
+ CHECK(pd) << "Failed to allocate protection domain";
+ return pd;
+}
+
+RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
+ : context_(open_default_device()),
+ pd_(alloc_protection_domain(context_)),
+ worker_env_(worker_env) {
+ event_channel_ = ibv_create_comp_channel(context_);
+ CHECK(event_channel_) << "Failed to create completion channel";
+ cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
+ 0);
+ CHECK(cq_) << "Failed to create completion queue";
+ CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
+ polling_thread_.reset(Env::Default()->StartThread(
+ ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
+ VLOG(2) << "Start RdmaAdapter: " << name();
+}
+
+RdmaAdapter::~RdmaAdapter() {
+ polling_thread_.reset();
+ CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
+ CHECK(!ibv_destroy_comp_channel(event_channel_))
+ << "Failed to destroy channel";
+ CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
+ CHECK(!ibv_close_device(context_)) << "Failed to release context";
+}
+
+string RdmaAdapter::name() const { return string(context_->device->name); }
+
+// Function to process incoming messages
+// There are two types of messages:
+// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
+// 2. IBV_WC_RDMA_WRITE (send))
+void RdmaAdapter::Process_CQ() {
+ while (true) {
+ ibv_cq* cq;
+ void* cq_context;
+ CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
+ CHECK(cq == cq_);
+ ibv_ack_cq_events(cq, 1);
+ CHECK(!ibv_req_notify_cq(cq_, 0));
+
+ int ne =
+ ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
+ CHECK_GE(ne, 0);
+ for (int i = 0; i < ne; ++i) {
+ CHECK(wc_[i].status == IBV_WC_SUCCESS)
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
+ if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
+ RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
+ // put back a recv wr.
+ rc->Recv();
+ // imm_data is the index of RX buffer in the buffer table.
+ uint32_t imm_data = wc_[i].imm_data;
+ RdmaBuffer* rb = rc->FindBuffer(imm_data);
+ RdmaMessage rm;
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
+
+ if (rm.type_ == RDMA_MESSAGE_ACK) {
+ // receive an ack to a message
+ rb = rc->tx_message_buffer_;
+ rb->SetBufferStatus(remote, idle);
+ rb->SendNextItem();
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
+ // received a request-for-tensor message
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find or create buffer
+ RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
+ string key_with_step_id =
+ VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+ tb->EnqueueItem(key_with_step_id);
+ // send the next tensor
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
+ // receive tensor-buffer-ready message
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find buffer
+ RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+ tb->SetBufferStatus(remote, idle);
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
+ // remote host requests to create a tensor buffer;
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find or create the buffer
+ RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
+ RemoteMR rmr;
+ rmr.remote_addr = rm.remote_addr_;
+ rmr.rkey = rm.rkey_;
+ tb->SetRemoteMR(rmr, true);
+ tb->CreateCPUBuffer(rm.buffer_size_);
+ // create RDMA_MESSAGE_BUFFER_RESPONSE message
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
+ br.name_size_ = rm.name_.size();
+ br.name_ = rm.name_;
+ br.buffer_size_ = rm.buffer_size_;
+ br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
+ br.rkey_ = tb->self_->rkey;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* mb = rc->tx_message_buffer_;
+ mb->EnqueueItem(message);
+ mb->SendNextItem();
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
+ // remote creates a buffer and responds
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find buffer
+ RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+ CHECK(rm.buffer_size_ == tb->size_)
+ << "rm.buffer_size = " << rm.buffer_size_
+ << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
+ RemoteMR rmr;
+ rmr.remote_addr = rm.remote_addr_;
+ rmr.rkey = rm.rkey_;
+ tb->SetRemoteMR(rmr, true);
+ tb->SetBufferStatus(local, idle);
+ tb->SetBufferStatus(remote, idle);
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ // tensor RDMA write completed
+ worker_env_->compute_pool->Schedule([rm, rc]() {
+ string key_with_step_id =
+ VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+ rc->RunRecvCallback(key_with_step_id);
+ });
+ }
+ } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
+ RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
+ rb->SetBufferStatus(local, idle);
+ RdmaMessage rm;
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
+ if (rm.type_ != RDMA_MESSAGE_ACK) {
+ worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
+ }
+ }
+ }
+ }
+}
+
+RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+ const string remote_name)
+ : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
+ // Create queue pair
+ {
+ struct ibv_qp_init_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_init_attr));
+ attr.send_cq = adapter_->cq_;
+ attr.recv_cq = adapter_->cq_;
+ attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_send_sge = 1;
+ attr.cap.max_recv_sge = 1;
+ attr.qp_type = IBV_QPT_RC;
+
+ qp_ = ibv_create_qp(adapter_->pd_, &attr);
+ CHECK(qp_) << "Failed to create queue pair";
+ }
+
+ // Init queue pair
+ {
+ struct ibv_qp_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_INIT;
+ attr.pkey_index = 0;
+ attr.port_num = 1;
+ attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
+
+ int mask =
+ IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
+ CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
+ }
+
+ // Local address
+ {
+ struct ibv_port_attr attr;
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
+ << "Query port";
+ self_.lid = attr.lid;
+ self_.qpn = qp_->qp_num;
+ self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
+ }
+
+ // create message and ack buffers, then initialize the tables.
+ {
+ const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
+ "tx_ack_buffer", "rx_ack_buffer"};
+ tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
+ rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
+ tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
+ rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
+ message_buffers_.reserve(kNumMessageBuffers);
+ message_buffers_.push_back(tx_message_buffer_);
+ message_buffers_.push_back(rx_message_buffer_);
+ message_buffers_.push_back(tx_ack_buffer_);
+ message_buffers_.push_back(rx_ack_buffer_);
+ // create buffer on host
+ tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+ rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+ tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+ rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+ // bt_mu_.lock() is not used in constructor.
+ for (int i = 0; i < kNumMessageBuffers; i++) {
+ uint32_t index = NameHash(buffer_names[i]);
+ buffer_table_.insert({index, message_buffers_[i]});
+ buffer_index_name_table_.insert({index, buffer_names[i]});
+ buffer_name_index_table_.insert({buffer_names[i], index});
+ }
+
+ // Initiate recv
+ for (int i = 0; i < 100; i++) {
+ Recv();
+ }
+ }
+}
+
+RdmaChannel::~RdmaChannel() {
+ CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
+ delete tx_message_buffer_;
+ delete rx_message_buffer_;
+ delete tx_ack_buffer_;
+ delete rx_ack_buffer_;
+}
+
+void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
+ mutex_lock lock{mu_};
+ if ((override) || (!remote_set_)) {
+ remote_.lid = ra.lid;
+ remote_.qpn = ra.qpn;
+ remote_.psn = ra.psn;
+ remote_set_ = true;
+ } else {
+ CHECK(remote_.lid == ra.lid);
+ CHECK(remote_.qpn == ra.qpn);
+ CHECK(remote_.psn == ra.psn);
+ }
+}
+
+// Adding tokens to the completion queue
+// Tokens are needed to process future messages.
+void RdmaChannel::Recv() {
+ struct ibv_recv_wr wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = (uint64_t)this;
+ struct ibv_recv_wr* bad_wr;
+ CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
+}
+
+// Lookup 32-bit buffer index from buffer name
+// Args:
+// buffer_name: name of the buffer
+// Returns:
+// 32-bit index
+uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
+ mutex_lock lock{bt_mu_};
+ BufferNameIndexTable::iterator iter =
+ buffer_name_index_table_.find(buffer_name);
+ CHECK(iter != buffer_name_index_table_.end());
+ return iter->second;
+}
+
+// Find a buffer by its 32-bit index
+// Args:
+// index: 32-bit hash code of the tensor buffer name
+// Returns:
+// name of the tensor buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
+ mutex_lock lock{bt_mu_};
+ BufferTable::iterator iter = buffer_table_.find(index);
+ CHECK(iter != buffer_table_.end());
+ return iter->second;
+}
+
+// Find a buffer by its name
+// Args:
+// name: name of the buffer
+// Returns:
+// the named rdma buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
+ uint32_t index = LookupBufferIndex(name);
+ return FindBuffer(index);
+}
+
+// Find a buffer if it exists, otherwise create one.
+// The memory inside the created buffer is not allocated.
+// Args:
+// name: the name of the buffer
+// buffer_type: TENSOR, MESSAGE or ACK.
+// Returns:
+// the named buffer
+RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
+ BufferType buffer_type) {
+ mutex_lock lock{bt_mu_};
+ RdmaBuffer* rb;
+ // find index
+ BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
+ if (iter != buffer_name_index_table_.end()) {
+ uint32_t index = iter->second;
+ // find buffer
+ BufferTable::iterator iter = buffer_table_.find(index);
+ CHECK(iter != buffer_table_.end());
+ rb = iter->second;
+ } else {
+ uint32_t index = NameHash(name);
+ if (buffer_type == TENSOR) {
+ rb = new RdmaTensorBuffer(this, name);
+ } else if (buffer_type == MESSAGE) {
+ rb = new RdmaMessageBuffer(this, name);
+ } else if (buffer_type == ACK) {
+ rb = new RdmaAckBuffer(this, name);
+ }
+ buffer_name_index_table_.insert({name, index});
+ buffer_index_name_table_.insert({index, name});
+ buffer_table_.insert({index, rb});
+ }
+ CHECK(rb);
+ return rb;
+}
+
+// Insert callback to the callback_table.
+// The callback is activated when the corresponding tensor is received.
+// Arg:
+// key: the name of the tensor
+// recv_done: the callback associated with the tensor.
+// Returns:
+// None
+void RdmaChannel::InsertRecvCallback(const string& key,
+ std::function<void()> recv_done) {
+ mutex_lock lock{ct_mu_};
+ callback_table_.insert({key, recv_done});
+}
+
+// Remove callback from the callback_table.
+// Arg:
+// key: the name of the tensor
+// Returns:
+// None
+void RdmaChannel::RemoveRecvCallback(const string& key) {
+ mutex_lock lock{ct_mu_};
+ callback_table_.erase(key);
+}
+
+// Run named callback in the callback_table.
+// Arg:
+// key: the name of the tensor
+// Returns:
+// None
+void RdmaChannel::RunRecvCallback(const string& key) {
+ std::function<void()> recv_done;
+ {
+ mutex_lock lock{ct_mu_};
+ CallbackTable::iterator iter = callback_table_.find(key);
+ CHECK(iter != callback_table_.end());
+ recv_done = iter->second;
+ }
+ recv_done();
+}
+
+void RdmaChannel::Connect() {
+ {
+ mutex_lock lock{mu_};
+ CHECK(remote_set_) << "remote channel is not set";
+ }
+ Connect(remote_);
+}
+
+// Setup channel to a remote node
+// Args:
+// remoteAddr: the rdma address of a remote channel.
+// Returns:
+// None
+void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
+ mutex_lock lock{mu_};
+ if (!connected_) {
+ struct ibv_qp_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_RTR;
+ attr.path_mtu = IBV_MTU_4096;
+ attr.dest_qp_num = remoteAddr.qpn;
+ attr.rq_psn = remoteAddr.psn;
+ attr.max_dest_rd_atomic = 1;
+ attr.min_rnr_timer = 12;
+ attr.ah_attr.is_global = 0;
+ attr.ah_attr.dlid = remoteAddr.lid;
+ attr.ah_attr.sl = 0;
+ attr.ah_attr.src_path_bits = 0;
+ attr.ah_attr.port_num = 1;
+
+ int r;
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
+ << "QP to Ready to Receive " << r;
+
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_RTS;
+ attr.sq_psn = self_.psn;
+ attr.timeout = 14;
+ attr.retry_cnt = 7;
+ attr.rnr_retry = 7; /* infinite */
+ attr.max_rd_atomic = 1;
+
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
+ << "QP to Ready to Send " << r;
+
+ connected_ = true;
+ } else {
+ LOG(INFO) << "channel already connected";
+ }
+}
+
+RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
+ : channel_(channel), name_(name) {}
+
+RdmaBuffer::~RdmaBuffer() {
+ CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+ FreeBuffer();
+}
+
+void RdmaBuffer::FreeBuffer() {
+ if ((buffer_ != nullptr) && buffer_on_host_) {
+ free(buffer_);
+ }
+ // TODO
+ // release buffer if it is on device.
+ // We don't support RDMABuffer on device at this moment.
+}
+
+// Allocate CPU memory for the Rdma buffer
+// Args:
+// size: to-be-allocated memory size
+// lock: whether or not mutex_lock the process to protect concurrency.
+// Returns:
+// None
+void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
+ CHECK(size > 0);
+ if (lock) {
+ mu_.lock();
+ }
+ if (local_status_ != none) {
+ // delete existing buffer
+ CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+ FreeBuffer();
+ }
+ size_ = size;
+ buffer_ = malloc(size_);
+ self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ CHECK(self_) << "Failed to register memory region";
+ buffer_on_host_ = true;
+ local_status_ = idle;
+ if (lock) {
+ mu_.unlock();
+ }
+}
+
+// Set address of remote memory region
+// Args:
+// rmr: address of remote memory region
+// override: whether override existing information
+// Returns:
+// None
+void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
+ mutex_lock lock{mu_};
+ if ((override) || (remote_status_ == none)) {
+ remote_.remote_addr = rmr.remote_addr;
+ remote_.rkey = rmr.rkey;
+ remote_status_ = idle;
+ } else {
+ CHECK(remote_.remote_addr == rmr.remote_addr);
+ CHECK(remote_.rkey == rmr.rkey);
+ }
+}
+
+// Put a task in the buffer's job queue
+void RdmaBuffer::EnqueueItem(string item) {
+ mutex_lock lock{mu_};
+ queue_.push(item);
+}
+
+// Rdma-Write the content of the buffer
+void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+ struct ibv_sge list;
+ list.addr = (uint64_t)buffer_;
+ list.length = buffer_size;
+ list.lkey = self_->lkey;
+
+ struct ibv_send_wr wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = (uint64_t)this;
+ wr.sg_list = &list;
+ wr.num_sge = 1;
+ wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
+ wr.send_flags = IBV_SEND_SIGNALED;
+ wr.imm_data = imm_data;
+ wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
+ wr.wr.rdma.rkey = remote_.rkey;
+
+ struct ibv_send_wr* bad_wr;
+ CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
+}
+
+RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+// Send the next ack from the buffer's job queue.
+void RdmaAckBuffer::SendNextItem() {
+ uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
+ RdmaMessage rm;
+ rm.name_ = "rx_ack_buffer";
+ rm.type_ = RDMA_MESSAGE_ACK;
+ rm.name_size_ = rm.name_.size();
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ Write(imm_data, message.size());
+}
+
+// Send the next message from the buffer's job queue.
+void RdmaMessageBuffer::SendNextItem() {
+ uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
+ mu_.lock();
+ if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
+ local_status_ = busy;
+ remote_status_ = busy;
+ string message = queue_.front();
+ queue_.pop();
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ memcpy(buffer_, message.data(), message.size());
+ Write(imm_data, message.size());
+ } else {
+ mu_.unlock();
+ }
+}
+
+// Send the next tensor from the buffer's job queue.
+void RdmaTensorBuffer::SendNextItem() {
+ // get the key
+ string key_with_step_id = "";
+ {
+ mutex_lock lock{mu_};
+ if (!queue_.empty()) {
+ key_with_step_id = queue_.front();
+ queue_.pop();
+ }
+ }
+ // send the tensor if a key is acquired.
+ if (key_with_step_id != "") {
+ VLOG(2) << "try to send tensor: " << key_with_step_id;
+ string key;
+ int64 step_id;
+ VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
+ CHECK(key.compare(name_) == 0);
+ Rendezvous::ParsedKey parsed;
+ Rendezvous::ParseKey(key, &parsed);
+ Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
+ parsed](const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
+ << " error message: " << status.error_message();
+ size_t buffer_size = RdmaMessage::kMessageTotalBytes;
+ size_t tensor_bytes = 0;
+ TensorProto proto;
+ // Figures out which device the tensor is hosted on.
+ Device* src_dev = nullptr;
+ Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
+ parsed.src_device, &src_dev);
+ CHECK(s.ok()) << "src device not found";
+ // Does the device have the right incarnation number we expect?
+ CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
+ << "RecvTensor expects a different device incarnation: "
+ << parsed.src_incarnation << " vs. "
+ << src_dev->attributes().incarnation()
+ << ". Your worker job was probably restarted. Check your "
+ << "worker job for the reason why it was restarted.";
+ Device* dst_dev = nullptr;
+ // destination is on CPU.
+ s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
+ &dst_dev);
+ CHECK(s.ok()) << "dst device not found";
+ AllocatorAttributes dst_alloc_attr;
+ dst_alloc_attr.set_on_host(true);
+ // string tensor needs to be serialized
+ if (src_dev->tensorflow_gpu_device_info() &&
+ (!send_args.alloc_attrs.on_host())) {
+ CHECK(send_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ // "val" is on a GPU. Uses GPUUtil to fill the proto.
+ s = VerbsUtil::SetProtoFromGPUSync(
+ in, src_dev, send_args.device_context, &proto, is_dead);
+ CHECK(s.ok()) << "set proto from gpu sync";
+ } else {
+ // tensor is in CPU memory.
+ in.AsProtoTensorContent(&proto);
+ }
+ tensor_bytes = proto.ByteSize();
+ // maybe some margin for string tensor?
+ buffer_size += tensor_bytes;
+ // prepare message
+ RdmaMessage rm;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.tensor_shape_ = in.shape();
+ rm.data_type_ = in.dtype();
+ rm.step_id_ = step_id;
+ rm.is_dead_ = is_dead;
+ rm.tensor_bytes_ = tensor_bytes;
+ rm.buffer_size_ = buffer_size;
+ mu_.lock();
+ if (local_status_ == none ||
+ (buffer_size > size_ && local_status_ == idle &&
+ remote_status_ == idle)) {
+ if ((local_status_ != none) && (buffer_size > size_)) {
+ CHECK(rm.data_type_ == DT_STRING)
+ << "Only string tensor allows to change size";
+ }
+ CreateCPUBuffer(buffer_size, false);
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ // ask the remote to create the same buffer
+ rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
+ rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
+ rm.rkey_ = self_->rkey;
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+ } else if ((local_status_ == idle) && (remote_status_ == idle)) {
+ // both buffers are ready, send the tensor
+ local_status_ = busy;
+ remote_status_ = busy;
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
+ (buffer_size <= size_ && rm.data_type_ == DT_STRING))
+ << "tensor and buffer size do not agree!"
+ << " buffer_size = " << size_
+ << " requested tensor size = " << buffer_size << in.DebugString();
+ uint32_t imm_data = LookupBufferIndex(key);
+ rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ if (!is_dead) {
+ // copy the tensor buffer content
+ void* output =
+ static_cast<void*>(static_cast<char*>(buffer_) +
+ RdmaMessage::kTensorBufferStartIndex);
+ CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
+ proto.SerializeToArray(output, tensor_bytes);
+ } else {
+ buffer_size = RdmaMessage::kMessageTotalBytes;
+ }
+ Write(imm_data, buffer_size);
+ } else {
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ }
+ };
+ // Use default session (legacy_session_)
+ // TODO use WorkerSessionForSession
+ // need to pass in session handle
+ channel_->adapter_->worker_env_->session_mgr->LegacySession()
+ ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
+ }
+}
+
+// Create a RdmaMessage according to the pre-defined format
+// Args:
+// rm: the message structure
+// Returns:
+// message in string format
+string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
+ // Rdma Message format
+ // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+ // ...| XB | XB | 8B |...
+ //
+ // ACK: type|13|"rx_ack_buffer"
+ // TENSOR_REQUEST: type|name_size|tensor_name|step_id
+ // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
+ // |data_type|tensor_shape|tensor_bytes
+ // BUFFER_IDLE: type|name_size|buffer_name
+ // BUFFER_REQUEST:
+ // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+ // BUFFER_RESPONSE:
+ // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+ char message[kMessageTotalBytes];
+ // type
+ message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
+ // size of name
+ memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
+ // name
+ memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
+ // buffer_size, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+ memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
+ sizeof(rm.buffer_size_));
+ memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
+ sizeof(rm.remote_addr_));
+ memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
+ }
+ // step_id
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+ memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
+ }
+ // is_dead, data_type, tensor_shape, tensor_bytes
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
+
+ memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
+ sizeof(rm.data_type_));
+ memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
+ sizeof(rm.tensor_shape_));
+ memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
+ sizeof(rm.tensor_bytes_));
+ }
+ return string(message, kMessageTotalBytes);
+}
+
+// Parse a RdmaMessage according to the pre-defined format
+// Args:
+// rm: the message structure where the parsed message will be saved
+// buffer: the place where the raw message is stored
+// Returns:
+// None
+void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
+ char* message = static_cast<char*>(buffer);
+ // type
+ rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
+ // name_size_
+ memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
+ // name
+ rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
+ // buffer_size, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+ memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
+ sizeof(rm.buffer_size_));
+ memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
+ sizeof(rm.remote_addr_));
+ memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
+ }
+ // step_id
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+ memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
+ }
+ // data_type, tensor_bytes, tensor_shape, is_dead
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
+ memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
+ sizeof(rm.data_type_));
+ memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
+ sizeof(rm.tensor_shape_));
+ memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
+ sizeof(rm.tensor_bytes_));
+ }
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
new file mode 100644
index 0000000000..ae2aa63e3f
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <infiniband/verbs.h>
+#include <cstring> // for memset
+#include <functional>
+#include <memory> // for shared_ptr
+#include <queue>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// structure to save the address of remote channels.
+struct RdmaAddress {
+ uint32_t lid;
+ uint32_t qpn;
+ uint32_t psn;
+};
+// structure to save information for remote memory regions.
+struct RemoteMR {
+ uint64_t remote_addr;
+ uint32_t rkey;
+};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
+enum BufferType { ACK, MESSAGE, TENSOR };
+enum RdmaMessageType {
+ RDMA_MESSAGE_ACK,
+ RDMA_MESSAGE_BUFFER_IDLE,
+ RDMA_MESSAGE_BUFFER_REQUEST,
+ RDMA_MESSAGE_BUFFER_RESPONSE,
+ RDMA_MESSAGE_TENSOR_REQUEST,
+ RDMA_MESSAGE_TENSOR_WRITE
+};
+class RdmaBuffer;
+// Class that represents the Rdma Adapter.
+// Responsible for creation of the completion queue, and handling
+// of work completions.
+class RdmaAdapter {
+ friend class RdmaChannel;
+ friend class RdmaBuffer;
+ friend class RdmaAckBuffer;
+ friend class RdmaMessageBuffer;
+ friend class RdmaTensorBuffer;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ RdmaAdapter(const WorkerEnv* worker_env);
+ ~RdmaAdapter();
+ // Adapter name, e.g. mlx5_0.
+ string name() const;
+ void Process_CQ();
+
+ protected:
+ static const int MAX_CONCURRENT_WRITES = 1000;
+ ibv_context* context_;
+ // ibverbs protection domain
+ ibv_pd* pd_;
+ // Completion event channel, to wait for work completions
+ ibv_comp_channel* event_channel_;
+ // Completion queue, to poll on work completions
+ ibv_cq* cq_;
+ // Pre-allocated work completions array used for polling
+ ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
+ // worker env for thread
+ const WorkerEnv* worker_env_;
+ // thread for cq.
+ std::unique_ptr<Thread> polling_thread_;
+};
+
+// Class that represents a connection to a remote Rdma peer.
+// Responsible for connecting queue pairs.
+class RdmaChannel {
+ friend class RdmaAdapter;
+ friend class RdmaBuffer;
+ friend class RdmaAckBuffer;
+ friend class RdmaMessageBuffer;
+ friend class RdmaTensorBuffer;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+ const string remote_name_);
+ ~RdmaChannel();
+ inline const RdmaAddress& self() { return self_; }
+ RdmaAddress address() const;
+ inline const std::vector<RdmaBuffer*>& message_buffers() const {
+ return message_buffers_;
+ }
+ void Connect(const RdmaAddress& remoteAddr);
+ void Connect();
+ void Recv();
+ RdmaBuffer* FindBuffer(const uint32_t index);
+ RdmaBuffer* FindBuffer(const string& name);
+ RdmaBuffer* FindOrCreateBuffer(const string& name,
+ BufferType buffer_type = TENSOR);
+ uint32_t LookupBufferIndex(const string& buffer_name);
+ void SetRemoteAddress(const RdmaAddress& ra, bool override);
+ void InsertRecvCallback(const string& key, std::function<void()> recv_done);
+ void RemoveRecvCallback(const string& key);
+ void RunRecvCallback(const string& key);
+ static const int kNumMessageBuffers = 4;
+
+ protected:
+ const RdmaAdapter* adapter_;
+ RdmaAddress self_;
+ string local_name_;
+ string remote_name_;
+ ibv_qp* qp_;
+ mutex mu_;
+ bool connected_ GUARDED_BY(bt_mu_) = false;
+ RdmaAddress remote_ GUARDED_BY(bt_mu_);
+ bool remote_set_ GUARDED_BY(bt_mu_) = false;
+ mutex ct_mu_;
+ typedef std::unordered_map<string, std::function<void()> > CallbackTable;
+ CallbackTable callback_table_ GUARDED_BY(ct_mu_);
+ mutex bt_mu_;
+ typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable;
+ BufferTable buffer_table_ GUARDED_BY(bt_mu_);
+ typedef std::unordered_map<uint32_t, string> BufferIndexNameTable;
+ BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_);
+ typedef std::unordered_map<string, uint32_t> BufferNameIndexTable;
+ BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_);
+ RdmaBuffer* tx_message_buffer_;
+ RdmaBuffer* rx_message_buffer_;
+ RdmaBuffer* tx_ack_buffer_;
+ RdmaBuffer* rx_ack_buffer_;
+ std::vector<RdmaBuffer*> message_buffers_;
+};
+
+// Class that represents a buffer for Rdma writes and reads.
+class RdmaBuffer {
+ friend class RdmaChannel;
+ friend class RdmaAdapter;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ explicit RdmaBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaBuffer();
+
+ inline void* buffer() const { return buffer_; }
+ inline ibv_mr* self() const { return self_; }
+ inline void SetBufferStatus(Location loc, BufferStatus status) {
+ mu_.lock();
+ if (loc == local) {
+ local_status_ = status;
+ } else {
+ remote_status_ = status;
+ }
+ mu_.unlock();
+ }
+ void FreeBuffer();
+ void EnqueueItem(string Item);
+ virtual void SendNextItem(){};
+ void CreateCPUBuffer(size_t size, bool lock = true);
+ void SetRemoteMR(RemoteMR rmi, bool override);
+ uint32_t LookupBufferIndex(const string& buffer_name) {
+ return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
+ }
+ void Write(uint32_t imm_data, size_t buffer_size);
+
+ protected:
+ const RdmaChannel* channel_;
+ void* buffer_ = nullptr;
+ bool buffer_on_host_ = true;
+ size_t size_ = 0;
+ const string name_;
+ ibv_mr* self_ = nullptr;
+ mutex mu_;
+ RemoteMR remote_;
+ std::queue<string> queue_ GUARDED_BY(mu_);
+ BufferStatus local_status_ GUARDED_BY(mu_) = none;
+ BufferStatus remote_status_ GUARDED_BY(mu_) = none;
+};
+
+class RdmaAckBuffer : public RdmaBuffer {
+ public:
+ explicit RdmaAckBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaAckBuffer() override {}
+ void SendNextItem() override;
+};
+
+class RdmaMessageBuffer : public RdmaBuffer {
+ friend class RdmaChannel;
+ friend class RdmaAapater;
+
+ public:
+ explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaMessageBuffer() override {}
+ void SendNextItem() override;
+};
+
+class RdmaTensorBuffer : public RdmaBuffer {
+ public:
+ explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaTensorBuffer() override {}
+ void SendNextItem() override;
+};
+
+struct RdmaMessage {
+ RdmaMessageType type_;
+ uint16_t name_size_;
+ string name_;
+ int64 step_id_;
+ uint64_t buffer_size_;
+ uint64_t remote_addr_;
+ uint32_t rkey_;
+ bool is_dead_;
+ DataType data_type_;
+ TensorShape tensor_shape_;
+ size_t tensor_bytes_;
+
+ // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+ // ...| XB | XB | 8B |...
+ //
+ static const size_t kNameCapacity = 512;
+ static const size_t kTypeStartIndex = 0;
+ static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
+ static const size_t kNameStartIndex =
+ kNameSizeStartIndex + sizeof(name_size_);
+ static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
+ static const size_t kBufferSizeStartIndex =
+ kStepIdStartIndex + sizeof(step_id_);
+ static const size_t kRemoteAddrStartIndex =
+ kBufferSizeStartIndex + sizeof(buffer_size_);
+ static const size_t kRkeyStartIndex =
+ kRemoteAddrStartIndex + sizeof(remote_addr_);
+ static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
+ static const size_t kDataTypeStartIndex =
+ kIsDeadStartIndex + sizeof(is_dead_);
+ static const size_t kTensorShapeStartIndex =
+ kDataTypeStartIndex + sizeof(data_type_);
+ static const size_t kTensorBytesStartIndex =
+ kTensorShapeStartIndex + sizeof(TensorShape);
+ static const size_t kTensorBufferStartIndex =
+ kTensorBytesStartIndex + sizeof(tensor_bytes_);
+ static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
+ static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
+ static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
+ static string CreateMessage(const RdmaMessage& rm);
+ static void ParseMessage(RdmaMessage& rm, void* buffer);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
new file mode 100644
index 0000000000..e28b80c6f6
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include <vector>
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
+ GrpcChannelCache* const channel_cache)
+ : worker_env_(worker_env), channel_cache_(channel_cache) {
+ rdma_adapter_ = new RdmaAdapter(worker_env_);
+ // hardcoded to default session (legacy_session_)
+ // TODO: use WorkerSessionForSession
+ // need to pass in session handle
+ local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
+ std::vector<string> workers;
+ worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
+ &workers);
+ num_remote_workers_ = workers.size() - 1;
+ VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
+ for (size_t i = 0; i < workers.size(); i++) {
+ if (local_worker_.compare(workers[i]) != 0) {
+ channel_table_.insert(
+ {workers[i],
+ new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
+ }
+ }
+}
+
+// Setup Rdma channels between peers.
+// This is done at the beginning of the server setup.
+
+void RdmaMgr::SetupChannels() {
+ for (const auto& p : channel_table_) {
+ string worker_name = p.first;
+ LOG(INFO) << "connecting to remote node " << worker_name;
+ RdmaChannel* rc = p.second;
+ GetRemoteAddressRequest req;
+ GetRemoteAddressResponse resp;
+ // get the channel cache
+ SharedGrpcChannelPtr client_channel =
+ channel_cache_->FindWorkerChannel(worker_name);
+ GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
+ CHECK(client != nullptr) << "No worker known as " << worker_name;
+
+ // setting up request
+ req.set_host_name(local_worker_);
+ Channel* channel_info = req.mutable_channel();
+ channel_info->set_lid(rc->self_.lid);
+ channel_info->set_qpn(rc->self_.qpn);
+ channel_info->set_psn(rc->self_.psn);
+ for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+ MemoryRegion* mr = req.add_mr();
+ mr->set_remote_addr(
+ reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
+ mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
+ }
+ // synchronous call
+ Status s = client->GetRemoteAddress(&req, &resp);
+ // save obtained remote addresses
+ // connect to the remote channel
+ if (s.ok()) {
+ CHECK(worker_name.compare(resp.host_name()) == 0);
+ RdmaAddress ra;
+ ra.lid = resp.channel().lid();
+ ra.qpn = resp.channel().qpn();
+ ra.psn = resp.channel().psn();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0, 3, 2};
+ for (const auto& mr : resp.mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // local tx_ack_buffer <---> remote rx_ack_buffer_
+ // local rx_ack_buffer <---> remote tx_ack_buffer_
+ // hence idx[] = {1, 0, 3, 2}.
+ RdmaBuffer* rb = rc->message_buffers_[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+ } else {
+ LOG(ERROR) << s.error_message();
+ }
+ delete client;
+ }
+}
+
+RdmaMgr::~RdmaMgr() {
+ for (const auto& p : channel_table_) delete p.second;
+ channel_table_.clear();
+ delete rdma_adapter_;
+}
+
+// Find a channel via the given name.
+// Args:
+// name: peer name, e.g. worker1
+// Returns
+// channel object that is connected to the named peer.
+RdmaChannel* RdmaMgr::FindChannel(const string& name) {
+ ChannelTable::iterator iter = channel_table_.find(name);
+ CHECK(iter != channel_table_.end());
+ return iter->second;
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
new file mode 100644
index 0000000000..b156f64096
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+
+namespace tensorflow {
+
+class RdmaMgr {
+ public:
+ explicit RdmaMgr(const WorkerEnv* const worker_env,
+ GrpcChannelCache* const channel_cache);
+ ~RdmaMgr();
+ RdmaChannel* FindChannel(const string& key);
+ void SetupChannels();
+ const string& local_worker() { return local_worker_; }
+
+ private:
+ string local_worker_;
+ size_t num_remote_workers_;
+ const WorkerEnv* const worker_env_;
+ GrpcChannelCache* const channel_cache_;
+ RdmaAdapter* rdma_adapter_;
+ typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
+ ChannelTable channel_table_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
new file mode 100644
index 0000000000..8cbdfaa943
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include <unordered_set>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
+ public:
+ RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
+ int64 step_id, RdmaMgr* rdma_mgr)
+ : BaseRemoteRendezvous(env, worker_name, step_id, true),
+ rdma_mgr_(rdma_mgr) {}
+
+ protected:
+ void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ private:
+ ~RdmaRemoteRendezvous() override {}
+ RdmaMgr* rdma_mgr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous);
+};
+
+void RdmaRemoteRendezvous::RecvFromRemoteAsync(
+ const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
+ DoneCallback done) {
+ Status s;
+ // parse src_name and dst_name
+ string src_name, dst_name, unused;
+ if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
+ &unused)) {
+ s = errors::Internal("Could not parse src name.");
+ }
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor{}, false);
+ return;
+ }
+ if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
+ &unused)) {
+ s = errors::Internal("Could not parse dst name.");
+ }
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor{}, false);
+ return;
+ }
+ CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
+ RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
+ string key(std::move(parsed.FullKey().ToString()));
+ string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
+ // insert callback
+ rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
+ recv_args, parsed, done]() {
+ Status s;
+ Device* src_dev;
+ s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+ Device* dst_dev;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+ RdmaBuffer* rb = rc->FindBuffer(key);
+ RdmaMessage rm;
+ CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
+ Tensor val;
+ if (!rm.is_dead_) {
+ void* input = static_cast<char*>(rb->buffer_) +
+ RdmaMessage::kTensorBufferStartIndex;
+ TensorProto proto;
+ CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
+ rb->size_);
+ CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
+ << "fail to parse proto from array";
+ s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ }
+
+ rc->RemoveRecvCallback(key_with_step_id);
+ // create message
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
+ br.name_size_ = key.size();
+ br.name_ = key;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* tb = rc->tx_message_buffer_;
+ tb->EnqueueItem(message);
+ tb->SendNextItem();
+ done(s, Args(), recv_args, val, rm.is_dead_);
+ });
+ // append key to message queue
+ RdmaBuffer* rb = rc->tx_message_buffer_;
+ RdmaMessage rm;
+ rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.step_id_ = step_id_;
+ string message = RdmaMessage::CreateMessage(rm);
+ rb->EnqueueItem(message);
+ rb->SendNextItem();
+}
+
+RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env,
+ const string& worker_name,
+ WorkerCacheInterface* worker_cache)
+ : BaseRendezvousMgr(env, worker_name) {}
+
+BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
+ const WorkerEnv* worker_env,
+ const string& worker_name) {
+ return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
new file mode 100644
index 0000000000..57cd4bf5e4
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received. Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+// fork execution of an graph executor using "rendez" on thread 1;
+// fork execution of another graph executor using "rendez" on thread 2;
+// ...
+// join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through the "rend".
+//
+// Tensors sent and recved through rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
+class RdmaRendezvousMgr : public BaseRendezvousMgr {
+ public:
+ explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache);
+ void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ protected:
+ BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
+ const string& worker_name) override;
+
+ private:
+ RdmaMgr* rdma_mgr_;
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
new file mode 100644
index 0000000000..b061c81d2d
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/verbs_server_lib.h"
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+namespace {
+// static utility function
+RendezvousMgrInterface* NewRdmaRendezvousMgr(
+ const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache) {
+ return new RdmaRendezvousMgr(env, worker_name, worker_cache);
+}
+
+} // namespace
+
+VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
+ : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
+
+VerbsServer::~VerbsServer() {
+ TF_CHECK_OK(Stop());
+ TF_CHECK_OK(Join());
+ delete rdma_mgr_;
+ delete verbs_service_;
+ delete channel_cache_;
+}
+
+Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
+ GrpcChannelCache** channel_cache) {
+ string name_prefix =
+ strings::StrCat("/job:", server_def.job_name(), "/replica:0",
+ "/task:", server_def.task_index());
+
+ GrpcChannelSpec channel_spec;
+ TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
+
+ *channel_cache =
+ NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
+
+ const string host_port = (*channel_cache)->TranslateTask(name_prefix);
+ int requested_port;
+
+ if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
+ &requested_port)) {
+ return errors::Internal("Could not parse port for local server from \"",
+ (*channel_cache)->TranslateTask(name_prefix),
+ "\".");
+ }
+ if (requested_port != bound_port()) {
+ return errors::InvalidArgument("Requested port ", requested_port,
+ " differs from expected port ",
+ bound_port());
+ }
+
+ return Status::OK();
+}
+
+Status VerbsServer::Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendezvous_mgr_func) {
+ Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
+ {
+ mutex_lock l(mu_);
+ CHECK_EQ(verbs_state_, DISCONNECTED);
+ CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
+ rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
+ // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
+ verbs_service_->SetRdmaMgr(rdma_mgr_);
+ // hardcoded to default session (legacy_session_)
+ // TODO: use WorkerSessionForSession
+ // need to pass in session handle
+ dynamic_cast<RdmaRendezvousMgr*>(
+ worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
+ ->SetRdmaMgr(rdma_mgr_);
+ }
+ return s;
+}
+
+Status VerbsServer::Start() {
+ Status s = GrpcServer::Start();
+ {
+ mutex_lock l(mu_);
+ if (verbs_state_ == DISCONNECTED) {
+ // verbs_thread needs to be initiated
+ // before rdma_mgr sets up the rdma channels.
+ verbs_thread_.reset(worker_env()->env->StartThread(
+ ThreadOptions(), "TF_verbs_service",
+ [this] { verbs_service_->HandleRPCsLoop(); }));
+ rdma_mgr_->SetupChannels();
+ verbs_state_ = CONNECTED;
+ }
+ }
+ return s;
+}
+
+Status VerbsServer::Join() {
+ Status s = GrpcServer::Join();
+ {
+ mutex_lock l(mu_);
+ if (verbs_state_ == CONNECTED) {
+ verbs_state_ = DISCONNECTED;
+ verbs_thread_.reset();
+ }
+ }
+ return s;
+}
+
+/* static */
+Status VerbsServer::Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<ServerInterface>* out_server) {
+ std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
+ ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder) {
+ return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
+ };
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
+ *out_server = std::move(ret);
+ return Status::OK();
+}
+
+namespace {
+
+class VerbsServerFactory : public ServerFactory {
+ public:
+ bool AcceptsOptions(const ServerDef& server_def) override {
+ return server_def.protocol() == "grpc+verbs";
+ }
+
+ Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server) override {
+ return VerbsServer::Create(server_def, Env::Default(), out_server);
+ }
+};
+
+// Registers a `ServerFactory` for `VerbsServer` instances.
+class VerbsServerRegistrar {
+ public:
+ VerbsServerRegistrar() {
+ gpr_allocation_functions alloc_fns;
+ alloc_fns.malloc_fn = port::Malloc;
+ alloc_fns.realloc_fn = port::Realloc;
+ alloc_fns.free_fn = port::Free;
+ gpr_set_allocation_functions(alloc_fns);
+ ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
+ }
+};
+static VerbsServerRegistrar registrar;
+
+} // namespace
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h
new file mode 100644
index 0000000000..855380129f
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+namespace tensorflow {
+
+class VerbsServer : public GrpcServer {
+ protected:
+ VerbsServer(const ServerDef& server_def, Env* env);
+
+ public:
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<ServerInterface>* out_server);
+
+ // Destruction is only supported in the factory method. Clean
+ // shutdown is not currently implemented for this server type.
+ virtual ~VerbsServer() override;
+
+ // Implementations of ServerInterface methods.
+ Status Start() override;
+ Status Join() override;
+
+ protected:
+ Status Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendezvous_mgr_func);
+ Status ChannelCacheFactory(const ServerDef& server_def,
+ GrpcChannelCache** channel_cache);
+
+ private:
+ RdmaMgr* rdma_mgr_;
+
+ // Guards state transitions.
+ mutex mu_;
+
+ enum State { DISCONNECTED, CONNECTED };
+ State verbs_state_ GUARDED_BY(mu_);
+
+ GrpcVerbsService* verbs_service_ = nullptr;
+ std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
+ GrpcChannelCache* channel_cache_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto
new file mode 100644
index 0000000000..b985febfb8
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_service.proto
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option java_outer_classname = "VerbsServiceProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.contrib.verbs";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// GRPC Helper messages used to exchange RDMA information.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message Channel {
+ int32 lid = 1;
+ int32 qpn = 2;
+ int32 psn = 3;
+}
+
+message MemoryRegion {
+ uint64 remote_addr = 1;
+ uint32 rkey = 2;
+}
+message GetRemoteAddressRequest {
+ string host_name = 1;
+ Channel channel = 2;
+ repeated MemoryRegion mr = 3;
+}
+
+message GetRemoteAddressResponse {
+ string host_name = 1;
+ Channel channel = 2;
+ repeated MemoryRegion mr = 3;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// VerbsService
+//
+////////////////////////////////////////////////////////////////////////////////
+
+service VerbsService {
+ rpc GetRemoteAddress(GetRemoteAddressRequest)
+ returns (GetRemoteAddressResponse);
+}
diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc
new file mode 100644
index 0000000000..c3350f7958
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/verbs_util.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+namespace tensorflow {
+
+// static sync wrapper:
+Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead) {
+ Notification n;
+ Status status;
+ GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return status;
+}
+
+// static
+string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
+ return strings::StrCat(key, ";", step_id);
+}
+
+// static
+void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
+ int64& step_id) {
+ StringPiece s(key_with_step_id);
+ // a key (with step_id) has exact 6 parts if split by ";"
+ // part 1: src_device;
+ // part 2: src_incarnation;
+ // part 3: dst_device;
+ // part 4: name;
+ // part 5: frame_iter.frame_id:frame_iter.iter_id
+ // part 6: step_id
+ std::vector<string> parts = str_util::Split(s, ';');
+ CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
+ strings::safe_strto64(parts[5], &step_id);
+ parts.pop_back(); // remove step_id
+ key.assign(str_util::Join(parts, ";")); // stitch them together
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
new file mode 100644
index 0000000000..cbc01adae4
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+#define TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class TensorProto;
+
+class VerbsUtil {
+ public:
+ // synchronous wrapper of SetProtoFromGPU
+ static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead);
+ static string AppendStepidToKey(const string& key, int64 step_id);
+ static void GetKeyAndStepId(const string& key_with_step_id, string& key,
+ int64& step_id);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_RDMA_UTIL_H_