diff options
Diffstat (limited to 'tensorflow/contrib/mpi')
-rw-r--r-- | tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc | 13 | ||||
-rw-r--r-- | tensorflow/contrib/mpi/mpi_rendezvous_mgr.h | 33 |
2 files changed, 16 insertions, 30 deletions
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index e97e8d0163..1a2563d20f 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -44,7 +44,8 @@ MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env) // extract worker-name auto parsed = env->local_devices[0]->parsed_name(); - const std::string task_id = strings::StrCat(parsed.job, ":", parsed.replica); + const std::string task_id = + strings::StrCat(parsed.job, ":", parsed.replica, ":", parsed.task); mpiutils_ = new MPIUtils(task_id); background_thread_ = @@ -66,8 +67,8 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync( VLOG(2) << "MPI User requested " << parsed.FullKey() << " @ step: " << step_id_; - std::string src_task = - strings::StrCat(parsed.src.job, ":", parsed.src.replica); + std::string src_task = strings::StrCat( + parsed.src.job, ":", parsed.src.replica, ":", parsed.src.task); const int dst = mpiutils_->GetSourceID(src_task); Device* dst_device; @@ -138,11 +139,7 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync( std::move(request_call), rendezvous_call); } -MPIRemoteRendezvous::~MPIRemoteRendezvous() { - MPIRendezvousMgr* mgr = - reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_); - mgr->RemoveStepID(step_id_); -} +MPIRemoteRendezvous::~MPIRemoteRendezvous() {} /* * Add the request for one of our Tensors by a remote process diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h index 50fc380496..24e784df3e 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h @@ -147,15 +147,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { MPIRequestTensorCall* rCall) { mutex_lock l(mrq_); request_queue_.push(RequestQueueEntry(key, std::move(request_call))); - recv_tensor_map_[step_id][key] = - std::shared_ptr<MPIRequestTensorCall>(rCall); - } - - void RemoveStepID(const int64 step_id) { - mutex_lock l(mrq_); - CHECK(recv_tensor_map_[step_id].size() == 0) << "Removing unfinished step"; - recv_tensor_map_.erase(step_id); - // TODO(jbedorf) Should we verify that the step_id is clear before remove? + const std::string key_id = strings::StrCat(key, "_", step_id); + recv_tensor_map_[key_id] = std::shared_ptr<MPIRequestTensorCall>(rCall); } protected: @@ -181,9 +174,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { std::queue<SendQueueEntry> send_queue_ GUARDED_BY(msq_); std::queue<RequestQueueEntry> request_queue_ GUARDED_BY(mrq_); - std::map<int64, std::unordered_map<std::string, - std::shared_ptr<MPIRequestTensorCall>>> - recv_tensor_map_ GUARDED_BY(mrq_); + std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_ + GUARDED_BY(mrq_); void AddRequest(RecvTensorRequest, const int); void MPIBackgroundThread(); @@ -196,22 +188,19 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { void GetRecvCall(const int64 step_id, const std::string& key, std::shared_ptr<MPIRequestTensorCall>* call) { mutex_lock l(mrq_); - if (recv_tensor_map_.find(step_id) == recv_tensor_map_.end()) { - LOG(FATAL) << "Step not found in recv_tensor_map_, step: " << step_id - << " key: " << key << std::endl; - } - if (recv_tensor_map_[step_id].find(key) != - recv_tensor_map_[step_id].end()) { - *call = recv_tensor_map_[step_id][key]; - } else { - LOG(FATAL) << "Key not found in recv_tensor_map_, step: " << step_id + + const std::string key_id = strings::StrCat(key, "_", step_id); + if (recv_tensor_map_.find(key_id) == recv_tensor_map_.end()) { + LOG(FATAL) << "Key/step not found in recv_tensor_map_, step: " << step_id << " key: " << key << std::endl; } + *call = recv_tensor_map_[key_id]; } void RemoveRecvCall(const int64 step_id, const std::string& key) { mutex_lock l(mrq_); - recv_tensor_map_[step_id].erase(key); + const std::string key_id = strings::StrCat(key, "_", step_id); + recv_tensor_map_.erase(key_id); } bool GetRequest(RequestQueueEntry* req) { |