aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2017-11-22 13:42:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-22 13:50:02 -0800
commitb1d8c59e9b014b527fb2fbef9ce9afc14dbc4938 (patch)
treeaf207d5a90f3176bdd3fbffbe1e98258125bf389 /tensorflow/contrib/verbs
parente219aeb542779d90a582ffe16f8602cd1b275b22 (diff)
Merge changes from github.
PiperOrigin-RevId: 176695926
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r--tensorflow/contrib/verbs/README.md14
-rw-r--r--tensorflow/contrib/verbs/rdma.cc372
-rw-r--r--tensorflow/contrib/verbs/rdma.h21
3 files changed, 382 insertions, 25 deletions
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
index da5f2b0223..dcb390b0a5 100644
--- a/tensorflow/contrib/verbs/README.md
+++ b/tensorflow/contrib/verbs/README.md
@@ -1,4 +1,4 @@
-## How to compile and use RDMA-enabled TensorFlow
+## How to compile, use and configure RDMA-enabled TensorFlow
1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based RDMA support, answer yes to this question:
```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
@@ -7,6 +7,18 @@
```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
+3. RDMA configuration is done by setting the following environment variables:
+ * **RDMA_DEVICE**: The RDMA device name to be used. If not defined by user, a default device with an active port will be set if exists.
+ * **RDMA_DEVICE_PORT**: The port within the selected device. Not relevant if RDMA_DEVICE is not defined. If not defined by user, a default active port will be set if exists.
+ * **RDMA_GID_INDEX**: The GID index of the port. If not defined by user, a default suitable GID index will be set (RoCEV2 is favourable as default).
+ * **RDMA_QP_PKEY_INDEX**: The Pkey for the QP. If not defined by user, the default value is 0.
+ * **RDMA_QP_QUEUE_DEPTH**: TX/RX queue size for the QP. If not defined by user, the default value is 1024.
+ * **RDMA_QP_TIMEOUT**: The retransmission timeout for QPs. If not defined by user, the default value is 14.
+ * **RDMA_QP_RETRY_COUNT**: Number of retransmission for QPs. If not defined by user, the default value is 7.
+ * **RDMA_QP_SL**: Service level configuration for QOS and ECN, valid values are 0-7. If not defined by user, the default value is 0.
+ * **RDMA_QP_MTU**: MTU configuration for the QPs. If not defined by user, the default value is active MTU from query_port.
+ * **RDMA_TRAFFIC_CLASS**: Traffic class configuration for QP, in case of DSCP trust level QoS configuration. If not defined by user, the default value is 0. For more info see [HowTo Configure Trust state on Mellanox Adapters](https://community.mellanox.com/docs/DOC-2866).
+
## Overview
The design is based on TensorFlow r1.0. An RDMA path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the RDMA path, exchanging computation graphs, etc.
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 26e18b28aa..ac8d994502 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -16,6 +16,7 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma.h"
+#include <fcntl.h>
#include <cstdlib>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -33,6 +34,8 @@ limitations under the License.
namespace tensorflow {
+#define RoCE_V2 "RoCE v2"
+
namespace {
// hash name to 32-bit integer
uint32_t NameHash(const string& name) {
@@ -66,16 +69,336 @@ string MessageTypeToString(RdmaMessageType rmt) {
}
} // namespace
-ibv_context* open_default_device() {
+// Function to get environment variable
+// Args:
+// var_name - the name of the environmental variable
+// Returns:
+// string with it's value or empty string if not set
+string get_env_var(char const* var_name) {
+ char const* var_temp = getenv(var_name);
+
+ return (var_temp == NULL) ? string() : string(var_temp);
+}
+
+// Function to open device
+// Args:
+// ibv_dev device to open
+// Returns:
+// context of the opened device
+ibv_context* open_device(ibv_device* ibv_dev) {
+ ibv_context* context = ibv_open_device(ibv_dev);
+
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
+ return context;
+}
+
+// Function to count the number of active ports for device
+// Args:
+// device - to check active ports
+// Returns:
+// number of active ports of the given device
+int get_dev_active_port_count(ibv_device* device) {
+ ibv_device_attr device_att;
+ ibv_port_attr port_attr;
+ ibv_context* context = NULL;
+ int rc, port_index, active_ports = 0;
+
+ context = ibv_open_device(device);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
+ rc = ibv_query_device(context, &device_att);
+ CHECK(!rc) << "Failed to query the device";
+
+ for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
+ rc = ibv_query_port(context, port_index, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_index;
+ if (port_attr.state == IBV_PORT_ACTIVE) {
+ active_ports++;
+ }
+ }
+ ibv_close_device(context);
+ return active_ports;
+}
+
+// Function to set device. If RDMA_DEVICE not set, search for device with active
+// port.
+// Fails if more than one device with active port was found.
+// Returns:
+// device to use
+ibv_device* set_device() {
ibv_device** dev_list;
- ibv_device* ib_dev;
- dev_list = ibv_get_device_list(NULL);
+ int dev_num, device_index, device_to_open = 0;
+ int num_devs_with_active_port = 0;
+ string env_p_rdma_device, str_port_num;
+
+ dev_list = ibv_get_device_list(&dev_num);
CHECK(dev_list) << "No InfiniBand device found";
- ib_dev = dev_list[0];
- CHECK(ib_dev) << "No InfiniBand device found";
- ibv_context* context = ibv_open_device(ib_dev);
- CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
- return context;
+
+ env_p_rdma_device = get_env_var("RDMA_DEVICE");
+ if (!env_p_rdma_device.empty()) {
+ for (device_index = 0; device_index < dev_num; device_index++) {
+ if (!env_p_rdma_device.compare(
+ ibv_get_device_name(dev_list[device_index]))) {
+ CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
+ << "Device " << ibv_get_device_name(dev_list[device_index])
+ << " has no active ports";
+ return dev_list[device_index];
+ }
+ }
+ // check validity of input device
+ CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
+ } else {
+ // set default device
+ str_port_num = get_env_var("RDMA_DEVICE_PORT");
+ CHECK(str_port_num.empty())
+ << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
+ for (device_index = 0; device_index < dev_num; device_index++) {
+ // get port_num
+ if (get_dev_active_port_count(dev_list[device_index]) > 0) {
+ num_devs_with_active_port++;
+ CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
+ "active port in the system. "
+ "Please enter RDMA_DEVICE";
+ // found device with at least 1 active port
+ device_to_open = device_index;
+ }
+ }
+ CHECK(num_devs_with_active_port > 0)
+ << "There is no active port in the system";
+ return dev_list[device_to_open];
+ }
+ CHECK(false) << "No device was set!";
+ return NULL; // never happens
+}
+
+// Function to set port for device.
+// If RDMA_DEVICE_PORT not set, first active port of the device will be set.
+// Args:
+// context of the device
+// Returns:
+// port to use
+uint8_t set_port(ibv_context* context) {
+ uint8_t port_num = 0; // 0 is illegal port number
+ string str_port_num;
+ ibv_device_attr device_att;
+ ibv_port_attr port_attr;
+ int rc, port_index;
+
+ rc = ibv_query_device(context, &device_att);
+ CHECK(!rc) << "Failed to query the device\n";
+
+ str_port_num = get_env_var("RDMA_DEVICE_PORT");
+ // user defined port
+ if (!str_port_num.empty()) {
+ port_num = stoi(str_port_num);
+ CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
+ CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
+ "less or equal to amount of "
+ "available ports";
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+ // check if port id active
+ CHECK(port_attr.state == IBV_PORT_ACTIVE)
+ << "Selected RDMA_DEVICE_PORT is not active";
+ } else { // set default port
+ for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
+ rc = ibv_query_port(context, port_index, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_index;
+ if (port_attr.state == IBV_PORT_ACTIVE) {
+ port_num = port_index;
+ break;
+ }
+ }
+ CHECK_GT(port_num, 0) << "No active ports";
+ }
+ return port_num;
+}
+
+// Function read from sysfs file
+// Args:
+// dir - directory
+// file - file
+// buff - buffer for the result
+// size - buffer size
+// Returns:
+// number of bytes were read or -1 if failed
+int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
+ char* path;
+ int fd;
+ int len;
+
+ if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
+
+ fd = open(path, O_RDONLY);
+ if (fd < 0) {
+ free(path);
+ return -1;
+ }
+
+ len = read(fd, buf, size);
+
+ close(fd);
+ free(path);
+
+ if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
+
+ return len;
+}
+
+// Function to check if GID index support RoCE V2
+// Args:
+// context - device context
+// port_num - port number
+// index - GID index
+// Returns:
+// if GID supports RoCE V2 - true, otherwise - false.
+bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
+ uint8_t index) {
+ char name[32];
+ char buff[41];
+
+ snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
+ if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
+ 0) {
+ return false;
+ }
+ return !strcmp(buff, RoCE_V2);
+}
+
+// Function to set GID index.
+// If the port link is IB, no GID index should be selected.
+// If Ethernet but RDMA_GID_INDEX not set gid index that supports
+// RoCE V2 will be chosen(fails if more than one IP is configured)
+// Args:
+// context - device context
+// port_num - port number
+// Returns:
+// GID index to use
+uint8_t set_gid(uint8_t port_num, ibv_context* context) {
+ ibv_port_attr port_attr;
+ string gid_str;
+ int rc, i, gids_num = 0, v2_ip_num = 0;
+ union ibv_gid gid;
+ uint8_t gid_index = 0;
+
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+
+ for (i = 0; i < port_attr.gid_tbl_len; i++) {
+ rc = ibv_query_gid(context, port_num, i, &gid);
+ CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
+ << i;
+ if (gid.global.interface_id) {
+ gids_num++;
+ if (gid.global.subnet_prefix == 0 &&
+ is_gid_type_roce_v2(context, port_num, i)) {
+ if (v2_ip_num == 0) {
+ // can be overwritten by RDMA_GID_INDEX later
+ gid_index = i;
+ }
+ v2_ip_num++;
+ }
+ }
+ }
+ switch (port_attr.link_layer) {
+ case (IBV_LINK_LAYER_ETHERNET):
+ gid_str = get_env_var("RDMA_GID_INDEX");
+ if (!gid_str.empty()) {
+ gid_index = stoi(gid_str);
+ CHECK(gid_index < gids_num)
+ << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
+ } else {
+ CHECK(v2_ip_num <= 1)
+ << "More than one IP is available, please specify GID_INDEX";
+ }
+ break;
+ case (IBV_LINK_LAYER_INFINIBAND): // no need in GID index
+ break;
+ default:
+ LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
+ "InfiniBand only. ";
+ }
+ if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
+ LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
+ }
+ return gid_index;
+}
+
+// set the default or environment value to the configuration parameter.
+// Args:
+// default_val- the default value for this parameter
+// env_param- the environment parameter's name
+// Returns:
+// 32-bit value
+uint32_t set_param(uint32_t default_val, const char* env_param) {
+ uint32_t val = default_val;
+ string val_s;
+
+ val_s = get_env_var(env_param);
+
+ if (!val_s.empty()) {
+ val = stoi(val_s);
+ }
+ return val;
+}
+
+enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
+ ibv_port_attr port_attr;
+ enum ibv_mtu mtu;
+ string mtu_s;
+ int rc, mtu_i;
+
+ rc = ibv_query_port(context, port_num, &port_attr);
+ CHECK(!rc) << "Failed to query the port" << port_num;
+
+ mtu_s = get_env_var("RDMA_MTU");
+
+ if (!mtu_s.empty()) {
+ mtu_i = stoi(mtu_s);
+ switch (mtu_i) {
+ case 256:
+ mtu = IBV_MTU_256;
+ break;
+ case 512:
+ mtu = IBV_MTU_512;
+ break;
+ case 1024:
+ mtu = IBV_MTU_1024;
+ break;
+ case 2048:
+ mtu = IBV_MTU_2048;
+ break;
+ case 4096:
+ mtu = IBV_MTU_4096;
+ break;
+ default:
+ CHECK(0) << "Error: MTU input value must be one of the following: 256, "
+ "512, 1024, 2048, 4096. MTU "
+ << mtu << " is invalid\n";
+ break;
+ }
+ CHECK(mtu < port_attr.active_mtu)
+ << "MTU configuration for the QPs is larger than active MTU";
+ } else {
+ mtu = port_attr.active_mtu;
+ }
+ return mtu;
+}
+
+RdmaParams params_init(ibv_context* context) {
+ RdmaParams params;
+
+ params.port_num = set_port(context);
+ params.sgid_index = set_gid(params.port_num, context);
+ params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
+ params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
+ params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
+ params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
+ params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
+ CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
+ << ". Valid values are 0-7.";
+ params.mtu = set_mtu(params.port_num, context);
+ params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
+ return params;
}
ibv_pd* alloc_protection_domain(ibv_context* context) {
@@ -85,7 +408,8 @@ ibv_pd* alloc_protection_domain(ibv_context* context) {
}
RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
- : context_(open_default_device()),
+ : context_(open_device(set_device())),
+ params_(params_init(context_)),
pd_(alloc_protection_domain(context_)),
worker_env_(worker_env) {
event_channel_ = ibv_create_comp_channel(context_);
@@ -242,8 +566,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
memset(&attr, 0, sizeof(ibv_qp_init_attr));
attr.send_cq = adapter_->cq_;
attr.recv_cq = adapter_->cq_;
- attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
- attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_send_wr = adapter_->params_.queue_depth;
+ attr.cap.max_recv_wr = adapter_->params_.queue_depth;
attr.cap.max_send_sge = 1;
attr.cap.max_recv_sge = 1;
attr.qp_type = IBV_QPT_RC;
@@ -257,8 +581,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_INIT;
- attr.pkey_index = 0;
- attr.port_num = 1;
+ attr.pkey_index = adapter_->params_.pkey_index;
+ attr.port_num = adapter_->params_.port_num;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int mask =
@@ -269,13 +593,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// Local address
{
struct ibv_port_attr attr;
- CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
+ CHECK(
+ !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
<< "Query port";
self_.lid = attr.lid;
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid;
- CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
+ CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
+ adapter_->params_.sgid_index, &gid))
<< "Query gid";
self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id;
@@ -479,11 +805,9 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTR;
- struct ibv_port_attr port_attr;
- CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr))
- << "Query port failed";
+
// This assumes both QP's ports are configured with the same MTU
- attr.path_mtu = port_attr.active_mtu;
+ attr.path_mtu = adapter_->params_.mtu;
attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1;
@@ -494,9 +818,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.hop_limit = 255;
attr.ah_attr.dlid = remoteAddr.lid;
- attr.ah_attr.sl = 0;
+ attr.ah_attr.sl = adapter_->params_.sl;
attr.ah_attr.src_path_bits = 0;
- attr.ah_attr.port_num = 1;
+ attr.ah_attr.port_num = adapter_->params_.port_num;
+ attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
+ attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
int r;
CHECK(!(r = ibv_modify_qp(qp_, &attr,
@@ -509,8 +835,8 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = self_.psn;
- attr.timeout = 14;
- attr.retry_cnt = 7;
+ attr.timeout = adapter_->params_.timeout;
+ attr.retry_cnt = adapter_->params_.retry_cnt;
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index e1e07db776..00217c81d4 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -36,7 +36,24 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
-
+#define PKEY_DEFAULT 0
+#define QUEUE_DEPTH_DEFAULT 1024
+#define TIMEOUT_DEFAULT 14
+#define RETRY_CNT_DEFAULT 7
+#define SL_DEFAULT 0
+#define TRAFFIC_CLASS 0
+
+struct RdmaParams {
+ uint8_t port_num;
+ uint8_t sgid_index;
+ uint8_t pkey_index;
+ uint32_t queue_depth;
+ uint8_t timeout;
+ uint8_t retry_cnt;
+ uint8_t sl;
+ enum ibv_mtu mtu;
+ uint8_t traffic_class;
+};
// structure to save the address of remote channels.
struct RdmaAddress {
uint32_t lid;
@@ -84,6 +101,8 @@ class RdmaAdapter {
protected:
static const int MAX_CONCURRENT_WRITES = 1000;
ibv_context* context_;
+ // RDMA configuration parameters
+ RdmaParams params_;
// ibverbs protection domain
ibv_pd* pd_;
// Completion event channel, to wait for work completions