aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2017-11-22 00:39:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-22 00:42:30 -0800
commite70c00950d295c519fd9c7f8b12e13a3c5aaf710 (patch)
treeee1210f8810e0b0fec9346f762e854b371899919 /tensorflow/contrib/verbs
parentad7eeec1cc06d7fdba6ee404f03a35fab9cd3e6a (diff)
Automated g4 rollback of changelist 176615107
PiperOrigin-RevId: 176622438
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r--tensorflow/contrib/verbs/README.md14
-rw-r--r--tensorflow/contrib/verbs/rdma.cc413
-rw-r--r--tensorflow/contrib/verbs/rdma.h40
3 files changed, 49 insertions, 418 deletions
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
index dcb390b0a5..da5f2b0223 100644
--- a/tensorflow/contrib/verbs/README.md
+++ b/tensorflow/contrib/verbs/README.md
@@ -1,4 +1,4 @@
-## How to compile, use and configure RDMA-enabled TensorFlow
+## How to compile and use RDMA-enabled TensorFlow
1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based RDMA support, answer yes to this question:
```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
@@ -7,18 +7,6 @@
```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
-3. RDMA configuration is done by setting the following environment variables:
- * **RDMA_DEVICE**: The RDMA device name to be used. If not defined by user, a default device with an active port will be set if exists.
- * **RDMA_DEVICE_PORT**: The port within the selected device. Not relevant if RDMA_DEVICE is not defined. If not defined by user, a default active port will be set if exists.
- * **RDMA_GID_INDEX**: The GID index of the port. If not defined by user, a default suitable GID index will be set (RoCEV2 is favourable as default).
- * **RDMA_QP_PKEY_INDEX**: The Pkey for the QP. If not defined by user, the default value is 0.
- * **RDMA_QP_QUEUE_DEPTH**: TX/RX queue size for the QP. If not defined by user, the default value is 1024.
- * **RDMA_QP_TIMEOUT**: The retransmission timeout for QPs. If not defined by user, the default value is 14.
- * **RDMA_QP_RETRY_COUNT**: Number of retransmission for QPs. If not defined by user, the default value is 7.
- * **RDMA_QP_SL**: Service level configuration for QOS and ECN, valid values are 0-7. If not defined by user, the default value is 0.
- * **RDMA_QP_MTU**: MTU configuration for the QPs. If not defined by user, the default value is active MTU from query_port.
- * **RDMA_TRAFFIC_CLASS**: Traffic class configuration for QP, in case of DSCP trust level QoS configuration. If not defined by user, the default value is 0. For more info see [HowTo Configure Trust state on Mellanox Adapters](https://community.mellanox.com/docs/DOC-2866).
-
## Overview
The design is based on TensorFlow r1.0. An RDMA path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the RDMA path, exchanging computation graphs, etc.
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 331943a3ef..26e18b28aa 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/contrib/verbs/rdma.h"
#include <cstdlib>
-#include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -34,8 +33,6 @@ limitations under the License.
namespace tensorflow {
-#define RoCE_V2 "RoCE v2"
-
namespace {
// hash name to 32-bit integer
uint32_t NameHash(const string& name) {
@@ -69,337 +66,16 @@ string MessageTypeToString(RdmaMessageType rmt) {
}
} // namespace
-// Function to get environment variable
-// Args:
-// var_name - the name of the environmental variable
-// Returns:
-// string with it's value or empty string if not set
-string get_env_var(char const* var_name) {
- char const* var_temp = getenv(var_name);
-
- return (var_temp == NULL) ? string() : string(var_temp);
-}
-
-// Function to open device
-// Args:
-// ibv_dev device to open
-// Returns:
-// context of the opened device
-ibv_context* open_device(ibv_device* ibv_dev) {
- ibv_context* context = ibv_open_device(ibv_dev);
-
- CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
- return context;
-}
-
-// Function to count the number of active ports for device
-// Args:
-// device - to check active ports
-// Returns:
-// number of active ports of the given device
-int get_dev_active_port_count(ibv_device* device) {
- ibv_device_attr device_att;
- ibv_port_attr port_attr;
- ibv_context* context = NULL;
- int rc, port_index, active_ports = 0;
-
- context = ibv_open_device(device);
- CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
- rc = ibv_query_device(context, &device_att);
- CHECK(!rc) << "Failed to query the device";
-
- for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
- rc = ibv_query_port(context, port_index, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_index;
- if (port_attr.state == IBV_PORT_ACTIVE) {
- active_ports++;
- }
- }
- ibv_close_device(context);
- return active_ports;
-}
-
-// Function to set device. If RDMA_DEVICE not set, search for device with active
-// port.
-// Fails if more than one device with active port was found.
-// Returns:
-// device to use
-ibv_device* set_device() {
+ibv_context* open_default_device() {
ibv_device** dev_list;
- int dev_num, device_index, device_to_open = 0;
- int num_devs_with_active_port = 0;
- string env_p_rdma_device, str_port_num;
-
- dev_list = ibv_get_device_list(&dev_num);
+ ibv_device* ib_dev;
+ dev_list = ibv_get_device_list(NULL);
CHECK(dev_list) << "No InfiniBand device found";
-
- env_p_rdma_device = get_env_var("RDMA_DEVICE");
- if (!env_p_rdma_device.empty()) {
- for (device_index = 0; device_index < dev_num; device_index++) {
- if (!env_p_rdma_device.compare(
- ibv_get_device_name(dev_list[device_index]))) {
- CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
- << "Device " << ibv_get_device_name(dev_list[device_index])
- << " has no active ports";
- return dev_list[device_index];
- }
- }
- // check validity of input device
- CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
- } else {
- // set default device
- str_port_num = get_env_var("RDMA_DEVICE_PORT");
- CHECK(str_port_num.empty())
- << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
- for (device_index = 0; device_index < dev_num; device_index++) {
- // get port_num
- if (get_dev_active_port_count(dev_list[device_index]) > 0) {
- num_devs_with_active_port++;
- CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
- "active port in the system. "
- "Please enter RDMA_DEVICE";
- // found device with at least 1 active port
- device_to_open = device_index;
- }
- }
- CHECK(num_devs_with_active_port > 0)
- << "There is no active port in the system";
- return dev_list[device_to_open];
- }
- CHECK(false) << "No device was set!";
- return NULL; // never happens
-}
-
-// Function to set port for device.
-// If RDMA_DEVICE_PORT not set, first active port of the device will be set.
-// Args:
-// context of the device
-// Returns:
-// port to use
-uint8_t set_port(ibv_context* context) {
- uint8_t port_num = 0; //0 is illegal port number
- string str_port_num;
- ibv_device_attr device_att;
- ibv_port_attr port_attr;
- int rc, port_index;
-
- rc = ibv_query_device(context, &device_att);
- CHECK(!rc) << "Failed to query the device\n";
-
- str_port_num = get_env_var("RDMA_DEVICE_PORT");
- // user defined port
- if (!str_port_num.empty()) {
- port_num = stoi(str_port_num);
- CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
- CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
- "less or equal to amount of "
- "available ports";
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
- // check if port id active
- CHECK(port_attr.state == IBV_PORT_ACTIVE)
- << "Selected RDMA_DEVICE_PORT is not active";
- }
- // set default port
- else {
- for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
- rc = ibv_query_port(context, port_index, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_index;
- if (port_attr.state == IBV_PORT_ACTIVE) {
- port_num = port_index;
- break;
- }
- }
- CHECK_GT(port_num, 0) << "No active ports";
- }
- return port_num;
-}
-
-// Function read from sysfs file
-// Args:
-// dir - directory
-// file - file
-// buff - buffer for the result
-// size - buffer size
-// Returns:
-// number of bytes were read or -1 if failed
-int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
- char* path;
- int fd;
- int len;
-
- if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
-
- fd = open(path, O_RDONLY);
- if (fd < 0) {
- free(path);
- return -1;
- }
-
- len = read(fd, buf, size);
-
- close(fd);
- free(path);
-
- if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
-
- return len;
-}
-
-// Function to check if GID index support RoCE V2
-// Args:
-// context - device context
-// port_num - port number
-// index - GID index
-// Returns:
-// if GID supports RoCE V2 - true, otherwise - false.
-bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
- uint8_t index) {
- char name[32];
- char buff[41];
-
- snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
- if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
- 0) {
- return false;
- }
- return !strcmp(buff, RoCE_V2);
-}
-
-// Function to set GID index.
-// If the port link is IB, no GID index should be selected.
-// If Ethernet but RDMA_GID_INDEX not set gid index that supports
-// RoCE V2 will be chosen(fails if more then one IP is configured)
-// Args:
-// context - device context
-// port_num - port number
-// Returns:
-// GID index to use
-uint8_t set_gid(uint8_t port_num, ibv_context* context) {
- ibv_port_attr port_attr;
- string gid_str;
- int rc, i, gids_num = 0, v2_ip_num = 0;
- union ibv_gid gid;
- uint8_t gid_index = 0;
-
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
-
- for (i = 0; i < port_attr.gid_tbl_len; i++) {
- rc = ibv_query_gid(context, port_num, i, &gid);
- CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
- << i;
- if (gid.global.interface_id) {
- gids_num++;
- if (gid.global.subnet_prefix == 0 &&
- is_gid_type_roce_v2(context, port_num, i)) {
- if (v2_ip_num == 0) {
- // can be overwritten by RDMA_GID_INDEX later
- gid_index = i;
- }
- v2_ip_num++;
- }
- }
- }
- switch (port_attr.link_layer) {
- case(IBV_LINK_LAYER_ETHERNET) :
- gid_str = get_env_var("RDMA_GID_INDEX");
- if (!gid_str.empty()) {
- gid_index = stoi(gid_str);
- CHECK(gid_index < gids_num)
- << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
- } else {
- CHECK(v2_ip_num <= 1)
- << "More than one IP is available, please specify GID_INDEX";
- }
- break;
- case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index
- break;
- default:
- LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
- "InfiniBand only. ";
- }
- if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
- LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
- }
- return gid_index;
-}
-
-// set the default or environment value to the configuration parameter.
-// Args:
-// default_val- the default value for this parameter
-// env_param- the environment parameter's name
-// Returns:
-// 32-bit value
-uint32_t set_param(uint32_t default_val, const char* env_param) {
- uint32_t val = default_val;
- string val_s;
-
- val_s = get_env_var(env_param);
-
- if (!val_s.empty()) {
- val = stoi(val_s);
- }
- return val;
-}
-
-enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
- ibv_port_attr port_attr;
- enum ibv_mtu mtu;
- string mtu_s;
- int rc, mtu_i;
-
- rc = ibv_query_port(context, port_num, &port_attr);
- CHECK(!rc) << "Failed to query the port" << port_num;
-
- mtu_s = get_env_var("RDMA_MTU");
-
- if (!mtu_s.empty()) {
- mtu_i = stoi(mtu_s);
- switch (mtu_i) {
- case 256:
- mtu = IBV_MTU_256;
- break;
- case 512:
- mtu = IBV_MTU_512;
- break;
- case 1024:
- mtu = IBV_MTU_1024;
- break;
- case 2048:
- mtu = IBV_MTU_2048;
- break;
- case 4096:
- mtu = IBV_MTU_4096;
- break;
- default:
- CHECK(0) << "Error: MTU input value must be one of the following: 256, "
- "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n";
- break;
- }
- CHECK(mtu < port_attr.active_mtu)
- << "MTU configuration for the QPs is larger than active MTU";
- } else {
- mtu = port_attr.active_mtu;
- }
- return mtu;
-}
-
-RdmaParams params_init(ibv_context* context) {
- RdmaParams params;
-
- params.port_num = set_port(context);
- params.sgid_index = set_gid(params.port_num, context);
- params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
- params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
- params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
- params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
- params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
- CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
- << ". Valid values are 0-7.";
- params.mtu = set_mtu(params.port_num, context);
- params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
- return params;
+ ib_dev = dev_list[0];
+ CHECK(ib_dev) << "No InfiniBand device found";
+ ibv_context* context = ibv_open_device(ib_dev);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
+ return context;
}
ibv_pd* alloc_protection_domain(ibv_context* context) {
@@ -409,8 +85,7 @@ ibv_pd* alloc_protection_domain(ibv_context* context) {
}
RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
- : context_(open_device(set_device())),
- params_(params_init(context_)),
+ : context_(open_default_device()),
pd_(alloc_protection_domain(context_)),
worker_env_(worker_env) {
event_channel_ = ibv_create_comp_channel(context_);
@@ -453,9 +128,9 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
- << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
- << wc_[i].vendor_err;
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@@ -567,8 +242,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
memset(&attr, 0, sizeof(ibv_qp_init_attr));
attr.send_cq = adapter_->cq_;
attr.recv_cq = adapter_->cq_;
- attr.cap.max_send_wr = adapter_->params_.queue_depth;
- attr.cap.max_recv_wr = adapter_->params_.queue_depth;
+ attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
attr.cap.max_send_sge = 1;
attr.cap.max_recv_sge = 1;
attr.qp_type = IBV_QPT_RC;
@@ -582,8 +257,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_INIT;
- attr.pkey_index = adapter_->params_.pkey_index;
- attr.port_num = adapter_->params_.port_num;
+ attr.pkey_index = 0;
+ attr.port_num = 1;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int mask =
@@ -594,15 +269,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// Local address
{
struct ibv_port_attr attr;
- CHECK(
- !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
<< "Query port";
self_.lid = attr.lid;
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid;
- CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
- adapter_->params_.sgid_index, &gid))
+ CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
<< "Query gid";
self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id;
@@ -611,7 +284,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ "tx_ack_buffer", "rx_ack_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
@@ -672,7 +345,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@@ -806,9 +479,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTR;
-
+ struct ibv_port_attr port_attr;
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr))
+ << "Query port failed";
// This assumes both QP's ports are configured with the same MTU
- attr.path_mtu = adapter_->params_.mtu;
+ attr.path_mtu = port_attr.active_mtu;
attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1;
@@ -819,32 +494,30 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.hop_limit = 255;
attr.ah_attr.dlid = remoteAddr.lid;
- attr.ah_attr.sl = adapter_->params_.sl;
+ attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
- attr.ah_attr.port_num = adapter_->params_.port_num;
- attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
- attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
+ attr.ah_attr.port_num = 1;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
- IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = self_.psn;
- attr.timeout = adapter_->params_.timeout;
- attr.retry_cnt = adapter_->params_.retry_cnt;
+ attr.timeout = 14;
+ attr.retry_cnt = 7;
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
- IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
@@ -931,7 +604,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
@@ -1026,9 +699,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
- CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
- << " gpu_info: "
- << src_dev->tensorflow_gpu_device_info();
+ CHECK(send_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
@@ -1054,8 +727,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead, send_args, recv_args](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index 52d92a7c5b..e1e07db776 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -36,24 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
-#define PKEY_DEFAULT 0
-#define QUEUE_DEPTH_DEFAULT 1024
-#define TIMEOUT_DEFAULT 14
-#define RETRY_CNT_DEFAULT 7
-#define SL_DEFAULT 0
-#define TRAFFIC_CLASS 0
-
-struct RdmaParams {
- uint8_t port_num;
- uint8_t sgid_index;
- uint8_t pkey_index;
- uint32_t queue_depth;
- uint8_t timeout;
- uint8_t retry_cnt;
- uint8_t sl;
- enum ibv_mtu mtu;
- uint8_t traffic_class;
-};
+
// structure to save the address of remote channels.
struct RdmaAddress {
uint32_t lid;
@@ -67,20 +50,9 @@ struct RemoteMR {
uint64_t remote_addr;
uint32_t rkey;
};
-enum BufferStatus {
- none,
- idle,
- busy
-};
-enum Location {
- local,
- remote
-};
-enum BufferType {
- ACK,
- MESSAGE,
- TENSOR
-};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
+enum BufferType { ACK, MESSAGE, TENSOR };
enum RdmaMessageType {
RDMA_MESSAGE_ACK,
RDMA_MESSAGE_BUFFER_IDLE,
@@ -112,8 +84,6 @@ class RdmaAdapter {
protected:
static const int MAX_CONCURRENT_WRITES = 1000;
ibv_context* context_;
- // RDMA configuration parameters
- RdmaParams params_;
// ibverbs protection domain
ibv_pd* pd_;
// Completion event channel, to wait for work completions
@@ -213,7 +183,7 @@ class RdmaBuffer {
}
void FreeBuffer();
void EnqueueItem(string Item);
- virtual void SendNextItem() {};
+ virtual void SendNextItem(){};
void CreateCPUBuffer(size_t size, bool lock = true);
void SetRemoteMR(RemoteMR rmi, bool override);
uint32_t LookupBufferIndex(const string& buffer_name) {