diff options
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.h')
-rw-r--r-- | tensorflow/contrib/verbs/rdma.h | 50 |
1 files changed, 47 insertions, 3 deletions
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index 16ef58bc62..e1e07db776 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -28,6 +28,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -224,14 +225,57 @@ class RdmaMessageBuffer : public RdmaBuffer { class RdmaTensorBuffer : public RdmaBuffer { public: explicit RdmaTensorBuffer(RdmaChannel* channel, string name); - virtual ~RdmaTensorBuffer() override {} + virtual ~RdmaTensorBuffer() override; void SendNextItem() override; void PostCopyOperations(bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key, const Tensor& in, int64 step_id, bool is_dead, const string& key_with_step_id, const Tensor* copy, - const TensorProto* proto, - const StringPiece* copy_buf); + const TensorProto* proto, const StringPiece* copy_buf, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args); + + void ReSendNextItem(); + + private: + Rendezvous::DoneCallback getRecvTensorCallback( + const string& key_with_step_id, const string& key, int64 step_id, + const Rendezvous::ParsedKey& parsed); + + struct ReItem { + Rendezvous::Args send_args; + Rendezvous::Args recv_args; + Tensor in; + bool is_dead; + + ReItem(const Rendezvous::Args& send_args_, + const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_) + : send_args(send_args_), + recv_args(recv_args_), + in(in_), + is_dead(is_dead_) { + if (send_args.device_context) { + send_args.device_context->Ref(); + } + if (recv_args.device_context) { + recv_args.device_context->Ref(); + } + } + + ~ReItem() { + if (send_args.device_context) { + send_args.device_context->Unref(); + } + if (recv_args.device_context) { + recv_args.device_context->Unref(); + } + } + }; + typedef std::map<string, ReItem*> Table; + typedef Table::iterator Itable; + + std::queue<string> requeue GUARDED_BY(mu_); + Table retable GUARDED_BY(mu_); }; struct RdmaMessage { |