diff options
author | Jeremy Lau <lauj@google.com> | 2018-01-22 17:26:43 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-22 17:30:59 -0800 |
commit | 6042b5d267f42d004087b44c29525951700579f9 (patch) | |
tree | c4aa81597cc062a64b200758ce6dec5a1f5f16e5 /tensorflow/contrib/mpi | |
parent | 2968447d32bdfd0dd6fafabfcd1aafd6dc261803 (diff) |
Reject retried RecvTensor requests.
Retried RecvTensorRequests are problematic because a RecvTensor with no
corresponding sender will wait forever, and the tensor may have been delivered
to a previous retry.
This change adds a unique request_id to each RecvTensor request, and we check
these request_ids against a set of recent request_ids. If a request_id is in the
recent set, we reject the RecvTensor request.
PiperOrigin-RevId: 182863245
Diffstat (limited to 'tensorflow/contrib/mpi')
-rw-r--r-- | tensorflow/contrib/mpi/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc | 8 | ||||
-rw-r--r-- | tensorflow/contrib/mpi/mpi_rendezvous_mgr.h | 6 |
3 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD index d9d55faf50..23f90cf77e 100644 --- a/tensorflow/contrib/mpi/BUILD +++ b/tensorflow/contrib/mpi/BUILD @@ -71,6 +71,8 @@ cc_library( "//tensorflow/core:protos_cc", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:recent_request_ids", + "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index 1a2563d20f..8d14a3ef04 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -33,8 +33,10 @@ limitations under the License. namespace tensorflow { MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env) - : BaseRendezvousMgr(env), worker_env_2(env), use_optimal_transfer_(false) { - + : BaseRendezvousMgr(env), + worker_env_2(env), + use_optimal_transfer_(false), + recv_tensor_recent_request_ids_(100000) { const char* mpienv = getenv("MPI_OPTIMAL_PATH"); if (mpienv && mpienv[0] == '1') { LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when " @@ -149,6 +151,8 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {} */ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, const int mpi_dst) { + TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique( + req.request_id(), "RecvTensor (MPIRendezvousMgr)", req)); const int64 step_id = request.step_id(); const std::string& key = request.rendezvous_key(); Rendezvous::ParsedKey parsed; diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h index b15748d63c..ca42ee2f6d 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h @@ -30,10 +30,11 @@ limitations under the License. #include <iostream> +#include "tensorflow/contrib/mpi/mpi_msg.pb.h" #include "tensorflow/contrib/mpi/mpi_utils.h" #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/contrib/mpi/mpi_msg.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #define TAG_REQTENSOR 1010 @@ -104,6 +105,7 @@ class MPIRequestTensorCall { void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) { req_.set_step_id(step_id); req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size()); + req_.set_request_id(GetUniqueRequestId()); request_buffer_size_ = req_.ByteSize(); // request_buffer_ = new char[request_buffer_size_]; // req_.SerializeToArray(request_buffer_, request_buffer_size_); @@ -177,6 +179,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_ GUARDED_BY(mrq_); + RecentRequestIds recv_tensor_recent_request_ids_; + void AddRequest(RecvTensorRequest, const int); void MPIBackgroundThread(); |