diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 413 |
1 files changed, 370 insertions, 43 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 26e18b28aa..331943a3ef 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/contrib/verbs/rdma.h" #include <cstdlib> +#include <fcntl.h> #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -33,6 +34,8 @@ limitations under the License. namespace tensorflow { +#define RoCE_V2 "RoCE v2" + namespace { // hash name to 32-bit integer uint32_t NameHash(const string& name) { @@ -66,16 +69,337 @@ string MessageTypeToString(RdmaMessageType rmt) { } } // namespace -ibv_context* open_default_device() { +// Function to get environment variable +// Args: +// var_name - the name of the environmental variable +// Returns: +// string with it's value or empty string if not set +string get_env_var(char const* var_name) { + char const* var_temp = getenv(var_name); + + return (var_temp == NULL) ? string() : string(var_temp); +} + +// Function to open device +// Args: +// ibv_dev device to open +// Returns: +// context of the opened device +ibv_context* open_device(ibv_device* ibv_dev) { + ibv_context* context = ibv_open_device(ibv_dev); + + CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev); + return context; +} + +// Function to count the number of active ports for device +// Args: +// device - to check active ports +// Returns: +// number of active ports of the given device +int get_dev_active_port_count(ibv_device* device) { + ibv_device_attr device_att; + ibv_port_attr port_attr; + ibv_context* context = NULL; + int rc, port_index, active_ports = 0; + + context = ibv_open_device(device); + CHECK(context) << "Open context failed for " << ibv_get_device_name(device); + rc = ibv_query_device(context, &device_att); + CHECK(!rc) << "Failed to query the device"; + + for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { + rc = ibv_query_port(context, port_index, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_index; + if (port_attr.state == IBV_PORT_ACTIVE) { + active_ports++; + } + } + ibv_close_device(context); + return active_ports; +} + +// Function to set device. If RDMA_DEVICE not set, search for device with active +// port. +// Fails if more than one device with active port was found. +// Returns: +// device to use +ibv_device* set_device() { ibv_device** dev_list; - ibv_device* ib_dev; - dev_list = ibv_get_device_list(NULL); + int dev_num, device_index, device_to_open = 0; + int num_devs_with_active_port = 0; + string env_p_rdma_device, str_port_num; + + dev_list = ibv_get_device_list(&dev_num); CHECK(dev_list) << "No InfiniBand device found"; - ib_dev = dev_list[0]; - CHECK(ib_dev) << "No InfiniBand device found"; - ibv_context* context = ibv_open_device(ib_dev); - CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev); - return context; + + env_p_rdma_device = get_env_var("RDMA_DEVICE"); + if (!env_p_rdma_device.empty()) { + for (device_index = 0; device_index < dev_num; device_index++) { + if (!env_p_rdma_device.compare( + ibv_get_device_name(dev_list[device_index]))) { + CHECK(get_dev_active_port_count(dev_list[device_index]) != 0) + << "Device " << ibv_get_device_name(dev_list[device_index]) + << " has no active ports"; + return dev_list[device_index]; + } + } + // check validity of input device + CHECK(false) << "The device " << env_p_rdma_device << " wasn't found"; + } else { + // set default device + str_port_num = get_env_var("RDMA_DEVICE_PORT"); + CHECK(str_port_num.empty()) + << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user"; + for (device_index = 0; device_index < dev_num; device_index++) { + // get port_num + if (get_dev_active_port_count(dev_list[device_index]) > 0) { + num_devs_with_active_port++; + CHECK(num_devs_with_active_port <= 1) << ". More than one device with " + "active port in the system. " + "Please enter RDMA_DEVICE"; + // found device with at least 1 active port + device_to_open = device_index; + } + } + CHECK(num_devs_with_active_port > 0) + << "There is no active port in the system"; + return dev_list[device_to_open]; + } + CHECK(false) << "No device was set!"; + return NULL; // never happens +} + +// Function to set port for device. +// If RDMA_DEVICE_PORT not set, first active port of the device will be set. +// Args: +// context of the device +// Returns: +// port to use +uint8_t set_port(ibv_context* context) { + uint8_t port_num = 0; //0 is illegal port number + string str_port_num; + ibv_device_attr device_att; + ibv_port_attr port_attr; + int rc, port_index; + + rc = ibv_query_device(context, &device_att); + CHECK(!rc) << "Failed to query the device\n"; + + str_port_num = get_env_var("RDMA_DEVICE_PORT"); + // user defined port + if (!str_port_num.empty()) { + port_num = stoi(str_port_num); + CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive"; + CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be " + "less or equal to amount of " + "available ports"; + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + // check if port id active + CHECK(port_attr.state == IBV_PORT_ACTIVE) + << "Selected RDMA_DEVICE_PORT is not active"; + } + // set default port + else { + for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { + rc = ibv_query_port(context, port_index, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_index; + if (port_attr.state == IBV_PORT_ACTIVE) { + port_num = port_index; + break; + } + } + CHECK_GT(port_num, 0) << "No active ports"; + } + return port_num; +} + +// Function read from sysfs file +// Args: +// dir - directory +// file - file +// buff - buffer for the result +// size - buffer size +// Returns: +// number of bytes were read or -1 if failed +int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) { + char* path; + int fd; + int len; + + if (asprintf(&path, "%s/%s", dir, file) < 0) return -1; + + fd = open(path, O_RDONLY); + if (fd < 0) { + free(path); + return -1; + } + + len = read(fd, buf, size); + + close(fd); + free(path); + + if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0'; + + return len; +} + +// Function to check if GID index support RoCE V2 +// Args: +// context - device context +// port_num - port number +// index - GID index +// Returns: +// if GID supports RoCE V2 - true, otherwise - false. +bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num, + uint8_t index) { + char name[32]; + char buff[41]; + + snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index); + if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <= + 0) { + return false; + } + return !strcmp(buff, RoCE_V2); +} + +// Function to set GID index. +// If the port link is IB, no GID index should be selected. +// If Ethernet but RDMA_GID_INDEX not set gid index that supports +// RoCE V2 will be chosen(fails if more then one IP is configured) +// Args: +// context - device context +// port_num - port number +// Returns: +// GID index to use +uint8_t set_gid(uint8_t port_num, ibv_context* context) { + ibv_port_attr port_attr; + string gid_str; + int rc, i, gids_num = 0, v2_ip_num = 0; + union ibv_gid gid; + uint8_t gid_index = 0; + + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + + for (i = 0; i < port_attr.gid_tbl_len; i++) { + rc = ibv_query_gid(context, port_num, i, &gid); + CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index " + << i; + if (gid.global.interface_id) { + gids_num++; + if (gid.global.subnet_prefix == 0 && + is_gid_type_roce_v2(context, port_num, i)) { + if (v2_ip_num == 0) { + // can be overwritten by RDMA_GID_INDEX later + gid_index = i; + } + v2_ip_num++; + } + } + } + switch (port_attr.link_layer) { + case(IBV_LINK_LAYER_ETHERNET) : + gid_str = get_env_var("RDMA_GID_INDEX"); + if (!gid_str.empty()) { + gid_index = stoi(gid_str); + CHECK(gid_index < gids_num) + << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num; + } else { + CHECK(v2_ip_num <= 1) + << "More than one IP is available, please specify GID_INDEX"; + } + break; + case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index + break; + default: + LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and " + "InfiniBand only. "; + } + if (!is_gid_type_roce_v2(context, port_num, gid_index)) { + LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index; + } + return gid_index; +} + +// set the default or environment value to the configuration parameter. +// Args: +// default_val- the default value for this parameter +// env_param- the environment parameter's name +// Returns: +// 32-bit value +uint32_t set_param(uint32_t default_val, const char* env_param) { + uint32_t val = default_val; + string val_s; + + val_s = get_env_var(env_param); + + if (!val_s.empty()) { + val = stoi(val_s); + } + return val; +} + +enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { + ibv_port_attr port_attr; + enum ibv_mtu mtu; + string mtu_s; + int rc, mtu_i; + + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + + mtu_s = get_env_var("RDMA_MTU"); + + if (!mtu_s.empty()) { + mtu_i = stoi(mtu_s); + switch (mtu_i) { + case 256: + mtu = IBV_MTU_256; + break; + case 512: + mtu = IBV_MTU_512; + break; + case 1024: + mtu = IBV_MTU_1024; + break; + case 2048: + mtu = IBV_MTU_2048; + break; + case 4096: + mtu = IBV_MTU_4096; + break; + default: + CHECK(0) << "Error: MTU input value must be one of the following: 256, " + "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n"; + break; + } + CHECK(mtu < port_attr.active_mtu) + << "MTU configuration for the QPs is larger than active MTU"; + } else { + mtu = port_attr.active_mtu; + } + return mtu; +} + +RdmaParams params_init(ibv_context* context) { + RdmaParams params; + + params.port_num = set_port(context); + params.sgid_index = set_gid(params.port_num, context); + params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY"); + params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH"); + params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT"); + params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT"); + params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL"); + CHECK(params.sl <= 7) << "SL value is " << (int)params.sl + << ". Valid values are 0-7."; + params.mtu = set_mtu(params.port_num, context); + params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS"); + return params; } ibv_pd* alloc_protection_domain(ibv_context* context) { @@ -85,7 +409,8 @@ ibv_pd* alloc_protection_domain(ibv_context* context) { } RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) - : context_(open_default_device()), + : context_(open_device(set_device())), + params_(params_init(context_)), pd_(alloc_protection_domain(context_)), worker_env_(worker_env) { event_channel_ = ibv_create_comp_channel(context_); @@ -128,9 +453,9 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" - << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " - << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err; + << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " + << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " " + << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id); // put back a recv wr. @@ -242,8 +567,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, memset(&attr, 0, sizeof(ibv_qp_init_attr)); attr.send_cq = adapter_->cq_; attr.recv_cq = adapter_->cq_; - attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; - attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; + attr.cap.max_send_wr = adapter_->params_.queue_depth; + attr.cap.max_recv_wr = adapter_->params_.queue_depth; attr.cap.max_send_sge = 1; attr.cap.max_recv_sge = 1; attr.qp_type = IBV_QPT_RC; @@ -257,8 +582,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = 0; - attr.port_num = 1; + attr.pkey_index = adapter_->params_.pkey_index; + attr.port_num = adapter_->params_.port_num; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; int mask = @@ -269,13 +594,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // Local address { struct ibv_port_attr attr; - CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr)) + CHECK( + !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr)) << "Query port"; self_.lid = attr.lid; self_.qpn = qp_->qp_num; self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff; union ibv_gid gid; - CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid)) + CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num, + adapter_->params_.sgid_index, &gid)) << "Query gid"; self_.snp = gid.global.subnet_prefix; self_.iid = gid.global.interface_id; @@ -284,7 +611,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + "tx_ack_buffer", "rx_ack_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); @@ -345,7 +672,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -479,11 +806,9 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTR; - struct ibv_port_attr port_attr; - CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr)) - << "Query port failed"; + // This assumes both QP's ports are configured with the same MTU - attr.path_mtu = port_attr.active_mtu; + attr.path_mtu = adapter_->params_.mtu; attr.dest_qp_num = remoteAddr.qpn; attr.rq_psn = remoteAddr.psn; attr.max_dest_rd_atomic = 1; @@ -494,30 +819,32 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.flow_label = 0; attr.ah_attr.grh.hop_limit = 255; attr.ah_attr.dlid = remoteAddr.lid; - attr.ah_attr.sl = 0; + attr.ah_attr.sl = adapter_->params_.sl; attr.ah_attr.src_path_bits = 0; - attr.ah_attr.port_num = 1; + attr.ah_attr.port_num = adapter_->params_.port_num; + attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index; + attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | + IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTS; attr.sq_psn = self_.psn; - attr.timeout = 14; - attr.retry_cnt = 7; + attr.timeout = adapter_->params_.timeout; + attr.retry_cnt = adapter_->params_.retry_cnt; attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; @@ -604,7 +931,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -699,9 +1026,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { - CHECK(send_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + CHECK(send_args.device_context) << "send dev name: " << src_dev->name() + << " gpu_info: " + << src_dev->tensorflow_gpu_device_info(); if (can_memcpy) { AllocatorAttributes host_alloc_attrs; @@ -727,8 +1054,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( // aync instead GPUUtil::SetProtoFromGPU( in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead, send_args, recv_args](const Status& s) mutable { CHECK(s.ok()) << "copy proto from gpu sync"; auto tensor_bytes = proto.ByteSize(); buffer_size += tensor_bytes; |