diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.cc')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.cc | 413 |
1 files changed, 43 insertions, 370 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 331943a3ef..26e18b28aa 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/contrib/verbs/rdma.h" #include <cstdlib> -#include <fcntl.h> #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -34,8 +33,6 @@ limitations under the License. namespace tensorflow { -#define RoCE_V2 "RoCE v2" - namespace { // hash name to 32-bit integer uint32_t NameHash(const string& name) { @@ -69,337 +66,16 @@ string MessageTypeToString(RdmaMessageType rmt) { } } // namespace -// Function to get environment variable -// Args: -// var_name - the name of the environmental variable -// Returns: -// string with it's value or empty string if not set -string get_env_var(char const* var_name) { - char const* var_temp = getenv(var_name); - - return (var_temp == NULL) ? string() : string(var_temp); -} - -// Function to open device -// Args: -// ibv_dev device to open -// Returns: -// context of the opened device -ibv_context* open_device(ibv_device* ibv_dev) { - ibv_context* context = ibv_open_device(ibv_dev); - - CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev); - return context; -} - -// Function to count the number of active ports for device -// Args: -// device - to check active ports -// Returns: -// number of active ports of the given device -int get_dev_active_port_count(ibv_device* device) { - ibv_device_attr device_att; - ibv_port_attr port_attr; - ibv_context* context = NULL; - int rc, port_index, active_ports = 0; - - context = ibv_open_device(device); - CHECK(context) << "Open context failed for " << ibv_get_device_name(device); - rc = ibv_query_device(context, &device_att); - CHECK(!rc) << "Failed to query the device"; - - for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { - rc = ibv_query_port(context, port_index, &port_attr); - CHECK(!rc) << "Failed to query the port" << port_index; - if (port_attr.state == IBV_PORT_ACTIVE) { - active_ports++; - } - } - ibv_close_device(context); - return active_ports; -} - -// Function to set device. If RDMA_DEVICE not set, search for device with active -// port. -// Fails if more than one device with active port was found. -// Returns: -// device to use -ibv_device* set_device() { +ibv_context* open_default_device() { ibv_device** dev_list; - int dev_num, device_index, device_to_open = 0; - int num_devs_with_active_port = 0; - string env_p_rdma_device, str_port_num; - - dev_list = ibv_get_device_list(&dev_num); + ibv_device* ib_dev; + dev_list = ibv_get_device_list(NULL); CHECK(dev_list) << "No InfiniBand device found"; - - env_p_rdma_device = get_env_var("RDMA_DEVICE"); - if (!env_p_rdma_device.empty()) { - for (device_index = 0; device_index < dev_num; device_index++) { - if (!env_p_rdma_device.compare( - ibv_get_device_name(dev_list[device_index]))) { - CHECK(get_dev_active_port_count(dev_list[device_index]) != 0) - << "Device " << ibv_get_device_name(dev_list[device_index]) - << " has no active ports"; - return dev_list[device_index]; - } - } - // check validity of input device - CHECK(false) << "The device " << env_p_rdma_device << " wasn't found"; - } else { - // set default device - str_port_num = get_env_var("RDMA_DEVICE_PORT"); - CHECK(str_port_num.empty()) - << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user"; - for (device_index = 0; device_index < dev_num; device_index++) { - // get port_num - if (get_dev_active_port_count(dev_list[device_index]) > 0) { - num_devs_with_active_port++; - CHECK(num_devs_with_active_port <= 1) << ". More than one device with " - "active port in the system. " - "Please enter RDMA_DEVICE"; - // found device with at least 1 active port - device_to_open = device_index; - } - } - CHECK(num_devs_with_active_port > 0) - << "There is no active port in the system"; - return dev_list[device_to_open]; - } - CHECK(false) << "No device was set!"; - return NULL; // never happens -} - -// Function to set port for device. -// If RDMA_DEVICE_PORT not set, first active port of the device will be set. -// Args: -// context of the device -// Returns: -// port to use -uint8_t set_port(ibv_context* context) { - uint8_t port_num = 0; //0 is illegal port number - string str_port_num; - ibv_device_attr device_att; - ibv_port_attr port_attr; - int rc, port_index; - - rc = ibv_query_device(context, &device_att); - CHECK(!rc) << "Failed to query the device\n"; - - str_port_num = get_env_var("RDMA_DEVICE_PORT"); - // user defined port - if (!str_port_num.empty()) { - port_num = stoi(str_port_num); - CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive"; - CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be " - "less or equal to amount of " - "available ports"; - rc = ibv_query_port(context, port_num, &port_attr); - CHECK(!rc) << "Failed to query the port" << port_num; - // check if port id active - CHECK(port_attr.state == IBV_PORT_ACTIVE) - << "Selected RDMA_DEVICE_PORT is not active"; - } - // set default port - else { - for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { - rc = ibv_query_port(context, port_index, &port_attr); - CHECK(!rc) << "Failed to query the port" << port_index; - if (port_attr.state == IBV_PORT_ACTIVE) { - port_num = port_index; - break; - } - } - CHECK_GT(port_num, 0) << "No active ports"; - } - return port_num; -} - -// Function read from sysfs file -// Args: -// dir - directory -// file - file -// buff - buffer for the result -// size - buffer size -// Returns: -// number of bytes were read or -1 if failed -int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) { - char* path; - int fd; - int len; - - if (asprintf(&path, "%s/%s", dir, file) < 0) return -1; - - fd = open(path, O_RDONLY); - if (fd < 0) { - free(path); - return -1; - } - - len = read(fd, buf, size); - - close(fd); - free(path); - - if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0'; - - return len; -} - -// Function to check if GID index support RoCE V2 -// Args: -// context - device context -// port_num - port number -// index - GID index -// Returns: -// if GID supports RoCE V2 - true, otherwise - false. -bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num, - uint8_t index) { - char name[32]; - char buff[41]; - - snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index); - if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <= - 0) { - return false; - } - return !strcmp(buff, RoCE_V2); -} - -// Function to set GID index. -// If the port link is IB, no GID index should be selected. -// If Ethernet but RDMA_GID_INDEX not set gid index that supports -// RoCE V2 will be chosen(fails if more then one IP is configured) -// Args: -// context - device context -// port_num - port number -// Returns: -// GID index to use -uint8_t set_gid(uint8_t port_num, ibv_context* context) { - ibv_port_attr port_attr; - string gid_str; - int rc, i, gids_num = 0, v2_ip_num = 0; - union ibv_gid gid; - uint8_t gid_index = 0; - - rc = ibv_query_port(context, port_num, &port_attr); - CHECK(!rc) << "Failed to query the port" << port_num; - - for (i = 0; i < port_attr.gid_tbl_len; i++) { - rc = ibv_query_gid(context, port_num, i, &gid); - CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index " - << i; - if (gid.global.interface_id) { - gids_num++; - if (gid.global.subnet_prefix == 0 && - is_gid_type_roce_v2(context, port_num, i)) { - if (v2_ip_num == 0) { - // can be overwritten by RDMA_GID_INDEX later - gid_index = i; - } - v2_ip_num++; - } - } - } - switch (port_attr.link_layer) { - case(IBV_LINK_LAYER_ETHERNET) : - gid_str = get_env_var("RDMA_GID_INDEX"); - if (!gid_str.empty()) { - gid_index = stoi(gid_str); - CHECK(gid_index < gids_num) - << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num; - } else { - CHECK(v2_ip_num <= 1) - << "More than one IP is available, please specify GID_INDEX"; - } - break; - case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index - break; - default: - LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and " - "InfiniBand only. "; - } - if (!is_gid_type_roce_v2(context, port_num, gid_index)) { - LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index; - } - return gid_index; -} - -// set the default or environment value to the configuration parameter. -// Args: -// default_val- the default value for this parameter -// env_param- the environment parameter's name -// Returns: -// 32-bit value -uint32_t set_param(uint32_t default_val, const char* env_param) { - uint32_t val = default_val; - string val_s; - - val_s = get_env_var(env_param); - - if (!val_s.empty()) { - val = stoi(val_s); - } - return val; -} - -enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { - ibv_port_attr port_attr; - enum ibv_mtu mtu; - string mtu_s; - int rc, mtu_i; - - rc = ibv_query_port(context, port_num, &port_attr); - CHECK(!rc) << "Failed to query the port" << port_num; - - mtu_s = get_env_var("RDMA_MTU"); - - if (!mtu_s.empty()) { - mtu_i = stoi(mtu_s); - switch (mtu_i) { - case 256: - mtu = IBV_MTU_256; - break; - case 512: - mtu = IBV_MTU_512; - break; - case 1024: - mtu = IBV_MTU_1024; - break; - case 2048: - mtu = IBV_MTU_2048; - break; - case 4096: - mtu = IBV_MTU_4096; - break; - default: - CHECK(0) << "Error: MTU input value must be one of the following: 256, " - "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n"; - break; - } - CHECK(mtu < port_attr.active_mtu) - << "MTU configuration for the QPs is larger than active MTU"; - } else { - mtu = port_attr.active_mtu; - } - return mtu; -} - -RdmaParams params_init(ibv_context* context) { - RdmaParams params; - - params.port_num = set_port(context); - params.sgid_index = set_gid(params.port_num, context); - params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY"); - params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH"); - params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT"); - params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT"); - params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL"); - CHECK(params.sl <= 7) << "SL value is " << (int)params.sl - << ". Valid values are 0-7."; - params.mtu = set_mtu(params.port_num, context); - params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS"); - return params; + ib_dev = dev_list[0]; + CHECK(ib_dev) << "No InfiniBand device found"; + ibv_context* context = ibv_open_device(ib_dev); + CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev); + return context; } ibv_pd* alloc_protection_domain(ibv_context* context) { @@ -409,8 +85,7 @@ ibv_pd* alloc_protection_domain(ibv_context* context) { } RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) - : context_(open_device(set_device())), - params_(params_init(context_)), + : context_(open_default_device()), pd_(alloc_protection_domain(context_)), worker_env_(worker_env) { event_channel_ = ibv_create_comp_channel(context_); @@ -453,9 +128,9 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " - << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " " - << wc_[i].vendor_err; + << "Failed status \n" + << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " + << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id); // put back a recv wr. @@ -567,8 +242,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, memset(&attr, 0, sizeof(ibv_qp_init_attr)); attr.send_cq = adapter_->cq_; attr.recv_cq = adapter_->cq_; - attr.cap.max_send_wr = adapter_->params_.queue_depth; - attr.cap.max_recv_wr = adapter_->params_.queue_depth; + attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; + attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; attr.cap.max_send_sge = 1; attr.cap.max_recv_sge = 1; attr.qp_type = IBV_QPT_RC; @@ -582,8 +257,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = adapter_->params_.pkey_index; - attr.port_num = adapter_->params_.port_num; + attr.pkey_index = 0; + attr.port_num = 1; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; int mask = @@ -594,15 +269,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // Local address { struct ibv_port_attr attr; - CHECK( - !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr)) + CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr)) << "Query port"; self_.lid = attr.lid; self_.qpn = qp_->qp_num; self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff; union ibv_gid gid; - CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num, - adapter_->params_.sgid_index, &gid)) + CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid)) << "Query gid"; self_.snp = gid.global.subnet_prefix; self_.iid = gid.global.interface_id; @@ -611,7 +284,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + "tx_ack_buffer", "rx_ack_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); @@ -672,7 +345,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -806,9 +479,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTR; - + struct ibv_port_attr port_attr; + CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr)) + << "Query port failed"; // This assumes both QP's ports are configured with the same MTU - attr.path_mtu = adapter_->params_.mtu; + attr.path_mtu = port_attr.active_mtu; attr.dest_qp_num = remoteAddr.qpn; attr.rq_psn = remoteAddr.psn; attr.max_dest_rd_atomic = 1; @@ -819,32 +494,30 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.flow_label = 0; attr.ah_attr.grh.hop_limit = 255; attr.ah_attr.dlid = remoteAddr.lid; - attr.ah_attr.sl = adapter_->params_.sl; + attr.ah_attr.sl = 0; attr.ah_attr.src_path_bits = 0; - attr.ah_attr.port_num = adapter_->params_.port_num; - attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index; - attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; + attr.ah_attr.port_num = 1; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | - IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTS; attr.sq_psn = self_.psn; - attr.timeout = adapter_->params_.timeout; - attr.retry_cnt = adapter_->params_.retry_cnt; + attr.timeout = 14; + attr.retry_cnt = 7; attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | - IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; @@ -931,7 +604,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -1026,9 +699,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { - CHECK(send_args.device_context) << "send dev name: " << src_dev->name() - << " gpu_info: " - << src_dev->tensorflow_gpu_device_info(); + CHECK(send_args.device_context) + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); if (can_memcpy) { AllocatorAttributes host_alloc_attrs; @@ -1054,8 +727,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( // aync instead GPUUtil::SetProtoFromGPU( in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead, send_args, recv_args](const Status& s) mutable { CHECK(s.ok()) << "copy proto from gpu sync"; auto tensor_bytes = proto.ByteSize(); buffer_size += tensor_bytes; |