diff options
author | Jianwei Xie <xiejw@google.com> | 2018-01-24 10:02:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-24 10:06:06 -0800 |
commit | d9f93c42a50b1f1401d9c186eac0ae8dc9093c3b (patch) | |
tree | 178d1a692f56580c266139642b5a1d0d155c477e /tensorflow/contrib/verbs/rdma.h | |
parent | 7b62a71e2d46c148df7d5704972f4592bc5e0f1b (diff) |
Merge changes from github.
PiperOrigin-RevId: 183100142
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.h')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.h | 504 |
1 files changed, 334 insertions, 170 deletions
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index fea2327d77..68b3d59f56 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -27,6 +27,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" @@ -43,6 +44,11 @@ namespace tensorflow { #define SL_DEFAULT 0 #define TRAFFIC_CLASS 0 +#define RDMA_LOG_0 LOG(INFO) +#define RDMA_LOG_1 VLOG(1) +#define RDMA_LOG_2 VLOG(2) +#define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL + struct RdmaParams { uint8_t port_num; uint8_t sgid_index; @@ -76,29 +82,303 @@ enum Location { local, remote }; -enum BufferType { - ACK, - MESSAGE, - TENSOR -}; + enum RdmaMessageType { - RDMA_MESSAGE_ACK, - RDMA_MESSAGE_BUFFER_IDLE, - RDMA_MESSAGE_BUFFER_REQUEST, - RDMA_MESSAGE_BUFFER_RESPONSE, + RDMA_MESSAGE_META_DATA_UPDATE, + RDMA_MESSAGE_TENSOR_RE_REQUEST, RDMA_MESSAGE_TENSOR_REQUEST, - RDMA_MESSAGE_TENSOR_WRITE + RDMA_MESSAGE_ERROR_STATUS, +}; + +struct RdmaMessage { + RdmaMessageType type_; + uint16_t name_size_; + string name_; + int64 step_id_; + uint64_t request_index_; + union { + uint64_t remote_addr_; +#ifdef RDMA_DATA_VALIDATION + uint64_t checksum_; +#endif + }; + uint32_t rkey_; + bool is_dead_; + DataType data_type_; + TensorShape tensor_shape_; + size_t tensor_bytes_; + + // For error status: + Status status_; + + // type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|... + // 1B| 2B | 512| 8B | 8B | 8B | 4B |... + // ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status | + // ...| 1B | XB | XB | 8B |size - 4B, proto - XB | + static const size_t kNameCapacity = 512; + static const size_t kTypeStartIndex = 0; + static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); + static const size_t kNameStartIndex = + kNameSizeStartIndex + sizeof(name_size_); + static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; + static const size_t kRequestIndexStartIndex = + kStepIdStartIndex + sizeof(step_id_); + static const size_t kRemoteAddrStartIndex = + kRequestIndexStartIndex + sizeof(request_index_); + static const size_t kChecksumStartIndex = kRemoteAddrStartIndex; + static const size_t kRkeyStartIndex = + kRemoteAddrStartIndex + sizeof(remote_addr_); + static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); + static const size_t kDataTypeStartIndex = + kIsDeadStartIndex + sizeof(is_dead_); + static const size_t kTensorShapeStartIndex = + kDataTypeStartIndex + sizeof(data_type_); + static const size_t kTensorBytesStartIndex = + kTensorShapeStartIndex + sizeof(TensorShape); + static const size_t kErrorStatusStartIndex = + kTensorBytesStartIndex + sizeof(tensor_bytes_); + static const size_t kErrorStatusMaxSize = 4096; + + static const size_t kMessageTotalBytes = kErrorStatusStartIndex; + static const size_t kRdmaMessageBufferSize = + kMessageTotalBytes + kErrorStatusMaxSize; + static string CreateMessage(const RdmaMessage& rm); + static void ParseMessage(RdmaMessage& rm, void* buffer); +}; + +// Immediate types for RDMA write +enum RdmaImmDataType { + RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD, + RDMA_IMM_DATA_ACK = 0xFFFFFFFE, + RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF +}; + +// Write types for RDMA write-complete events +enum RdmaWriteIDType { + RDMA_WRITE_ID_ACK, + RDMA_WRITE_ID_MESSAGE, + RDMA_WRITE_ID_TENSOR_WRITE +}; + +// Context for RDMA write-complete events +class RdmaWriteID { + public: + RdmaWriteID(RdmaWriteIDType write_type, void* write_context) + : write_type(write_type), write_context(write_context) {} + + RdmaWriteIDType write_type; + void* write_context; +}; + +// Tensor meta-data +class TensorMetaData { + public: + TensorShape tensor_shape_; + DataType data_type_; + size_t proto_size_; + bool is_dead_; + + std::ostream& print(std::ostream& out) const { + out << "Dtype = " << DataTypeString(data_type_) + << ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x" + << std::hex << proto_size_ << ", Is dead = " << is_dead_; + return out; + } +}; + +inline std::ostream& operator<<(std::ostream& out, + const TensorMetaData& meta_data) { + return meta_data.print(out); +} + +class RdmaChannel; + +void MRDeleter(ibv_mr* mr); +using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>; + +// RdmaMemoryMgr +// Manages the local meta-data cache, and the registered RDMA memory regions. +class RdmaMemoryMgr { + public: + static RdmaMemoryMgr& Singleton() { + static RdmaMemoryMgr instance; + return instance; + } + + // Memory regions + ibv_mr* FindMemoryRegion(void* addr, size_t length); + void InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name); + void EvictMemoryRegion(void* addr, size_t length); + + // Tensor meta-data cache + const TensorMetaData* GetTensorMetaData(const std::string& tensor_name); + const TensorMetaData* SetTensorMetaData(const std::string& tensor_name, + DataType dtype, + const TensorShape& shape, + bool is_dead, size_t proto_size); + + struct ibv_pd* pd_; + + protected: + RdmaMemoryMgr() : pd_(nullptr) {} + + static bool Comparator(const void* ptr, const MemoryRegionPtr& other) { + return ptr < reinterpret_cast<char*>(other->addr) + other->length; + } + + private: + mutex tensor_meta_data_mu_; + std::unordered_map<std::string, TensorMetaData> tensors_meta_data_; + + // Managed memory regions + mutex mrs_mu_; + std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(mrs_mu_); }; -class RdmaBuffer; + +// RdmaTensorRequest +// Represents a single tensor request. +class RdmaTensorRequest { + public: + typedef Rendezvous::DoneCallback RecvDoneCallback; + + // Creates a tensor request identified by index. + RdmaTensorRequest(uint32_t index, const string& key, int64 step_id, + RdmaChannel* channel, Device* dst_dev, + const Rendezvous::Args recv_args, + const RecvDoneCallback& done); + ~RdmaTensorRequest(); + + // Request unique index. + uint32_t index() { return index_; } + + // Start the tensor request sequence. + // + // 1. Allocate the result tensor (and proxy tensor if required). + // 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side. + void Start(); + + // Receive tensor meta-data. + // + // 1. Update the local meta-data cache. + // 2. Reallocate the result tensor (and proxy tensor if required). + // 3. Re-send the request to the remote side. + void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead, + size_t proto_size); + + // Receive tensor content (RDMA write was completed). + // + // Decode proto if required and/or move to GPU if the content was not + // written to it directly (GPU direct is not avaliable). Afterwards, + // invoke Done(). + void RecvTensorContent(); + + // Receive error status (in case of a remote error). + // Invoke Done() with the status code. + void RecvErrorStatus(const Status& status); + +#ifdef RDMA_DATA_VALIDATION + // Receive tensor checksum + // + // For validation: Get and store the Tensor's expected checksum for the + // current request. Compare the result Tensor's checksum with the stored + // checksum right before invoking Done(). + void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; } +#endif + + private: + void Done(const Status& s); + void Send(RdmaMessageType message_type); + bool AllocateTensors(); + void AllocateTensorsAsync(StatusCallback done); + void DeallocateTensors(); + + uint32_t index_; + string key_; + int64 step_id_; + RdmaChannel* channel_; + Device* dst_dev_; + Rendezvous::Args recv_args_; + const TensorMetaData* meta_data_; + Tensor* result_tensor_; + Tensor* proxy_tensor_; + void* rdma_addr_; + ibv_mr* mr_; + RecvDoneCallback done_; +#ifdef RDMA_DATA_VALIDATION + uint64_t checksum_; +#endif +}; + +// RdmaTensorResponse +// Represents a single tensor response. +class RdmaTensorResponse { + public: + // Creates a response for request message. + RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm) + : channel_(channel), rm_(rm) {} + + void Update(const RdmaMessage& rm) { rm_ = rm; } + + // Start the tensor response sequence. + // + // 1. Find the tensor in the local tag-match table and invoke RecvHandler. + // (Using RecvLocalAsync()). + // 2. Compare the tensor's meta-data to the meta-data in the message (taken + // from the requester's local cache). + // If meta-data changed: + // a. Clone the tensor to be sent later. + // b. Send a meta-data update message and wait for re-request. + // Else: + // a. Send the tensor's content (using direct RDMA write). + void Start(); + + // Resume the response sequence, after a re-request. + // + // 1. Send the tensor's content that was cloned earlier. + void Resume(); + + // Destroy the response's resources and remove it from the pending list. + void Destroy(); + + private: + void RecvHandler(Rendezvous::ParsedKey parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + bool is_dead); + void Clone(const Tensor& in, const TensorProto& proto, bool is_dead); + void Send(const Tensor& in, const TensorProto& proto, bool is_dead, + const Status& status); + bool TensorMetaDataChanged(const Tensor& in, bool is_dead); + Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev); + void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead); + void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead); + void SendErrorStatus(const Status& status); + + RdmaChannel* channel_; + RdmaMessage rm_; // The request message + Device* src_dev_ = nullptr; + TensorBuffer* src_buffer_ = nullptr; + void* src_addr_ = nullptr; + ibv_mr* mr_ = nullptr; + uint64_t checksum_ = 0; + bool meta_data_changed_ = false; + + // Re-item: + TensorProto* proto_ = nullptr; + Tensor* tensor_ = nullptr; + bool is_dead_ = false; +}; + +class RdmaMessageBuffer; // Class that represents the Rdma Adapter. // Responsible for creation of the completion queue, and handling // of work completions. class RdmaAdapter { friend class RdmaChannel; - friend class RdmaBuffer; - friend class RdmaAckBuffer; friend class RdmaMessageBuffer; - friend class RdmaTensorBuffer; + friend class RdmaTensorResponse; friend class RdmaMgr; friend class RdmaRemoteRendezvous; @@ -133,10 +413,10 @@ class RdmaAdapter { // Responsible for connecting queue pairs. class RdmaChannel { friend class RdmaAdapter; - friend class RdmaBuffer; - friend class RdmaAckBuffer; friend class RdmaMessageBuffer; friend class RdmaTensorBuffer; + friend class RdmaTensorRequest; + friend class RdmaTensorResponse; friend class RdmaMgr; friend class RdmaRemoteRendezvous; @@ -146,22 +426,28 @@ class RdmaChannel { ~RdmaChannel(); inline const RdmaAddress& self() { return self_; } RdmaAddress address() const; - inline const std::vector<RdmaBuffer*>& message_buffers() const { + inline const std::vector<RdmaMessageBuffer*>& message_buffers() const { return message_buffers_; } void Connect(const RdmaAddress& remoteAddr); void Connect(); void Recv(); - RdmaBuffer* FindBuffer(const uint32_t index); - RdmaBuffer* FindBuffer(const string& name); - RdmaBuffer* FindOrCreateBuffer(const string& name, - BufferType buffer_type = TENSOR); - uint32_t LookupBufferIndex(const string& buffer_name); void SetRemoteAddress(const RdmaAddress& ra, bool override); - void InsertRecvCallback(const string& key, std::function<void()> recv_done); - void RemoveRecvCallback(const string& key); - void RunRecvCallback(const string& key); - static const int kNumMessageBuffers = 4; + + // Requests: + RdmaTensorRequest* InsertTensorRequest( + const string& key, int64 step_id, Device* dst_dev, + const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done); + void RemoveTensorRequest(uint32_t request_index); + RdmaTensorRequest* GetTensorRequest(uint32_t request_index); + + // Responses: + RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm); + RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm); + void RemoveTensorResponse(uint32_t request_index); + + static const int kNumMessageBuffers = 2; static const int kPingRecvWrid = 0; private: @@ -179,36 +465,31 @@ class RdmaChannel { string remote_name_; ibv_qp* qp_; mutex mu_; - bool connected_ GUARDED_BY(bt_mu_) = false; - RdmaAddress remote_ GUARDED_BY(bt_mu_); - bool remote_set_ GUARDED_BY(bt_mu_) = false; + bool connected_ GUARDED_BY(mu_) = false; + RdmaAddress remote_ GUARDED_BY(mu_); + bool remote_set_ GUARDED_BY(mu_) = false; mutex ct_mu_; - typedef std::unordered_map<string, std::function<void()> > CallbackTable; - CallbackTable callback_table_ GUARDED_BY(ct_mu_); - mutex bt_mu_; - typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable; - BufferTable buffer_table_ GUARDED_BY(bt_mu_); - typedef std::unordered_map<uint32_t, string> BufferIndexNameTable; - BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_); - typedef std::unordered_map<string, uint32_t> BufferNameIndexTable; - BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_); - RdmaBuffer* tx_message_buffer_; - RdmaBuffer* rx_message_buffer_; - RdmaBuffer* tx_ack_buffer_; - RdmaBuffer* rx_ack_buffer_; - std::vector<RdmaBuffer*> message_buffers_; + typedef std::unordered_map<uint32_t, RdmaTensorRequest> RequestTable; + RequestTable request_table_ GUARDED_BY(ct_mu_); + uint32_t request_serial_ GUARDED_BY(ct_mu_); + mutex responses_mu_; + typedef std::unordered_map<uint32_t, RdmaTensorResponse> ResponsesTable; + ResponsesTable responses_table_ GUARDED_BY(responses_mu_); + RdmaMessageBuffer* tx_message_buffer_; + RdmaMessageBuffer* rx_message_buffer_; + std::vector<RdmaMessageBuffer*> message_buffers_; }; -// Class that represents a buffer for Rdma writes and reads. -class RdmaBuffer { +// Class that represents a buffer for Rdma message sending. +class RdmaMessageBuffer { friend class RdmaChannel; friend class RdmaAdapter; friend class RdmaMgr; friend class RdmaRemoteRendezvous; public: - explicit RdmaBuffer(RdmaChannel* channel, string name); - virtual ~RdmaBuffer(); + explicit RdmaMessageBuffer(RdmaChannel* channel, string name); + ~RdmaMessageBuffer(); inline void* buffer() const { return buffer_; } inline ibv_mr* self() const { return self_; } @@ -223,13 +504,15 @@ class RdmaBuffer { } void FreeBuffer(); void EnqueueItem(string Item); - virtual void SendNextItem() {}; + void SendNextItem(); void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); - uint32_t LookupBufferIndex(const string& buffer_name) { - return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name); - } void Write(uint32_t imm_data, size_t buffer_size); + static void Write(const RdmaChannel* channel, uint32_t imm_data, + size_t buffer_size, uint64_t src_addr, uint32_t lkey, + uint64_t remote_addr, uint32_t rkey, + RdmaWriteIDType write_type, void* write_context); + static void SendAck(const RdmaChannel* channel); protected: const RdmaChannel* channel_; @@ -245,125 +528,6 @@ class RdmaBuffer { BufferStatus remote_status_ GUARDED_BY(mu_) = none; }; -class RdmaAckBuffer : public RdmaBuffer { - public: - explicit RdmaAckBuffer(RdmaChannel* channel, string name); - virtual ~RdmaAckBuffer() override {} - void SendNextItem() override; -}; - -class RdmaMessageBuffer : public RdmaBuffer { - friend class RdmaChannel; - friend class RdmaAapater; - - public: - explicit RdmaMessageBuffer(RdmaChannel* channel, string name); - virtual ~RdmaMessageBuffer() override {} - void SendNextItem() override; -}; - -class RdmaTensorBuffer : public RdmaBuffer { - public: - explicit RdmaTensorBuffer(RdmaChannel* channel, string name); - virtual ~RdmaTensorBuffer() override; - void SendNextItem() override; - void PostCopyOperations(bool can_memcpy, size_t buffer_size, - size_t tensor_bytes, const string& key, - const Tensor& in, int64 step_id, bool is_dead, - const string& key_with_step_id, const Tensor* copy, - const TensorProto* proto, const StringPiece* copy_buf, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args); - - void ReSendNextItem(); - - private: - Rendezvous::DoneCallback getRecvTensorCallback( - const string& key_with_step_id, const string& key, int64 step_id, - const Rendezvous::ParsedKey& parsed); - - struct ReItem { - Rendezvous::Args send_args; - Rendezvous::Args recv_args; - Tensor in; - bool is_dead; - - ReItem(const Rendezvous::Args& send_args_, - const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_) - : send_args(send_args_), - recv_args(recv_args_), - in(in_), - is_dead(is_dead_) { - if (send_args.device_context) { - send_args.device_context->Ref(); - } - if (recv_args.device_context) { - recv_args.device_context->Ref(); - } - } - - ~ReItem() { - if (send_args.device_context) { - send_args.device_context->Unref(); - } - if (recv_args.device_context) { - recv_args.device_context->Unref(); - } - } - }; - typedef std::map<string, ReItem*> Table; - typedef Table::iterator Itable; - - std::queue<string> requeue GUARDED_BY(mu_); - Table retable GUARDED_BY(mu_); -}; - -struct RdmaMessage { - RdmaMessageType type_; - uint16_t name_size_; - string name_; - int64 step_id_; - uint64_t buffer_size_; - uint64_t remote_addr_; - uint32_t rkey_; - bool is_dead_; - DataType data_type_; - TensorShape tensor_shape_; - size_t tensor_bytes_; - - // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... - // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... - // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer - // ...| XB | XB | 8B |... - // - static const size_t kNameCapacity = 512; - static const size_t kTypeStartIndex = 0; - static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); - static const size_t kNameStartIndex = - kNameSizeStartIndex + sizeof(name_size_); - static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; - static const size_t kBufferSizeStartIndex = - kStepIdStartIndex + sizeof(step_id_); - static const size_t kRemoteAddrStartIndex = - kBufferSizeStartIndex + sizeof(buffer_size_); - static const size_t kRkeyStartIndex = - kRemoteAddrStartIndex + sizeof(remote_addr_); - static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); - static const size_t kDataTypeStartIndex = - kIsDeadStartIndex + sizeof(is_dead_); - static const size_t kTensorShapeStartIndex = - kDataTypeStartIndex + sizeof(data_type_); - static const size_t kTensorBytesStartIndex = - kTensorShapeStartIndex + sizeof(TensorShape); - static const size_t kTensorBufferStartIndex = - kTensorBytesStartIndex + sizeof(tensor_bytes_); - static const size_t kMessageTotalBytes = kTensorBufferStartIndex; - static const size_t kRdmaMessageBufferSize = kMessageTotalBytes; - static const size_t kRdmaAckBufferSize = kMessageTotalBytes; - static string CreateMessage(const RdmaMessage& rm); - static void ParseMessage(RdmaMessage& rm, void* buffer); -}; - } // namespace tensorflow #endif // TENSORFLOW_USE_VERBS |