aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-06 18:43:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 18:47:41 -0800
commitfe8406149feec453250905965a14285465cd2063 (patch)
treebe3cd75d543f3c0f29f368da61d915abbae7fcbf /tensorflow/contrib/verbs
parent8ad62af489df718992561710123bc8c037e7d17b (diff)
Merge changes from github.
PiperOrigin-RevId: 178185697
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r--tensorflow/contrib/verbs/BUILD6
-rw-r--r--tensorflow/contrib/verbs/rdma.cc98
-rw-r--r--tensorflow/contrib/verbs/rdma.h29
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc51
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h5
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc46
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc5
7 files changed, 176 insertions, 64 deletions
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
index 746ff38b37..38a84ffb10 100644
--- a/tensorflow/contrib/verbs/BUILD
+++ b/tensorflow/contrib/verbs/BUILD
@@ -7,6 +7,8 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
+load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
+
exports_files(["LICENSE"])
filegroup(
@@ -97,7 +99,7 @@ cc_library(
alwayslink = 1,
)
-cc_library(
+tf_cuda_library(
name = "rdma_rendezvous_mgr",
srcs = ["rdma_rendezvous_mgr.cc"],
hdrs = ["rdma_rendezvous_mgr.h"],
@@ -130,7 +132,7 @@ cc_library(
],
)
-cc_library(
+tf_cuda_library(
name = "rdma",
srcs = ["rdma.cc"],
hdrs = ["rdma.h"],
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index ac8d994502..ae9a384565 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -18,11 +18,14 @@ limitations under the License.
#include "tensorflow/contrib/verbs/rdma.h"
#include <fcntl.h>
#include <cstdlib>
+#include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -31,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
@@ -418,9 +422,6 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
0);
CHECK(cq_) << "Failed to create completion queue";
CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
- polling_thread_.reset(Env::Default()->StartThread(
- ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
- VLOG(2) << "Start RdmaAdapter: " << name();
}
RdmaAdapter::~RdmaAdapter() {
@@ -432,6 +433,12 @@ RdmaAdapter::~RdmaAdapter() {
CHECK(!ibv_close_device(context_)) << "Failed to release context";
}
+void RdmaAdapter::StartPolling() {
+ polling_thread_.reset(Env::Default()->StartThread(
+ ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
+ VLOG(2) << "Start RdmaAdapter: " << name();
+}
+
string RdmaAdapter::name() const { return string(context_->device->name); }
// Function to process incoming messages
@@ -452,9 +459,9 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n"
- << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
- << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
+ << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
+ << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
+ << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@@ -557,9 +564,44 @@ void RdmaAdapter::Process_CQ() {
}
}
+int RdmaChannel::PingPostRecv() {
+ struct ibv_recv_wr wr, *bad_wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.sg_list = &ping_sge_list_;
+ wr.num_sge = 1;
+ wr.wr_id = kPingRecvWrid;
+
+ return ibv_post_recv(qp_, &wr, &bad_wr);
+}
+
+int RdmaChannel::PingPostSend() {
+ struct ibv_send_wr wr, *bad_wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = (uint64_t) this;
+ wr.sg_list = &ping_sge_list_;
+ wr.num_sge = 1;
+ wr.opcode = IBV_WR_SEND;
+ wr.send_flags = IBV_SEND_SIGNALED;
+
+ return ibv_post_send(qp_, &wr, &bad_wr);
+}
+
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name)
: adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
+
+ struct ibv_sge list;
+
+ mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
+ IBV_ACCESS_LOCAL_WRITE);
+ CHECK(mr_) << "Failed to register memory region";
+
+ memset(&list, 0, sizeof(list));
+ list.addr = (uintptr_t)ping_buff_;
+ list.length = kPingBuffSize;
+ list.lkey = mr_->lkey;
+
+ ping_sge_list_ = list;
// Create queue pair
{
struct ibv_qp_init_attr attr;
@@ -610,7 +652,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ "tx_ack_buffer", "rx_ack_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
@@ -632,15 +674,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
buffer_index_name_table_.insert({index, buffer_names[i]});
buffer_name_index_table_.insert({buffer_names[i], index});
}
-
- // Initiate recv
- for (int i = 0; i < 100; i++) {
- Recv();
- }
}
+ CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
+ << " with error " << std::strerror(errno);
}
RdmaChannel::~RdmaChannel() {
+ ibv_dereg_mr(mr_);
CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
delete tx_message_buffer_;
delete rx_message_buffer_;
@@ -671,7 +711,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t)this;
+ wr.wr_id = (uint64_t) this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@@ -825,11 +865,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr,
- IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
+ IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
@@ -840,10 +880,10 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr,
- IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
+ IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
@@ -930,7 +970,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t)this;
+ wr.wr_id = (uint64_t) this;
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
@@ -1025,9 +1065,10 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
- CHECK(send_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+#if GOOGLE_CUDA
+ CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
+ << " gpu_info: "
+ << src_dev->tensorflow_gpu_device_info();
if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
@@ -1053,8 +1094,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead, send_args, recv_args](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;
@@ -1063,6 +1104,7 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
&proto, NULL, send_args, recv_args);
});
}
+#endif // GOOGLE_CUDA
} else {
// tensor is in CPU memory.
StringPiece copy_buf;
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index 00217c81d4..fea2327d77 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -67,9 +67,20 @@ struct RemoteMR {
uint64_t remote_addr;
uint32_t rkey;
};
-enum BufferStatus { none, idle, busy };
-enum Location { local, remote };
-enum BufferType { ACK, MESSAGE, TENSOR };
+enum BufferStatus {
+ none,
+ idle,
+ busy
+};
+enum Location {
+ local,
+ remote
+};
+enum BufferType {
+ ACK,
+ MESSAGE,
+ TENSOR
+};
enum RdmaMessageType {
RDMA_MESSAGE_ACK,
RDMA_MESSAGE_BUFFER_IDLE,
@@ -96,6 +107,7 @@ class RdmaAdapter {
~RdmaAdapter();
// Adapter name, e.g. mlx5_0.
string name() const;
+ void StartPolling();
void Process_CQ();
protected:
@@ -150,6 +162,15 @@ class RdmaChannel {
void RemoveRecvCallback(const string& key);
void RunRecvCallback(const string& key);
static const int kNumMessageBuffers = 4;
+ static const int kPingRecvWrid = 0;
+
+ private:
+ static const int kPingBuffSize = 1024;
+ char ping_buff_[kPingBuffSize];
+ struct ibv_mr* mr_;
+ struct ibv_sge ping_sge_list_;
+ int PingPostRecv();
+ int PingPostSend();
protected:
const RdmaAdapter* adapter_;
@@ -202,7 +223,7 @@ class RdmaBuffer {
}
void FreeBuffer();
void EnqueueItem(string Item);
- virtual void SendNextItem(){};
+ virtual void SendNextItem() {};
void CreateCPUBuffer(size_t size, bool lock = true);
void SetRemoteMR(RemoteMR rmi, bool override);
uint32_t LookupBufferIndex(const string& buffer_name) {
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 09b878843f..9cb307bcfa 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -115,6 +115,57 @@ void RdmaMgr::SetupChannels() {
}
}
+// Check connectivity by pinging every channel
+bool RdmaMgr::ConnectivityCheck() {
+ int i, rcnt = 0, scnt = 0;
+
+ for (const auto& p : channel_table_) {
+ string worker_name = p.first;
+ RdmaChannel* rc = p.second;
+
+ VLOG(2) << "Ping to " << worker_name;
+ CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name
+ << " with error: " << std::strerror(errno);
+ for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) {
+ rc->Recv();
+ }
+ }
+
+ while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) {
+ int ne;
+ do {
+ ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_,
+ rdma_adapter_->wc_);
+ CHECK(ne >= 0) << "poll CQ failed " << ne << "with error"
+ << std::strerror(errno);
+ } while (ne < 1);
+
+ for (i = 0; i < ne; ++i) {
+ ibv_wc_status s = rdma_adapter_->wc_[i].status;
+ // recv complete
+ if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
+ CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
+ rdma_adapter_->wc_[i].status)
+ << "(" << rdma_adapter_->wc_[i].status
+ << ") for PING_RECV_WRID";
+ ++rcnt;
+ // send complete
+ } else {
+ RdmaChannel* rc =
+ reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
+ CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
+ rdma_adapter_->wc_[i].status)
+ << "(" << rdma_adapter_->wc_[i].status
+ << ") to " << rc->remote_name_;
+ ++scnt;
+ }
+ } // for
+ } // while
+ CHECK(rcnt == scnt) << "Connectivity check failed!";
+ rdma_adapter_->StartPolling();
+ return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt);
+}
+
RdmaMgr::~RdmaMgr() {
for (const auto& p : channel_table_) delete p.second;
channel_table_.clear();
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
index b156f64096..e711e60478 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -28,12 +28,16 @@ limitations under the License.
namespace tensorflow {
class RdmaMgr {
+ friend class RdmaChannel;
+ friend class RdmaAdapter;
+
public:
explicit RdmaMgr(const WorkerEnv* const worker_env,
GrpcChannelCache* const channel_cache);
~RdmaMgr();
RdmaChannel* FindChannel(const string& key);
void SetupChannels();
+ bool ConnectivityCheck();
const string& local_worker() { return local_worker_; }
private:
@@ -44,7 +48,6 @@ class RdmaMgr {
RdmaAdapter* rdma_adapter_;
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
ChannelTable channel_table_;
-
TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
};
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ce82ca2883..74f6681af3 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#endif // GOOGLE_CUDA
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -58,20 +60,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
// parse src_name and dst_name
string src_name, dst_name, unused;
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
+ &unused) ||
+ !DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
&unused)) {
- s = errors::Internal("Could not parse src name.");
+ s = errors::Internal("Could not parse src or dst name.");
}
- CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
- if (!s.ok()) {
- done(s, Args(), recv_args, Tensor{}, false);
- return;
- }
- if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
- &unused)) {
- s = errors::Internal("Could not parse dst name.");
- }
- CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
+ LOG(ERROR) << "s is not ok, error code " << s.error_message();
done(s, Args(), recv_args, Tensor{}, false);
return;
}
@@ -82,18 +77,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
// insert callback
rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
recv_args, parsed, done]() {
- Status s;
- Device* src_dev;
- s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
- CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
- if (!s.ok()) {
- done(s, Args(), recv_args, Tensor(), true);
- return;
- }
- Device* dst_dev;
- s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
- CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
- if (!s.ok()) {
+ Status src_s, dst_s, s;
+ Device* src_dev, *dst_dev;
+ src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
+ dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+ if (!src_s.ok() || !dst_s.ok()) {
+ s = src_s.ok() ? dst_s : src_s;
+ LOG(ERROR) << "s is not ok, error code " << s.error_message();
done(s, Args(), recv_args, Tensor(), true);
return;
}
@@ -110,9 +100,10 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
if (can_memcpy) {
if (dst_dev->tensorflow_gpu_device_info() &&
(!recv_args.alloc_attrs.on_host())) {
+#if GOOGLE_CUDA
CHECK(recv_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
@@ -122,14 +113,15 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
GPUUtil::CopyCPUTensorToGPU(
&copy, recv_args.device_context, dst_dev, &gpu_copy,
- [this, gpu_copy, key, key_with_step_id, recv_args, done, rm,
- rc](const Status& s) {
+ [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc](
+ const Status& s) {
CHECK(s.ok()) << "copy tensor to gpu sync";
Tensor val;
val = std::move(gpu_copy);
RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
val, s);
});
+#endif // GOOGLE_CUDA
return;
} else {
AllocatorAttributes host_alloc_attrs;
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 6d1c79c0fb..a606ef75a4 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() {
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
GrpcChannelCache** channel_cache) {
string name_prefix =
- strings::StrCat("/job:", server_def.job_name(), "/replica:0",
- "/task:", server_def.task_index());
+ strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:",
+ server_def.task_index());
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
@@ -103,6 +103,7 @@ Status VerbsServer::Start() {
ThreadOptions(), "TF_verbs_service",
[this] { verbs_service_->HandleRPCsLoop(); }));
rdma_mgr_->SetupChannels();
+ CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
verbs_state_ = CONNECTED;
}
}