aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc413
1 files changed, 370 insertions, 43 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 26e18b28aa..331943a3ef 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/contrib/verbs/rdma.h"
#include <cstdlib>
+#include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -33,6 +34,8 @@ limitations under the License.
namespace tensorflow {
+#define RoCE_V2 "RoCE v2"
+
namespace {
// hash name to 32-bit integer
uint32_t NameHash(const string& name) {
@@ -66,16 +69,337 @@ string MessageTypeToString(RdmaMessageType rmt) {
}
} // namespace
-ibv_context* open_default_device() {
+// Function to get environment variable
+// Args:
+// var_name - the name of the environmental variable
+// Returns:
+// string with it's value or empty string if not set
+string get_env_var(char const* var_name) {
+ char const* var_temp = getenv(var_name);
+
+ return (var_temp == NULL) ? string() : string(var_temp);
+}
+
+// Function to open device
+// Args:
+// ibv_dev device to open
+// Returns:
+// context of the opened device
+ibv_context* open_device(ibv_device* ibv_dev) {
+ ibv_context* context = ibv_open_device(ibv_dev);
+
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
+ return context;
+}
+
+// Function to count the number of active ports for device
+// Args:
+// device - to check active ports
+// Returns:
+// number of active ports of the given device
+int get_dev_active_port_count(ibv_device* device) {
+ ibv_device_attr device_att;
+ ibv_port_attr port_attr;
+ ibv_context* context = NULL;
+ int rc, port_index, active_ports = 0;
+
+ context = ibv_open_device(device);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
+ rc = ibv_query_device(context, &device_att);
+ CHECK(!rc) << "Failed to query the device";
+
+ for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
+ rc = ibv_query_port(context, port_index, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_index;
+ if (port_attr.state == IBV_PORT_ACTIVE) {
+ active_ports++;
+ }
+ }
+ ibv_close_device(context);
+ return active_ports;
+}
+
+// Function to set device. If RDMA_DEVICE not set, search for device with active
+// port.
+// Fails if more than one device with active port was found.
+// Returns:
+// device to use
+ibv_device* set_device() {
ibv_device** dev_list;
- ibv_device* ib_dev;
- dev_list = ibv_get_device_list(NULL);
+ int dev_num, device_index, device_to_open = 0;
+ int num_devs_with_active_port = 0;
+ string env_p_rdma_device, str_port_num;
+
+ dev_list = ibv_get_device_list(&dev_num);
CHECK(dev_list) << "No InfiniBand device found";
- ib_dev = dev_list[0];
- CHECK(ib_dev) << "No InfiniBand device found";
- ibv_context* context = ibv_open_device(ib_dev);
- CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
- return context;
+
+ env_p_rdma_device = get_env_var("RDMA_DEVICE");
+ if (!env_p_rdma_device.empty()) {
+ for (device_index = 0; device_index < dev_num; device_index++) {
+ if (!env_p_rdma_device.compare(
+ ibv_get_device_name(dev_list[device_index]))) {
+ CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
+ << "Device " << ibv_get_device_name(dev_list[device_index])
+ << " has no active ports";
+ return dev_list[device_index];
+ }
+ }
+ // check validity of input device
+ CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
+ } else {
+ // set default device
+ str_port_num = get_env_var("RDMA_DEVICE_PORT");
+ CHECK(str_port_num.empty())
+ << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
+ for (device_index = 0; device_index < dev_num; device_index++) {
+ // get port_num
+ if (get_dev_active_port_count(dev_list[device_index]) > 0) {
+ num_devs_with_active_port++;
+ CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
+ "active port in the system. "
+ "Please enter RDMA_DEVICE";
+ // found device with at least 1 active port
+ device_to_open = device_index;
+ }
+ }
+ CHECK(num_devs_with_active_port > 0)
+ << "There is no active port in the system";
+ return dev_list[device_to_open];
+ }
+ CHECK(false) << "No device was set!";
+ return NULL; // never happens
+}
+
+// Function to set port for device.
+// If RDMA_DEVICE_PORT not set, first active port of the device will be set.
+// Args:
+// context of the device
+// Returns:
+// port to use
+uint8_t set_port(ibv_context* context) {
+ uint8_t port_num = 0; //0 is illegal port number
+ string str_port_num;
+ ibv_device_attr device_att;
+ ibv_port_attr port_attr;
+ int rc, port_index;
+
+ rc = ibv_query_device(context, &device_att);
+ CHECK(!rc) << "Failed to query the device\n";
+
+ str_port_num = get_env_var("RDMA_DEVICE_PORT");
+ // user defined port
+ if (!str_port_num.empty()) {
+ port_num = stoi(str_port_num);
+ CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
+ CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
+ "less or equal to amount of "
+ "available ports";
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+ // check if port id active
+ CHECK(port_attr.state == IBV_PORT_ACTIVE)
+ << "Selected RDMA_DEVICE_PORT is not active";
+ }
+ // set default port
+ else {
+ for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
+ rc = ibv_query_port(context, port_index, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_index;
+ if (port_attr.state == IBV_PORT_ACTIVE) {
+ port_num = port_index;
+ break;
+ }
+ }
+ CHECK_GT(port_num, 0) << "No active ports";
+ }
+ return port_num;
+}
+
+// Function read from sysfs file
+// Args:
+// dir - directory
+// file - file
+// buff - buffer for the result
+// size - buffer size
+// Returns:
+// number of bytes were read or -1 if failed
+int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
+ char* path;
+ int fd;
+ int len;
+
+ if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
+
+ fd = open(path, O_RDONLY);
+ if (fd < 0) {
+ free(path);
+ return -1;
+ }
+
+ len = read(fd, buf, size);
+
+ close(fd);
+ free(path);
+
+ if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
+
+ return len;
+}
+
+// Function to check if GID index support RoCE V2
+// Args:
+// context - device context
+// port_num - port number
+// index - GID index
+// Returns:
+// if GID supports RoCE V2 - true, otherwise - false.
+bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
+ uint8_t index) {
+ char name[32];
+ char buff[41];
+
+ snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
+ if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
+ 0) {
+ return false;
+ }
+ return !strcmp(buff, RoCE_V2);
+}
+
+// Function to set GID index.
+// If the port link is IB, no GID index should be selected.
+// If Ethernet but RDMA_GID_INDEX not set gid index that supports
+// RoCE V2 will be chosen(fails if more then one IP is configured)
+// Args:
+// context - device context
+// port_num - port number
+// Returns:
+// GID index to use
+uint8_t set_gid(uint8_t port_num, ibv_context* context) {
+ ibv_port_attr port_attr;
+ string gid_str;
+ int rc, i, gids_num = 0, v2_ip_num = 0;
+ union ibv_gid gid;
+ uint8_t gid_index = 0;
+
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+
+ for (i = 0; i < port_attr.gid_tbl_len; i++) {
+ rc = ibv_query_gid(context, port_num, i, &gid);
+ CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
+ << i;
+ if (gid.global.interface_id) {
+ gids_num++;
+ if (gid.global.subnet_prefix == 0 &&
+ is_gid_type_roce_v2(context, port_num, i)) {
+ if (v2_ip_num == 0) {
+ // can be overwritten by RDMA_GID_INDEX later
+ gid_index = i;
+ }
+ v2_ip_num++;
+ }
+ }
+ }
+ switch (port_attr.link_layer) {
+ case(IBV_LINK_LAYER_ETHERNET) :
+ gid_str = get_env_var("RDMA_GID_INDEX");
+ if (!gid_str.empty()) {
+ gid_index = stoi(gid_str);
+ CHECK(gid_index < gids_num)
+ << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
+ } else {
+ CHECK(v2_ip_num <= 1)
+ << "More than one IP is available, please specify GID_INDEX";
+ }
+ break;
+ case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index
+ break;
+ default:
+ LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
+ "InfiniBand only. ";
+ }
+ if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
+ LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
+ }
+ return gid_index;
+}
+
+// set the default or environment value to the configuration parameter.
+// Args:
+// default_val- the default value for this parameter
+// env_param- the environment parameter's name
+// Returns:
+// 32-bit value
+uint32_t set_param(uint32_t default_val, const char* env_param) {
+ uint32_t val = default_val;
+ string val_s;
+
+ val_s = get_env_var(env_param);
+
+ if (!val_s.empty()) {
+ val = stoi(val_s);
+ }
+ return val;
+}
+
+enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
+ ibv_port_attr port_attr;
+ enum ibv_mtu mtu;
+ string mtu_s;
+ int rc, mtu_i;
+
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+
+ mtu_s = get_env_var("RDMA_MTU");
+
+ if (!mtu_s.empty()) {
+ mtu_i = stoi(mtu_s);
+ switch (mtu_i) {
+ case 256:
+ mtu = IBV_MTU_256;
+ break;
+ case 512:
+ mtu = IBV_MTU_512;
+ break;
+ case 1024:
+ mtu = IBV_MTU_1024;
+ break;
+ case 2048:
+ mtu = IBV_MTU_2048;
+ break;
+ case 4096:
+ mtu = IBV_MTU_4096;
+ break;
+ default:
+ CHECK(0) << "Error: MTU input value must be one of the following: 256, "
+ "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n";
+ break;
+ }
+ CHECK(mtu < port_attr.active_mtu)
+ << "MTU configuration for the QPs is larger than active MTU";
+ } else {
+ mtu = port_attr.active_mtu;
+ }
+ return mtu;
+}
+
+RdmaParams params_init(ibv_context* context) {
+ RdmaParams params;
+
+ params.port_num = set_port(context);
+ params.sgid_index = set_gid(params.port_num, context);
+ params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
+ params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
+ params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
+ params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
+ params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
+ CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
+ << ". Valid values are 0-7.";
+ params.mtu = set_mtu(params.port_num, context);
+ params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
+ return params;
}
ibv_pd* alloc_protection_domain(ibv_context* context) {
@@ -85,7 +409,8 @@ ibv_pd* alloc_protection_domain(ibv_context* context) {
}
RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
- : context_(open_default_device()),
+ : context_(open_device(set_device())),
+ params_(params_init(context_)),
pd_(alloc_protection_domain(context_)),
worker_env_(worker_env) {
event_channel_ = ibv_create_comp_channel(context_);
@@ -128,9 +453,9 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n"
- << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
- << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
+ << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
+ << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
+ << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@@ -242,8 +567,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
memset(&attr, 0, sizeof(ibv_qp_init_attr));
attr.send_cq = adapter_->cq_;
attr.recv_cq = adapter_->cq_;
- attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
- attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_send_wr = adapter_->params_.queue_depth;
+ attr.cap.max_recv_wr = adapter_->params_.queue_depth;
attr.cap.max_send_sge = 1;
attr.cap.max_recv_sge = 1;
attr.qp_type = IBV_QPT_RC;
@@ -257,8 +582,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_INIT;
- attr.pkey_index = 0;
- attr.port_num = 1;
+ attr.pkey_index = adapter_->params_.pkey_index;
+ attr.port_num = adapter_->params_.port_num;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int mask =
@@ -269,13 +594,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// Local address
{
struct ibv_port_attr attr;
- CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
+ CHECK(
+ !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
<< "Query port";
self_.lid = attr.lid;
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid;
- CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
+ CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
+ adapter_->params_.sgid_index, &gid))
<< "Query gid";
self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id;
@@ -284,7 +611,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ "tx_ack_buffer", "rx_ack_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
@@ -345,7 +672,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t)this;
+ wr.wr_id = (uint64_t) this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@@ -479,11 +806,9 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTR;
- struct ibv_port_attr port_attr;
- CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr))
- << "Query port failed";
+
// This assumes both QP's ports are configured with the same MTU
- attr.path_mtu = port_attr.active_mtu;
+ attr.path_mtu = adapter_->params_.mtu;
attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1;
@@ -494,30 +819,32 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.hop_limit = 255;
attr.ah_attr.dlid = remoteAddr.lid;
- attr.ah_attr.sl = 0;
+ attr.ah_attr.sl = adapter_->params_.sl;
attr.ah_attr.src_path_bits = 0;
- attr.ah_attr.port_num = 1;
+ attr.ah_attr.port_num = adapter_->params_.port_num;
+ attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
+ attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr,
- IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
+ IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = self_.psn;
- attr.timeout = 14;
- attr.retry_cnt = 7;
+ attr.timeout = adapter_->params_.timeout;
+ attr.retry_cnt = adapter_->params_.retry_cnt;
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr,
- IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
+ IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
@@ -604,7 +931,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t)this;
+ wr.wr_id = (uint64_t) this;
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
@@ -699,9 +1026,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
- CHECK(send_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
+ << " gpu_info: "
+ << src_dev->tensorflow_gpu_device_info();
if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
@@ -727,8 +1054,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead, send_args, recv_args](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;