aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 18:12:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 18:14:57 -0700
commit9070f24ae15a4f589219d4cb9c962b14612c2d8c (patch)
tree561e2362e67fc2c45ddd0e8736de2d9e5b5a022f
parent53901f9bb9a3965ed5dce65284053b0eb387b0c4 (diff)
Collective Ops Part 8
Enable collective op execution in distibuted mode: Pass collective_graph_key into graph building and step execution contexts (MasterSession) where it triggers allocation of an RpcCollectiveExecutorMgr that becomes accessible via the WorkerEnv and MasterEnv. The collective_graph_key is used to synchronize step_ids (which are otherwise random) between otherwise independent graph executions that contain collective ops that need to rendezvous. All APIs for using collectives are still non-public and experimental. PiperOrigin-RevId: 199879087
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.cc3
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h3
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc18
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h9
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr_test.cc11
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc10
-rw-r--r--tensorflow/core/distributed_runtime/BUILD50
-rw-r--r--tensorflow/core/distributed_runtime/cancellable_call.h65
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc48
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc7
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc42
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc26
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h8
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc78
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc39
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h11
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc142
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h79
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc124
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc10
25 files changed, 659 insertions, 139 deletions
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc
index a9dc6ca6cd..00f7a8e645 100644
--- a/tensorflow/core/common_runtime/build_graph_options.cc
+++ b/tensorflow/core/common_runtime/build_graph_options.cc
@@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const {
for (auto& s : callable_options.target()) {
strings::StrAppend(&rv, s, ", ");
}
+ if (collective_graph_key != kNoCollectiveGraphKey) {
+ strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
+ }
return rv;
}
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 5ca170e922..3d0f242ea5 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -31,6 +31,9 @@ struct BuildGraphOptions {
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
+ static const int64 kNoCollectiveGraphKey = 0;
+ int64 collective_graph_key = kNoCollectiveGraphKey;
+
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index e07829b286..4f03a5e13a 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -25,11 +25,11 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver)
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver)
: dev_mgr_(dev_mgr),
- dev_resolver_(dev_resolver),
- param_resolver_(param_resolver) {}
+ dev_resolver_(std::move(dev_resolver)),
+ param_resolver_(std::move(param_resolver)) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
if (it != executor_table_.end()) {
ce = it->second;
} else {
- CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
- dev_mgr_, dev_resolver_.get(), step_id);
- ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+ ce = Create(step_id);
executor_table_[step_id] = ce;
}
ce->Ref();
@@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
return ce;
}
+CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessLocal* rma =
+ new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
void CollectiveExecutorMgr::Cleanup(int64 step_id) {
CollectiveExecutor* ce = nullptr;
{
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 4b42e2b4d1..9de6ab8968 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -25,8 +25,8 @@ class DeviceMgr;
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
public:
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver);
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver);
virtual ~CollectiveExecutorMgr();
@@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
void RetireStepId(int64 graph_key, int64 step_id) override {}
protected:
+ // Called by FindOrCreate when table entry does not yet exist.
+ virtual CollectiveExecutor* Create(int64 step_id);
+
const DeviceMgr* dev_mgr_;
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
std::unique_ptr<ParamResolverInterface> param_resolver_;
CollectiveRemoteAccess* remote_access_;
string task_name_;
+
+ private:
mutex exec_mu_;
// Map from step_id to CollectiveExecutor
gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index 34c9163d6a..91994c5731 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
- cme_.reset(new CollectiveExecutorMgr(
- cp, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> prl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ task_name));
+ cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
+ std::move(prl)));
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 3a871f962d..43c404f2ec 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -201,7 +201,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
LOCKS_EXCLUDED(irec->out_mu);
const DeviceMgr* dev_mgr_;
- DeviceResolverInterface* dev_resolver_;
+ DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
mutex group_mu_;
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 07c1eafedc..5cef93c605 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -450,11 +450,13 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Set up for collectives if the RunOption declares a key.
if (run_options.experimental().collective_graph_key() > 0) {
if (!collective_executor_mgr_) {
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> cprl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
- options_.config, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl,
- "/job:localhost/replica:0/task:0")));
+ options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
}
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index ead698d787..9032823e17 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -145,9 +145,11 @@ tf_cc_test(
deps = [
":session_mgr",
":worker_env",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@@ -226,6 +228,17 @@ tf_cc_test(
],
)
+cc_library(
+ name = "cancellable_call",
+ hdrs = ["cancellable_call.h"],
+ deps = [
+ ":call_options",
+ ":worker_cache",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
tf_cc_test(
name = "tensor_coding_test",
size = "small",
@@ -392,6 +405,7 @@ cc_library(
hdrs = ["master_env.h"],
deps = [
":worker_cache",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
],
@@ -453,10 +467,45 @@ cc_library(
)
cc_library(
+ name = "rpc_collective_executor_mgr",
+ srcs = ["rpc_collective_executor_mgr.cc"],
+ hdrs = ["rpc_collective_executor_mgr.h"],
+ deps = [
+ ":base_rendezvous_mgr",
+ ":collective_param_resolver_distributed",
+ ":collective_rma_distributed",
+ ":device_resolver_distributed",
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_collective_executor_mgr_test",
+ srcs = ["rpc_collective_executor_mgr_test.cc"],
+ deps = [
+ ":collective_param_resolver_distributed",
+ ":device_resolver_distributed",
+ ":rpc_collective_executor_mgr",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "collective_rma_distributed",
srcs = ["collective_rma_distributed.cc"],
hdrs = ["collective_rma_distributed.h"],
deps = [
+ ":cancellable_call",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -492,6 +541,7 @@ cc_library(
hdrs = ["collective_param_resolver_distributed.h"],
deps = [
":call_options",
+ ":cancellable_call",
":device_resolver_distributed",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h
new file mode 100644
index 0000000000..05089c7d15
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/cancellable_call.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+
+#include <string>
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr),
+ remote_worker_(remote_worker),
+ wc_(wc),
+ wi_(wc_->CreateWorker(remote_worker_)) {}
+
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* const cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* const wc_; // Not owned
+ WorkerInterface* const wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 7a93b54eae..612ac14e22 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -14,55 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
-#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/protobuf/config.pb.h"
-// TODO(tucker): When we're ready to enable collectives this const will
-// transition to a settable config member.
-static const char FLAGS_collective_group_leader[] =
- "/job:worker/replica:0/task:0";
-
namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager. Note that ParamResolverInterface
-// calls are done on behalf of an Op execution which needs to abort if the
-// step in which it executes is cancelled.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
class CompleteGroupCall : public CancellableCall {
public:
@@ -126,9 +84,9 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
const string& task_name)
: CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name),
worker_cache_(worker_cache),
- group_leader_(task_name == FLAGS_collective_group_leader
+ group_leader_(task_name == config.experimental().collective_group_leader()
? ""
- : FLAGS_collective_group_leader) {}
+ : config.experimental().collective_group_leader()) {}
void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 95a010286d..4eed856759 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -147,10 +147,9 @@ class DeviceResDistTest : public ::testing::Test {
ConfigProto config;
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
- // TODO(tucker): When config option becomes available, set here.
- // if (w == 0) {
- // config.set_collective_group_leader(name);
- // }
+ if (w == 0) {
+ config.mutable_experimental()->set_collective_group_leader(name);
+ }
DefineWorker(config, name, device_type, num_devices);
}
}
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index c15878bfd3..d4c47cab49 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
@@ -28,45 +29,6 @@ namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager.
-//
-// TODO(tucker): Maybe unify this with CancellableCall in
-// collective_param_resolver_distributed.cc.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
-
class RecvBufCall : public CancellableCall {
public:
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
@@ -119,7 +81,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
};
State* state = new State;
- // Logic to be executed on the RecvBufferAsync callback.
+ // Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
to_device_ctx, to_tensor, done](const Status& s) {
if (s.ok()) {
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 8447c55bf4..e2f13df19f 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -118,9 +120,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
Item* item) {
item->session = session;
+ item->collective_graph_key = collective_graph_key;
item->lib_def.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
@@ -280,11 +284,12 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle) {
Item* item = new Item;
- Status s =
- InitItem(session, gdef, graph_options, debug_options, cluster_flr, item);
+ Status s = InitItem(session, gdef, graph_options, debug_options,
+ collective_graph_key, cluster_flr, item);
if (!s.ok()) {
item->Unref();
return s;
@@ -415,7 +420,12 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
-
+ CollectiveExecutor::Handle* ce_handle =
+ item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
+ ? new CollectiveExecutor::Handle(
+ worker_env_->collective_executor_mgr->FindOrCreate(step_id),
+ true)
+ : nullptr;
// Sends values specified by the caller.
if (s.ok()) {
std::vector<string> keys;
@@ -431,22 +441,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
if (!s.ok()) {
done(s);
+ delete ce_handle;
item->Unref();
rendezvous->Unref();
return;
}
- StartParallelExecutors(handle, step_id, item, rendezvous, collector,
- cost_graph, cancellation_manager,
- [item, rendezvous, done](const Status& s) {
+ StartParallelExecutors(handle, step_id, item, rendezvous, ce_handle,
+ collector, cost_graph, cancellation_manager,
+ [item, rendezvous, ce_handle, done](const Status& s) {
done(s);
rendezvous->Unref();
item->Unref();
+ delete ce_handle;
});
}
void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
Item* item, Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -471,6 +484,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
args.step_id = ++next_id_;
}
args.rendezvous = rendezvous;
+ args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
args.step_container = step_container;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index cc35264b8f..5196046c19 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -75,7 +76,7 @@ class GraphMgr {
// reference to cluster_flr to do cross process function calls.
Status Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle);
@@ -138,6 +139,8 @@ class GraphMgr {
// Used to deregister a cost model when cost model is required in graph
// manager.
GraphMgr* graph_mgr;
+
+ int64 collective_graph_key;
};
const WorkerEnv* worker_env_; // Not owned.
@@ -161,6 +164,7 @@ class GraphMgr {
void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -175,7 +179,7 @@ class GraphMgr {
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index 16f4d93c8b..da26c42aca 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -26,6 +26,7 @@ limitations under the License.
namespace tensorflow {
+class CollectiveExecutorMgrInterface;
class Device;
class DeviceSet;
class Env;
@@ -90,6 +91,10 @@ struct MasterEnv {
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
+
+ // Generates per-step CollectiveExecutors and has access to utilities
+ // supporting collective operations.
+ CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index e29bb76ddf..d34ca53f73 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -69,6 +70,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
bool is_partial, WorkerCacheInterface* worker_cache,
bool should_deregister)
: session_handle_(handle),
+ bg_opts_(bopts),
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
@@ -100,6 +102,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const CallableOptions& callable_options() { return callable_opts_; }
+ const BuildGraphOptions& build_graph_options() { return bg_opts_; }
+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
@@ -225,6 +229,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
+ const BuildGraphOptions bg_opts_;
const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
@@ -444,6 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
+ c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -1065,6 +1071,9 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
*callable_opts->mutable_run_options()->mutable_debug_options() =
req.options().debug_options();
}
+
+ opts->collective_graph_key =
+ req.options().experimental().collective_graph_key();
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
@@ -1102,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ h = Hash64Combine(opts.collective_graph_key, h);
+ }
+
return h;
}
@@ -1118,6 +1131,9 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
+ }
strings::StrAppend(&buf, "\n");
return buf;
}
@@ -1430,11 +1446,35 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
-namespace {
-uint64 MakeStepId() {
- return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+uint64 MasterSession::NewStepId(int64 graph_key) {
+ if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // StepId must leave the most-significant 7 bits empty for future use.
+ return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
+ } else {
+ uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ int32 retry_count = 0;
+ while (step_id == CollectiveExecutor::kInvalidId) {
+ Notification note;
+ Status status;
+ env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
+ graph_key, [&status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ if (!status.ok()) {
+ LOG(ERROR) << "Bad status from "
+ "collective_executor_mgr->RefreshStepIdSequence: "
+ << status << ". Retrying.";
+ int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
+ Env::Default()->SleepForMicroseconds(delay_micros);
+ } else {
+ step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ }
+ }
+ return step_id;
+ }
}
-} // namespace
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
@@ -1456,15 +1496,13 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
- int64 count;
+ int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
- TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
- RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
+ RunState* run_state =
+ new RunState(inputs, outputs, rcg,
+ NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
{
mutex_lock l(mu_);
partial_runs_.emplace(
@@ -1566,6 +1604,13 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
}
run_state = it->second.get();
}
+ // CollectiveOps are not supported in partial runs.
+ if (req.options().experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ return errors::InvalidArgument(
+ "PartialRun does not support Collective ops. collective_graph_key "
+ "must be kNoCollectiveGraphKey.");
+ }
// If this is the first partial run, initialize the PerStepState.
if (!run_state->step_started) {
@@ -1743,7 +1788,11 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
-
+ if (rcg->build_graph_options().collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ env_->collective_executor_mgr->RetireStepId(
+ rcg->build_graph_options().collective_graph_key, step_id);
+ }
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
} else if (errors::IsCancelled(s)) {
@@ -1801,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- const uint64 step_id = MakeStepId();
+ uint64 step_id = NewStepId(bgopts.collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1865,9 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
+ const uint64 step_id =
+ NewStepId(rcg->build_graph_options().collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index ec34e20b79..449a6d3e3c 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -141,6 +141,8 @@ class MasterSession : public core::RefCounted {
std::atomic<int64> partial_run_handle_counter_ = {0};
+ uint64 NewStepId(int64 graph_key);
+
mutex mu_;
std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
int64 graph_version_;
@@ -175,6 +177,7 @@ class MasterSession : public core::RefCounted {
std::unordered_map<string, bool> pending_outputs; // true if fetched
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
+ int64 collective_graph_key;
int64 count = 0;
PerStepState pss;
std::unique_ptr<ProfileHandler> ph;
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 4b2747f26d..2eadfcde54 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -274,11 +274,14 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
+ "//tensorflow/core/distributed_runtime:device_resolver_distributed",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:local_master",
"//tensorflow/core/distributed_runtime:master",
"//tensorflow/core/distributed_runtime:master_env",
"//tensorflow/core/distributed_runtime:master_session",
+ "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
index f5dc4c831d..9b863ccee5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
@@ -74,7 +74,7 @@ class EagerGrpcServer : public GrpcServer {
this->eager_service_.reset(
new eager::GrpcEagerServiceImpl(worker_env, server_builder));
},
- nullptr));
+ nullptr, nullptr));
worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
"", worker_name_,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index c0a9b43bf4..43dbe20836 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/local_master.h"
#include "tensorflow/core/distributed_runtime/master.h"
@@ -38,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
@@ -106,6 +109,7 @@ GrpcServer::~GrpcServer() {
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory) {
mutex_lock l(mu_);
@@ -204,6 +208,26 @@ Status GrpcServer::Init(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
+ if (collective_mgr_func) {
+ worker_env_.collective_executor_mgr =
+ collective_mgr_func(config, &worker_env_, worker_cache);
+ if (!worker_env_.collective_executor_mgr) {
+ return errors::Internal(
+ "collective_mgr_func did not return CollectiveExecutorMgr");
+ }
+ } else {
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver(
+ new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
+ default_worker_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
+ new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
+ dev_resolver.get(), worker_cache,
+ default_worker_name));
+ worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
+ config, worker_env_.device_mgr, std::move(dev_resolver),
+ std::move(param_resolver), worker_cache, default_worker_name);
+ }
+
// Set up worker environment.
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
@@ -246,18 +270,21 @@ Status GrpcServer::Init(
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func) {
- return Init(std::move(service_func), rendezvous_mgr_func, worker_func,
- CreateNoOpStatsPublisher);
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ worker_func, CreateNoOpStatsPublisher);
}
Status GrpcServer::Init(
ServiceInitFunction service_func,
- const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
- return Init(service_func, rendezvous_mgr_func, nullptr);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func) {
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ nullptr);
}
-Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
+Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
@@ -403,7 +430,7 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
ServiceInitFunction service_func = nullptr;
- TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
*out_server = std::move(ret);
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index b1c2eda0cf..ca9946cafc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
@@ -41,6 +42,11 @@ class Master;
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
+// function that creates a CollectiveExecutorMgr.
+typedef std::function<CollectiveExecutorMgrInterface*(
+ const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
+ CollectiveMgrCreationFunction;
+
// function that registers a service to the server. The service needs to
// be registered before builder.BuildAndStart().
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
@@ -71,15 +77,18 @@ class GrpcServer : public ServerInterface {
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory);
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func);
Status Init(ServiceInitFunction service_func,
- const RendezvousMgrCreationFunction& rendezvous_mgr_func);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func);
Status Init();
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
new file mode 100644
index 0000000000..5eeed6e382
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name)
+ : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
+ std::move(param_resolver)),
+ worker_cache_(worker_cache),
+ task_name_(task_name) {
+ group_leader_ = (task_name == config.experimental().collective_group_leader())
+ ? ""
+ : config.experimental().collective_group_leader();
+}
+
+RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
+ for (auto it : sequence_table_) {
+ delete it.second;
+ }
+}
+
+CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessDistributed* rma =
+ new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+ worker_cache_, step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
+namespace {
+// StepId must leave the most-significant 7 bits empty for future use.
+static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
+
+int64 NewRandomStepId() {
+ int64 step_id = random::New64();
+ // Leave MS 8 bits clear for future use.
+ step_id &= kStepIdMask;
+ return step_id;
+}
+} // namespace
+
+void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ if (group_leader_.empty()) {
+ mutex_lock l(sequence_mu_);
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(graph_key);
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = NewRandomStepId();
+ done(Status::OK());
+ } else {
+ WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
+ GetStepSequenceRequest* req = new GetStepSequenceRequest;
+ GetStepSequenceResponse* resp = new GetStepSequenceResponse;
+ req->add_graph_key(graph_key);
+ wi->GetStepSequenceAsync(
+ req, resp, [this, req, resp, done](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Bad response [" << s
+ << "] from GetStepSequenceAsync call to "
+ << group_leader_;
+ done(s);
+ } else {
+ done(UpdateStepSequences(*resp));
+ }
+ delete req;
+ delete resp;
+ });
+ }
+}
+
+Status RpcCollectiveExecutorMgr::UpdateStepSequences(
+ const GetStepSequenceResponse& resp) {
+ mutex_lock l(sequence_mu_);
+ for (const StepSequence& ss : resp.step_sequence()) {
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(ss.graph_key());
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(ss.graph_key());
+ sequence_table_[ss.graph_key()] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = ss.next_step_id();
+ }
+ return Status::OK();
+}
+
+int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ return it->second->next_step_id_;
+ }
+ return CollectiveExecutor::kInvalidId;
+}
+
+void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ if (step_id == it->second->next_step_id_) {
+ it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
+ } else {
+ it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
+ }
+ } else {
+ LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
new file mode 100644
index 0000000000..e9f3f0ebe8
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+
+namespace tensorflow {
+class CollectiveParamResolverDistributed;
+class ConfigProto;
+class DeviceMgr;
+class DeviceResolverDistributed;
+class WorkerCacheInterface;
+class StepSequenceRequest;
+class StepSequenceResponse;
+
+// An implementation of CollectiveExecutorMgr for a distributed environment
+// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs.
+//
+// In some execution environments it may be possible to implement a
+// higher-performance solution and use it in place of this class.
+class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
+ public:
+ RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name);
+
+ virtual ~RpcCollectiveExecutorMgr();
+
+ void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) override;
+
+ int64 NextStepId(int64 graph_key) override;
+
+ void RetireStepId(int64 graph_key, int64 step_id) override;
+
+ protected:
+ CollectiveExecutor* Create(int64 step_id) override;
+
+ WorkerCacheInterface* const worker_cache_; // Not owned.
+ const string task_name_;
+ string group_leader_;
+ friend class RpcCollectiveExecutorMgrTest;
+
+ private:
+ Status UpdateStepSequences(const GetStepSequenceResponse& resp);
+
+ // This class maintains the step_id sequencing for a single
+ // collective_graph_key.
+ struct GraphKeySequence {
+ explicit GraphKeySequence(int64 k)
+ : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {}
+
+ const int64 graph_key_;
+ int64 next_step_id_;
+ };
+
+ mutex sequence_mu_;
+ gtl::FlatMap<int64, GraphKeySequence*> sequence_table_
+ GUARDED_BY(sequence_mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
new file mode 100644
index 0000000000..37b83d82be
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+#define NUM_DEVS 3
+
+class RpcCollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+ RpcCollectiveExecutorMgrTest() {
+ string task_name = "/job:localhost/replica:0/task:0";
+ SessionOptions options;
+ options.config.mutable_experimental()->set_collective_group_leader(
+ task_name);
+ WorkerCacheInterface* worker_cache = nullptr;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
+ device_mgr_.get(), worker_cache, task_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> cpr(
+ new CollectiveParamResolverDistributed(options.config,
+ device_mgr_.get(), dr.get(),
+ worker_cache, task_name));
+ // This CME is the group leader.
+ cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(),
+ std::move(dr), std::move(cpr),
+ worker_cache, task_name));
+ }
+
+ std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(RpcCollectiveExecutorMgrTest, FindOrCreate) {
+ CollectiveExecutor::Handle* h =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_TRUE(h->get());
+ CollectiveExecutor::Handle* h2 =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(h->get(), h2->get());
+ CollectiveExecutor* ce = h->get();
+ delete h;
+ delete h2;
+ CollectiveExecutor* ce2 = cme_->FindOrCreate(1);
+ EXPECT_EQ(ce, ce2);
+ ce2->Unref();
+ cme_->Cleanup(1);
+}
+
+TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) {
+ int64 x = cme_->NextStepId(7);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ // Calling Refresh should generate a valid id.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ EXPECT_TRUE(status.ok());
+ }
+ x = cme_->NextStepId(7);
+ EXPECT_NE(x, CollectiveExecutor::kInvalidId);
+ // Should keep returning same number.
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on a different graph_key should have no effect.
+ cme_->RetireStepId(6, x);
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on same graph_key should advance.
+ cme_->RetireStepId(7, x);
+ int64 y = cme_->NextStepId(7);
+ EXPECT_EQ((x + 1) & (((1uLL << 56) - 1) | (1uLL << 56)), y);
+ // Calling refresh should jump to a different point in the random space.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ int64 z = cme_->NextStepId(7);
+ // z should not be equal to or a successor of y.
+ EXPECT_NE(y, z);
+ EXPECT_GT(llabs(y - z), 3);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 4e6500fbc6..1ea19c48f0 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
@@ -72,7 +73,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(),
request->graph_options(), request->debug_options(),
- session->cluster_flr.get(), response->mutable_graph_handle());
+ request->collective_graph_key(), session->cluster_flr.get(),
+ response->mutable_graph_handle());
}
done(s);
}
@@ -315,6 +317,12 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->Cleanup(step_id);
}
+ for (Device* d : env_->local_devices) {
+ ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+ if (sam) {
+ sam->Cleanup(step_id);
+ }
+ }
done(Status::OK());
}