aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi
diff options
context:
space:
mode:
authorGravatar Jeremy Lau <lauj@google.com>2018-01-22 17:26:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-22 17:30:59 -0800
commit6042b5d267f42d004087b44c29525951700579f9 (patch)
treec4aa81597cc062a64b200758ce6dec5a1f5f16e5 /tensorflow/contrib/mpi
parent2968447d32bdfd0dd6fafabfcd1aafd6dc261803 (diff)
Reject retried RecvTensor requests.
Retried RecvTensorRequests are problematic because a RecvTensor with no corresponding sender will wait forever, and the tensor may have been delivered to a previous retry. This change adds a unique request_id to each RecvTensor request, and we check these request_ids against a set of recent request_ids. If a request_id is in the recent set, we reject the RecvTensor request. PiperOrigin-RevId: 182863245
Diffstat (limited to 'tensorflow/contrib/mpi')
-rw-r--r--tensorflow/contrib/mpi/BUILD2
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc8
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h6
3 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD
index d9d55faf50..23f90cf77e 100644
--- a/tensorflow/contrib/mpi/BUILD
+++ b/tensorflow/contrib/mpi/BUILD
@@ -71,6 +71,8 @@ cc_library(
"//tensorflow/core:protos_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:recent_request_ids",
+ "//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 1a2563d20f..8d14a3ef04 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -33,8 +33,10 @@ limitations under the License.
namespace tensorflow {
MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env)
- : BaseRendezvousMgr(env), worker_env_2(env), use_optimal_transfer_(false) {
-
+ : BaseRendezvousMgr(env),
+ worker_env_2(env),
+ use_optimal_transfer_(false),
+ recv_tensor_recent_request_ids_(100000) {
const char* mpienv = getenv("MPI_OPTIMAL_PATH");
if (mpienv && mpienv[0] == '1') {
LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when "
@@ -149,6 +151,8 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
*/
void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
const int mpi_dst) {
+ TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique(
+ req.request_id(), "RecvTensor (MPIRendezvousMgr)", req));
const int64 step_id = request.step_id();
const std::string& key = request.rendezvous_key();
Rendezvous::ParsedKey parsed;
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index b15748d63c..ca42ee2f6d 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -30,10 +30,11 @@ limitations under the License.
#include <iostream>
+#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
#include "tensorflow/contrib/mpi/mpi_utils.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
-#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#define TAG_REQTENSOR 1010
@@ -104,6 +105,7 @@ class MPIRequestTensorCall {
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size());
+ req_.set_request_id(GetUniqueRequestId());
request_buffer_size_ = req_.ByteSize();
// request_buffer_ = new char[request_buffer_size_];
// req_.SerializeToArray(request_buffer_, request_buffer_size_);
@@ -177,6 +179,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_
GUARDED_BY(mrq_);
+ RecentRequestIds recv_tensor_recent_request_ids_;
+
void AddRequest(RecvTensorRequest, const int);
void MPIBackgroundThread();