diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc new file mode 100644 index 0000000000..5eeed6e382 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr<DeviceResolverDistributed> dev_resolver, + std::unique_ptr<CollectiveParamResolverDistributed> param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name) + : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver)), + worker_cache_(worker_cache), + task_name_(task_name) { + group_leader_ = (task_name == config.experimental().collective_group_leader()) + ? "" + : config.experimental().collective_group_leader(); +} + +RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() { + for (auto it : sequence_table_) { + delete it.second; + } +} + +CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); +} + +namespace { +// StepId must leave the most-significant 7 bits empty for future use. +static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56)); + +int64 NewRandomStepId() { + int64 step_id = random::New64(); + // Leave MS 8 bits clear for future use. + step_id &= kStepIdMask; + return step_id; +} +} // namespace + +void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( + int64 graph_key, const StatusCallback& done) { + if (group_leader_.empty()) { + mutex_lock l(sequence_mu_); + GraphKeySequence* gks = nullptr; + auto it = sequence_table_.find(graph_key); + if (it == sequence_table_.end()) { + gks = new GraphKeySequence(graph_key); + sequence_table_[graph_key] = gks; + } else { + gks = it->second; + } + gks->next_step_id_ = NewRandomStepId(); + done(Status::OK()); + } else { + WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_); + GetStepSequenceRequest* req = new GetStepSequenceRequest; + GetStepSequenceResponse* resp = new GetStepSequenceResponse; + req->add_graph_key(graph_key); + wi->GetStepSequenceAsync( + req, resp, [this, req, resp, done](const Status& s) { + if (!s.ok()) { + LOG(ERROR) << "Bad response [" << s + << "] from GetStepSequenceAsync call to " + << group_leader_; + done(s); + } else { + done(UpdateStepSequences(*resp)); + } + delete req; + delete resp; + }); + } +} + +Status RpcCollectiveExecutorMgr::UpdateStepSequences( + const GetStepSequenceResponse& resp) { + mutex_lock l(sequence_mu_); + for (const StepSequence& ss : resp.step_sequence()) { + GraphKeySequence* gks = nullptr; + auto it = sequence_table_.find(ss.graph_key()); + if (it == sequence_table_.end()) { + gks = new GraphKeySequence(ss.graph_key()); + sequence_table_[ss.graph_key()] = gks; + } else { + gks = it->second; + } + gks->next_step_id_ = ss.next_step_id(); + } + return Status::OK(); +} + +int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) { + mutex_lock l(sequence_mu_); + auto it = sequence_table_.find(graph_key); + if (it != sequence_table_.end()) { + return it->second->next_step_id_; + } + return CollectiveExecutor::kInvalidId; +} + +void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) { + mutex_lock l(sequence_mu_); + auto it = sequence_table_.find(graph_key); + if (it != sequence_table_.end()) { + if (step_id == it->second->next_step_id_) { + it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask; + } else { + it->second->next_step_id_ = CollectiveExecutor::kInvalidId; + } + } else { + LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire."; + } +} + +} // namespace tensorflow |