aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc413
1 files changed, 43 insertions, 370 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 331943a3ef..26e18b28aa 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/contrib/verbs/rdma.h"
#include <cstdlib>
-#include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -34,8 +33,6 @@ limitations under the License.
namespace tensorflow {
-#define RoCE_V2 "RoCE v2"
-
namespace {
// hash name to 32-bit integer
uint32_t NameHash(const string& name) {
@@ -69,337 +66,16 @@ string MessageTypeToString(RdmaMessageType rmt) {
}
} // namespace
-// Function to get environment variable
-// Args:
-// var_name - the name of the environmental variable
-// Returns:
-// string with it's value or empty string if not set
-string get_env_var(char const* var_name) {
- char const* var_temp = getenv(var_name);
-
- return (var_temp == NULL) ? string() : string(var_temp);
-}
-
-// Function to open device
-// Args:
-// ibv_dev device to open
-// Returns:
-// context of the opened device
-ibv_context* open_device(ibv_device* ibv_dev) {
- ibv_context* context = ibv_open_device(ibv_dev);
-
- CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
- return context;
-}
-
-// Function to count the number of active ports for device
-// Args:
-// device - to check active ports
-// Returns:
-// number of active ports of the given device
-int get_dev_active_port_count(ibv_device* device) {
- ibv_device_attr device_att;
- ibv_port_attr port_attr;
- ibv_context* context = NULL;
- int rc, port_index, active_ports = 0;
-
- context = ibv_open_device(device);
- CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
- rc = ibv_query_device(context, &device_att);
- CHECK(!rc) << "Failed to query the device";
-
- for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
- rc = ibv_query_port(context, port_index, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_index;
- if (port_attr.state == IBV_PORT_ACTIVE) {
- active_ports++;
- }
- }
- ibv_close_device(context);
- return active_ports;
-}
-
-// Function to set device. If RDMA_DEVICE not set, search for device with active
-// port.
-// Fails if more than one device with active port was found.
-// Returns:
-// device to use
-ibv_device* set_device() {
+ibv_context* open_default_device() {
ibv_device** dev_list;
- int dev_num, device_index, device_to_open = 0;
- int num_devs_with_active_port = 0;
- string env_p_rdma_device, str_port_num;
-
- dev_list = ibv_get_device_list(&dev_num);
+ ibv_device* ib_dev;
+ dev_list = ibv_get_device_list(NULL);
CHECK(dev_list) << "No InfiniBand device found";
-
- env_p_rdma_device = get_env_var("RDMA_DEVICE");
- if (!env_p_rdma_device.empty()) {
- for (device_index = 0; device_index < dev_num; device_index++) {
- if (!env_p_rdma_device.compare(
- ibv_get_device_name(dev_list[device_index]))) {
- CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
- << "Device " << ibv_get_device_name(dev_list[device_index])
- << " has no active ports";
- return dev_list[device_index];
- }
- }
- // check validity of input device
- CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
- } else {
- // set default device
- str_port_num = get_env_var("RDMA_DEVICE_PORT");
- CHECK(str_port_num.empty())
- << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
- for (device_index = 0; device_index < dev_num; device_index++) {
- // get port_num
- if (get_dev_active_port_count(dev_list[device_index]) > 0) {
- num_devs_with_active_port++;
- CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
- "active port in the system. "
- "Please enter RDMA_DEVICE";
- // found device with at least 1 active port
- device_to_open = device_index;
- }
- }
- CHECK(num_devs_with_active_port > 0)
- << "There is no active port in the system";
- return dev_list[device_to_open];
- }
- CHECK(false) << "No device was set!";
- return NULL; // never happens
-}
-
-// Function to set port for device.
-// If RDMA_DEVICE_PORT not set, first active port of the device will be set.
-// Args:
-// context of the device
-// Returns:
-// port to use
-uint8_t set_port(ibv_context* context) {
- uint8_t port_num = 0; //0 is illegal port number
- string str_port_num;
- ibv_device_attr device_att;
- ibv_port_attr port_attr;
- int rc, port_index;
-
- rc = ibv_query_device(context, &device_att);
- CHECK(!rc) << "Failed to query the device\n";
-
- str_port_num = get_env_var("RDMA_DEVICE_PORT");
- // user defined port
- if (!str_port_num.empty()) {
- port_num = stoi(str_port_num);
- CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
- CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
- "less or equal to amount of "
- "available ports";
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
- // check if port id active
- CHECK(port_attr.state == IBV_PORT_ACTIVE)
- << "Selected RDMA_DEVICE_PORT is not active";
- }
- // set default port
- else {
- for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
- rc = ibv_query_port(context, port_index, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_index;
- if (port_attr.state == IBV_PORT_ACTIVE) {
- port_num = port_index;
- break;
- }
- }
- CHECK_GT(port_num, 0) << "No active ports";
- }
- return port_num;
-}
-
-// Function read from sysfs file
-// Args:
-// dir - directory
-// file - file
-// buff - buffer for the result
-// size - buffer size
-// Returns:
-// number of bytes were read or -1 if failed
-int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
- char* path;
- int fd;
- int len;
-
- if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
-
- fd = open(path, O_RDONLY);
- if (fd < 0) {
- free(path);
- return -1;
- }
-
- len = read(fd, buf, size);
-
- close(fd);
- free(path);
-
- if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
-
- return len;
-}
-
-// Function to check if GID index support RoCE V2
-// Args:
-// context - device context
-// port_num - port number
-// index - GID index
-// Returns:
-// if GID supports RoCE V2 - true, otherwise - false.
-bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
- uint8_t index) {
- char name[32];
- char buff[41];
-
- snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
- if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
- 0) {
- return false;
- }
- return !strcmp(buff, RoCE_V2);
-}
-
-// Function to set GID index.
-// If the port link is IB, no GID index should be selected.
-// If Ethernet but RDMA_GID_INDEX not set gid index that supports
-// RoCE V2 will be chosen(fails if more then one IP is configured)
-// Args:
-// context - device context
-// port_num - port number
-// Returns:
-// GID index to use
-uint8_t set_gid(uint8_t port_num, ibv_context* context) {
- ibv_port_attr port_attr;
- string gid_str;
- int rc, i, gids_num = 0, v2_ip_num = 0;
- union ibv_gid gid;
- uint8_t gid_index = 0;
-
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
-
- for (i = 0; i < port_attr.gid_tbl_len; i++) {
- rc = ibv_query_gid(context, port_num, i, &gid);
- CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
- << i;
- if (gid.global.interface_id) {
- gids_num++;
- if (gid.global.subnet_prefix == 0 &&
- is_gid_type_roce_v2(context, port_num, i)) {
- if (v2_ip_num == 0) {
- // can be overwritten by RDMA_GID_INDEX later
- gid_index = i;
- }
- v2_ip_num++;
- }
- }
- }
- switch (port_attr.link_layer) {
- case(IBV_LINK_LAYER_ETHERNET) :
- gid_str = get_env_var("RDMA_GID_INDEX");
- if (!gid_str.empty()) {
- gid_index = stoi(gid_str);
- CHECK(gid_index < gids_num)
- << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
- } else {
- CHECK(v2_ip_num <= 1)
- << "More than one IP is available, please specify GID_INDEX";
- }
- break;
- case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index
- break;
- default:
- LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
- "InfiniBand only. ";
- }
- if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
- LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
- }
- return gid_index;
-}
-
-// set the default or environment value to the configuration parameter.
-// Args:
-// default_val- the default value for this parameter
-// env_param- the environment parameter's name
-// Returns:
-// 32-bit value
-uint32_t set_param(uint32_t default_val, const char* env_param) {
- uint32_t val = default_val;
- string val_s;
-
- val_s = get_env_var(env_param);
-
- if (!val_s.empty()) {
- val = stoi(val_s);
- }
- return val;
-}
-
-enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
- ibv_port_attr port_attr;
- enum ibv_mtu mtu;
- string mtu_s;
- int rc, mtu_i;
-
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
-
- mtu_s = get_env_var("RDMA_MTU");
-
- if (!mtu_s.empty()) {
- mtu_i = stoi(mtu_s);
- switch (mtu_i) {
- case 256:
- mtu = IBV_MTU_256;
- break;
- case 512:
- mtu = IBV_MTU_512;
- break;
- case 1024:
- mtu = IBV_MTU_1024;
- break;
- case 2048:
- mtu = IBV_MTU_2048;
- break;
- case 4096:
- mtu = IBV_MTU_4096;
- break;
- default:
- CHECK(0) << "Error: MTU input value must be one of the following: 256, "
- "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n";
- break;
- }
- CHECK(mtu < port_attr.active_mtu)
- << "MTU configuration for the QPs is larger than active MTU";
- } else {
- mtu = port_attr.active_mtu;
- }
- return mtu;
-}
-
-RdmaParams params_init(ibv_context* context) {
- RdmaParams params;
-
- params.port_num = set_port(context);
- params.sgid_index = set_gid(params.port_num, context);
- params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
- params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
- params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
- params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
- params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
- CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
- << ". Valid values are 0-7.";
- params.mtu = set_mtu(params.port_num, context);
- params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
- return params;
+ ib_dev = dev_list[0];
+ CHECK(ib_dev) << "No InfiniBand device found";
+ ibv_context* context = ibv_open_device(ib_dev);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
+ return context;
}
ibv_pd* alloc_protection_domain(ibv_context* context) {
@@ -409,8 +85,7 @@ ibv_pd* alloc_protection_domain(ibv_context* context) {
}
RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
- : context_(open_device(set_device())),
- params_(params_init(context_)),
+ : context_(open_default_device()),
pd_(alloc_protection_domain(context_)),
worker_env_(worker_env) {
event_channel_ = ibv_create_comp_channel(context_);
@@ -453,9 +128,9 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
- << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
- << wc_[i].vendor_err;
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@@ -567,8 +242,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
memset(&attr, 0, sizeof(ibv_qp_init_attr));
attr.send_cq = adapter_->cq_;
attr.recv_cq = adapter_->cq_;
- attr.cap.max_send_wr = adapter_->params_.queue_depth;
- attr.cap.max_recv_wr = adapter_->params_.queue_depth;
+ attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
attr.cap.max_send_sge = 1;
attr.cap.max_recv_sge = 1;
attr.qp_type = IBV_QPT_RC;
@@ -582,8 +257,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_INIT;
- attr.pkey_index = adapter_->params_.pkey_index;
- attr.port_num = adapter_->params_.port_num;
+ attr.pkey_index = 0;
+ attr.port_num = 1;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int mask =
@@ -594,15 +269,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// Local address
{
struct ibv_port_attr attr;
- CHECK(
- !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
<< "Query port";
self_.lid = attr.lid;
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid;
- CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
- adapter_->params_.sgid_index, &gid))
+ CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
<< "Query gid";
self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id;
@@ -611,7 +284,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ "tx_ack_buffer", "rx_ack_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
@@ -672,7 +345,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@@ -806,9 +479,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTR;
-
+ struct ibv_port_attr port_attr;
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr))
+ << "Query port failed";
// This assumes both QP's ports are configured with the same MTU
- attr.path_mtu = adapter_->params_.mtu;
+ attr.path_mtu = port_attr.active_mtu;
attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1;
@@ -819,32 +494,30 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.hop_limit = 255;
attr.ah_attr.dlid = remoteAddr.lid;
- attr.ah_attr.sl = adapter_->params_.sl;
+ attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
- attr.ah_attr.port_num = adapter_->params_.port_num;
- attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
- attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
+ attr.ah_attr.port_num = 1;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
- IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = self_.psn;
- attr.timeout = adapter_->params_.timeout;
- attr.retry_cnt = adapter_->params_.retry_cnt;
+ attr.timeout = 14;
+ attr.retry_cnt = 7;
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
- IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
@@ -931,7 +604,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
@@ -1026,9 +699,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
- CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
- << " gpu_info: "
- << src_dev->tensorflow_gpu_device_info();
+ CHECK(send_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
@@ -1054,8 +727,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead, send_args, recv_args](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;