aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs/rdma.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/verbs/rdma.h')
-rw-r--r--tensorflow/contrib/verbs/rdma.h50
1 files changed, 47 insertions, 3 deletions
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index 16ef58bc62..e1e07db776 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -28,6 +28,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@@ -224,14 +225,57 @@ class RdmaMessageBuffer : public RdmaBuffer {
class RdmaTensorBuffer : public RdmaBuffer {
public:
explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
- virtual ~RdmaTensorBuffer() override {}
+ virtual ~RdmaTensorBuffer() override;
void SendNextItem() override;
void PostCopyOperations(bool can_memcpy, size_t buffer_size,
size_t tensor_bytes, const string& key,
const Tensor& in, int64 step_id, bool is_dead,
const string& key_with_step_id, const Tensor* copy,
- const TensorProto* proto,
- const StringPiece* copy_buf);
+ const TensorProto* proto, const StringPiece* copy_buf,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args);
+
+ void ReSendNextItem();
+
+ private:
+ Rendezvous::DoneCallback getRecvTensorCallback(
+ const string& key_with_step_id, const string& key, int64 step_id,
+ const Rendezvous::ParsedKey& parsed);
+
+ struct ReItem {
+ Rendezvous::Args send_args;
+ Rendezvous::Args recv_args;
+ Tensor in;
+ bool is_dead;
+
+ ReItem(const Rendezvous::Args& send_args_,
+ const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_)
+ : send_args(send_args_),
+ recv_args(recv_args_),
+ in(in_),
+ is_dead(is_dead_) {
+ if (send_args.device_context) {
+ send_args.device_context->Ref();
+ }
+ if (recv_args.device_context) {
+ recv_args.device_context->Ref();
+ }
+ }
+
+ ~ReItem() {
+ if (send_args.device_context) {
+ send_args.device_context->Unref();
+ }
+ if (recv_args.device_context) {
+ recv_args.device_context->Unref();
+ }
+ }
+ };
+ typedef std::map<string, ReItem*> Table;
+ typedef Table::iterator Itable;
+
+ std::queue<string> requeue GUARDED_BY(mu_);
+ Table retable GUARDED_BY(mu_);
};
struct RdmaMessage {