aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/mpi')
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc13
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h33
2 files changed, 16 insertions, 30 deletions
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index e97e8d0163..1a2563d20f 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -44,7 +44,8 @@ MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env)
// extract worker-name
auto parsed = env->local_devices[0]->parsed_name();
- const std::string task_id = strings::StrCat(parsed.job, ":", parsed.replica);
+ const std::string task_id =
+ strings::StrCat(parsed.job, ":", parsed.replica, ":", parsed.task);
mpiutils_ = new MPIUtils(task_id);
background_thread_ =
@@ -66,8 +67,8 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
VLOG(2) << "MPI User requested " << parsed.FullKey()
<< " @ step: " << step_id_;
- std::string src_task =
- strings::StrCat(parsed.src.job, ":", parsed.src.replica);
+ std::string src_task = strings::StrCat(
+ parsed.src.job, ":", parsed.src.replica, ":", parsed.src.task);
const int dst = mpiutils_->GetSourceID(src_task);
Device* dst_device;
@@ -138,11 +139,7 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
std::move(request_call), rendezvous_call);
}
-MPIRemoteRendezvous::~MPIRemoteRendezvous() {
- MPIRendezvousMgr* mgr =
- reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
- mgr->RemoveStepID(step_id_);
-}
+MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
/*
* Add the request for one of our Tensors by a remote process
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index 50fc380496..24e784df3e 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -147,15 +147,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
MPIRequestTensorCall* rCall) {
mutex_lock l(mrq_);
request_queue_.push(RequestQueueEntry(key, std::move(request_call)));
- recv_tensor_map_[step_id][key] =
- std::shared_ptr<MPIRequestTensorCall>(rCall);
- }
-
- void RemoveStepID(const int64 step_id) {
- mutex_lock l(mrq_);
- CHECK(recv_tensor_map_[step_id].size() == 0) << "Removing unfinished step";
- recv_tensor_map_.erase(step_id);
- // TODO(jbedorf) Should we verify that the step_id is clear before remove?
+ const std::string key_id = strings::StrCat(key, "_", step_id);
+ recv_tensor_map_[key_id] = std::shared_ptr<MPIRequestTensorCall>(rCall);
}
protected:
@@ -181,9 +174,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
std::queue<SendQueueEntry> send_queue_ GUARDED_BY(msq_);
std::queue<RequestQueueEntry> request_queue_ GUARDED_BY(mrq_);
- std::map<int64, std::unordered_map<std::string,
- std::shared_ptr<MPIRequestTensorCall>>>
- recv_tensor_map_ GUARDED_BY(mrq_);
+ std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_
+ GUARDED_BY(mrq_);
void AddRequest(RecvTensorRequest, const int);
void MPIBackgroundThread();
@@ -196,22 +188,19 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
void GetRecvCall(const int64 step_id, const std::string& key,
std::shared_ptr<MPIRequestTensorCall>* call) {
mutex_lock l(mrq_);
- if (recv_tensor_map_.find(step_id) == recv_tensor_map_.end()) {
- LOG(FATAL) << "Step not found in recv_tensor_map_, step: " << step_id
- << " key: " << key << std::endl;
- }
- if (recv_tensor_map_[step_id].find(key) !=
- recv_tensor_map_[step_id].end()) {
- *call = recv_tensor_map_[step_id][key];
- } else {
- LOG(FATAL) << "Key not found in recv_tensor_map_, step: " << step_id
+
+ const std::string key_id = strings::StrCat(key, "_", step_id);
+ if (recv_tensor_map_.find(key_id) == recv_tensor_map_.end()) {
+ LOG(FATAL) << "Key/step not found in recv_tensor_map_, step: " << step_id
<< " key: " << key << std::endl;
}
+ *call = recv_tensor_map_[key_id];
}
void RemoveRecvCall(const int64 step_id, const std::string& key) {
mutex_lock l(mrq_);
- recv_tensor_map_[step_id].erase(key);
+ const std::string key_id = strings::StrCat(key, "_", step_id);
+ recv_tensor_map_.erase(key_id);
}
bool GetRequest(RequestQueueEntry* req) {