aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 18:12:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 18:14:57 -0700
commit9070f24ae15a4f589219d4cb9c962b14612c2d8c (patch)
tree561e2362e67fc2c45ddd0e8736de2d9e5b5a022f /tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
parent53901f9bb9a3965ed5dce65284053b0eb387b0c4 (diff)
Collective Ops Part 8
Enable collective op execution in distibuted mode: Pass collective_graph_key into graph building and step execution contexts (MasterSession) where it triggers allocation of an RpcCollectiveExecutorMgr that becomes accessible via the WorkerEnv and MasterEnv. The collective_graph_key is used to synchronize step_ids (which are otherwise random) between otherwise independent graph executions that contain collective ops that need to rendezvous. All APIs for using collectives are still non-public and experimental. PiperOrigin-RevId: 199879087
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc')
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc142
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
new file mode 100644
index 0000000000..5eeed6e382
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name)
+ : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
+ std::move(param_resolver)),
+ worker_cache_(worker_cache),
+ task_name_(task_name) {
+ group_leader_ = (task_name == config.experimental().collective_group_leader())
+ ? ""
+ : config.experimental().collective_group_leader();
+}
+
+RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
+ for (auto it : sequence_table_) {
+ delete it.second;
+ }
+}
+
+CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessDistributed* rma =
+ new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+ worker_cache_, step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
+namespace {
+// StepId must leave the most-significant 7 bits empty for future use.
+static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
+
+int64 NewRandomStepId() {
+ int64 step_id = random::New64();
+ // Leave MS 8 bits clear for future use.
+ step_id &= kStepIdMask;
+ return step_id;
+}
+} // namespace
+
+void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ if (group_leader_.empty()) {
+ mutex_lock l(sequence_mu_);
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(graph_key);
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = NewRandomStepId();
+ done(Status::OK());
+ } else {
+ WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
+ GetStepSequenceRequest* req = new GetStepSequenceRequest;
+ GetStepSequenceResponse* resp = new GetStepSequenceResponse;
+ req->add_graph_key(graph_key);
+ wi->GetStepSequenceAsync(
+ req, resp, [this, req, resp, done](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Bad response [" << s
+ << "] from GetStepSequenceAsync call to "
+ << group_leader_;
+ done(s);
+ } else {
+ done(UpdateStepSequences(*resp));
+ }
+ delete req;
+ delete resp;
+ });
+ }
+}
+
+Status RpcCollectiveExecutorMgr::UpdateStepSequences(
+ const GetStepSequenceResponse& resp) {
+ mutex_lock l(sequence_mu_);
+ for (const StepSequence& ss : resp.step_sequence()) {
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(ss.graph_key());
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(ss.graph_key());
+ sequence_table_[ss.graph_key()] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = ss.next_step_id();
+ }
+ return Status::OK();
+}
+
+int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ return it->second->next_step_id_;
+ }
+ return CollectiveExecutor::kInvalidId;
+}
+
+void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ if (step_id == it->second->next_step_id_) {
+ it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
+ } else {
+ it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
+ }
+ } else {
+ LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
+ }
+}
+
+} // namespace tensorflow