From 9070f24ae15a4f589219d4cb9c962b14612c2d8c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 Jun 2018 18:12:16 -0700 Subject: Collective Ops Part 8 Enable collective op execution in distibuted mode: Pass collective_graph_key into graph building and step execution contexts (MasterSession) where it triggers allocation of an RpcCollectiveExecutorMgr that becomes accessible via the WorkerEnv and MasterEnv. The collective_graph_key is used to synchronize step_ids (which are otherwise random) between otherwise independent graph executions that contain collective ops that need to rendezvous. All APIs for using collectives are still non-public and experimental. PiperOrigin-RevId: 199879087 --- .../core/common_runtime/build_graph_options.cc | 3 + .../core/common_runtime/build_graph_options.h | 3 + .../core/common_runtime/collective_executor_mgr.cc | 18 ++- .../core/common_runtime/collective_executor_mgr.h | 9 +- .../common_runtime/collective_executor_mgr_test.cc | 11 +- .../collective_param_resolver_local.h | 2 +- tensorflow/core/common_runtime/direct_session.cc | 10 +- tensorflow/core/distributed_runtime/BUILD | 50 ++++++++ .../core/distributed_runtime/cancellable_call.h | 65 ++++++++++ .../collective_param_resolver_distributed.cc | 48 +------ .../collective_param_resolver_distributed_test.cc | 7 +- .../collective_rma_distributed.cc | 42 +----- tensorflow/core/distributed_runtime/graph_mgr.cc | 26 +++- tensorflow/core/distributed_runtime/graph_mgr.h | 8 +- tensorflow/core/distributed_runtime/master_env.h | 5 + .../core/distributed_runtime/master_session.cc | 78 ++++++++--- .../core/distributed_runtime/master_session.h | 3 + tensorflow/core/distributed_runtime/rpc/BUILD | 3 + .../rpc/eager/eager_grpc_server_lib.h | 2 +- .../distributed_runtime/rpc/grpc_server_lib.cc | 39 +++++- .../core/distributed_runtime/rpc/grpc_server_lib.h | 11 +- .../rpc_collective_executor_mgr.cc | 142 +++++++++++++++++++++ .../rpc_collective_executor_mgr.h | 79 ++++++++++++ .../rpc_collective_executor_mgr_test.cc | 124 ++++++++++++++++++ tensorflow/core/distributed_runtime/worker.cc | 10 +- 25 files changed, 659 insertions(+), 139 deletions(-) create mode 100644 tensorflow/core/distributed_runtime/cancellable_call.h create mode 100644 tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc create mode 100644 tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h create mode 100644 tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc index a9dc6ca6cd..00f7a8e645 100644 --- a/tensorflow/core/common_runtime/build_graph_options.cc +++ b/tensorflow/core/common_runtime/build_graph_options.cc @@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const { for (auto& s : callable_options.target()) { strings::StrAppend(&rv, s, ", "); } + if (collective_graph_key != kNoCollectiveGraphKey) { + strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key); + } return rv; } diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h index 5ca170e922..3d0f242ea5 100644 --- a/tensorflow/core/common_runtime/build_graph_options.h +++ b/tensorflow/core/common_runtime/build_graph_options.h @@ -31,6 +31,9 @@ struct BuildGraphOptions { // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval. bool use_function_convention = false; + static const int64 kNoCollectiveGraphKey = 0; + int64 collective_graph_key = kNoCollectiveGraphKey; + string DebugString() const; }; diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index e07829b286..4f03a5e13a 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -25,11 +25,11 @@ namespace tensorflow { CollectiveExecutorMgr::CollectiveExecutorMgr( const ConfigProto& config, const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - ParamResolverInterface* param_resolver) + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver) : dev_mgr_(dev_mgr), - dev_resolver_(dev_resolver), - param_resolver_(param_resolver) {} + dev_resolver_(std::move(dev_resolver)), + param_resolver_(std::move(param_resolver)) {} CollectiveExecutorMgr::~CollectiveExecutorMgr() { for (auto iter : executor_table_) { @@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { if (it != executor_table_.end()) { ce = it->second; } else { - CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal( - dev_mgr_, dev_resolver_.get(), step_id); - ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); + ce = Create(step_id); executor_table_[step_id] = ce; } ce->Ref(); @@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { return ce; } +CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessLocal* rma = + new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); +} + void CollectiveExecutorMgr::Cleanup(int64 step_id) { CollectiveExecutor* ce = nullptr; { diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h index 4b42e2b4d1..9de6ab8968 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/collective_executor_mgr.h @@ -25,8 +25,8 @@ class DeviceMgr; class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { public: CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - ParamResolverInterface* param_resolver); + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver); virtual ~CollectiveExecutorMgr(); @@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { void RetireStepId(int64 graph_key, int64 step_id) override {} protected: + // Called by FindOrCreate when table entry does not yet exist. + virtual CollectiveExecutor* Create(int64 step_id); + const DeviceMgr* dev_mgr_; std::unique_ptr dev_resolver_; std::unique_ptr param_resolver_; CollectiveRemoteAccess* remote_access_; string task_name_; + + private: mutex exec_mu_; // Map from step_id to CollectiveExecutor gtl::FlatMap executor_table_ GUARDED_BY(exec_mu_); diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc index 34c9163d6a..91994c5731 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc @@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test { device_count->insert({"CPU", NUM_DEVS}); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); device_mgr_.reset(new DeviceMgr(devices_)); - DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); - cme_.reset(new CollectiveExecutorMgr( - cp, device_mgr_.get(), drl, - new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name))); + std::unique_ptr drl( + new DeviceResolverLocal(device_mgr_.get())); + std::unique_ptr prl( + new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(), + task_name)); + cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl), + std::move(prl))); } std::unique_ptr cme_; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 3a871f962d..43c404f2ec 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -201,7 +201,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { LOCKS_EXCLUDED(irec->out_mu); const DeviceMgr* dev_mgr_; - DeviceResolverInterface* dev_resolver_; + DeviceResolverInterface* dev_resolver_; // Not owned. string task_name_; mutex group_mu_; gtl::FlatMap> group_table_ diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 07c1eafedc..5cef93c605 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -450,11 +450,13 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, // Set up for collectives if the RunOption declares a key. if (run_options.experimental().collective_graph_key() > 0) { if (!collective_executor_mgr_) { - DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); + std::unique_ptr drl( + new DeviceResolverLocal(device_mgr_.get())); + std::unique_ptr cprl( + new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(), + "/job:localhost/replica:0/task:0")); collective_executor_mgr_.reset(new CollectiveExecutorMgr( - options_.config, device_mgr_.get(), drl, - new CollectiveParamResolverLocal(device_mgr_.get(), drl, - "/job:localhost/replica:0/task:0"))); + options_.config, device_mgr_.get(), std::move(drl), std::move(cprl))); } run_state.collective_executor.reset(new CollectiveExecutor::Handle( collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index ead698d787..9032823e17 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -145,9 +145,11 @@ tf_cc_test( deps = [ ":session_mgr", ":worker_env", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", ], ) @@ -226,6 +228,17 @@ tf_cc_test( ], ) +cc_library( + name = "cancellable_call", + hdrs = ["cancellable_call.h"], + deps = [ + ":call_options", + ":worker_cache", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_cc_test( name = "tensor_coding_test", size = "small", @@ -392,6 +405,7 @@ cc_library( hdrs = ["master_env.h"], deps = [ ":worker_cache", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", ], @@ -452,11 +466,46 @@ cc_library( ], ) +cc_library( + name = "rpc_collective_executor_mgr", + srcs = ["rpc_collective_executor_mgr.cc"], + hdrs = ["rpc_collective_executor_mgr.h"], + deps = [ + ":base_rendezvous_mgr", + ":collective_param_resolver_distributed", + ":collective_rma_distributed", + ":device_resolver_distributed", + ":worker_cache", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:worker_proto_cc", + ], +) + +tf_cc_test( + name = "rpc_collective_executor_mgr_test", + srcs = ["rpc_collective_executor_mgr_test.cc"], + deps = [ + ":collective_param_resolver_distributed", + ":device_resolver_distributed", + ":rpc_collective_executor_mgr", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "collective_rma_distributed", srcs = ["collective_rma_distributed.cc"], hdrs = ["collective_rma_distributed.h"], deps = [ + ":cancellable_call", ":worker_cache", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -492,6 +541,7 @@ cc_library( hdrs = ["collective_param_resolver_distributed.h"], deps = [ ":call_options", + ":cancellable_call", ":device_resolver_distributed", ":worker_cache", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h new file mode 100644 index 0000000000..05089c7d15 --- /dev/null +++ b/tensorflow/core/distributed_runtime/cancellable_call.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ + +#include +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Supports client side cancellation of WorkerInterface calls via +// registration with a CancellationManager. +class CancellableCall { + public: + CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, + WorkerCacheInterface* wc) + : cancel_mgr_(cancel_mgr), + remote_worker_(remote_worker), + wc_(wc), + wi_(wc_->CreateWorker(remote_worker_)) {} + + virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } + + virtual void IssueCall(const StatusCallback& done) = 0; + + void Start(const StatusCallback& done) { + CancellationToken token = cancel_mgr_->get_cancellation_token(); + const bool not_yet_cancelled = cancel_mgr_->RegisterCallback( + token, [this, token]() { opts_.StartCancel(); }); + if (not_yet_cancelled) { + IssueCall([this, token, done](const Status& s) { + cancel_mgr_->DeregisterCallback(token); + done(s); + }); + } else { + done(errors::Cancelled("RPC Request was cancelled")); + } + } + + protected: + mutable mutex mu_; + CancellationManager* const cancel_mgr_; // Not owned + const string remote_worker_; + WorkerCacheInterface* const wc_; // Not owned + WorkerInterface* const wi_; // Owned by wc_, must be released. + CallOptions opts_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 7a93b54eae..612ac14e22 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -14,55 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" -#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/protobuf/config.pb.h" -// TODO(tucker): When we're ready to enable collectives this const will -// transition to a settable config member. -static const char FLAGS_collective_group_leader[] = - "/job:worker/replica:0/task:0"; - namespace tensorflow { namespace { -// Supports client side cancellation of WorkerInterface calls via -// registration with a CancellationManager. Note that ParamResolverInterface -// calls are done on behalf of an Op execution which needs to abort if the -// step in which it executes is cancelled. -class CancellableCall { - public: - CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, - WorkerCacheInterface* wc) - : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) { - wi_ = wc_->CreateWorker(remote_worker_); - } - virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } - - virtual void IssueCall(const StatusCallback& done) = 0; - - void Start(const StatusCallback& done) { - CancellationToken token = cancel_mgr_->get_cancellation_token(); - const bool not_yet_cancelled = cancel_mgr_->RegisterCallback( - token, [this, token]() { opts_.StartCancel(); }); - if (not_yet_cancelled) { - IssueCall([this, token, done](const Status& s) { - cancel_mgr_->DeregisterCallback(token); - done(s); - }); - } else { - done(errors::Cancelled("RPC Request was cancelled")); - } - } - - protected: - mutable mutex mu_; - CancellationManager* cancel_mgr_; // Not owned - const string remote_worker_; - WorkerCacheInterface* wc_; // Not owned - WorkerInterface* wi_; // Owned by wc_, must be released. - CallOptions opts_; -}; class CompleteGroupCall : public CancellableCall { public: @@ -126,9 +84,9 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed( const string& task_name) : CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name), worker_cache_(worker_cache), - group_leader_(task_name == FLAGS_collective_group_leader + group_leader_(task_name == config.experimental().collective_group_leader() ? "" - : FLAGS_collective_group_leader) {} + : config.experimental().collective_group_leader()) {} void CollectiveParamResolverDistributed::CompleteParamsAsync( const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 95a010286d..4eed856759 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -147,10 +147,9 @@ class DeviceResDistTest : public ::testing::Test { ConfigProto config; for (int w = 0; w < num_workers; ++w) { string name = strings::StrCat("/job:worker/replica:0/task:", w); - // TODO(tucker): When config option becomes available, set here. - // if (w == 0) { - // config.set_collective_group_leader(name); - // } + if (w == 0) { + config.mutable_experimental()->set_collective_group_leader(name); + } DefineWorker(config, name, device_type, num_devices); } } diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index c15878bfd3..d4c47cab49 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/transport_options.pb.h" @@ -28,45 +29,6 @@ namespace tensorflow { namespace { -// Supports client side cancellation of WorkerInterface calls via -// registration with a CancellationManager. -// -// TODO(tucker): Maybe unify this with CancellableCall in -// collective_param_resolver_distributed.cc. -class CancellableCall { - public: - CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, - WorkerCacheInterface* wc) - : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) { - wi_ = wc_->CreateWorker(remote_worker_); - } - virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } - - virtual void IssueCall(const StatusCallback& done) = 0; - - void Start(const StatusCallback& done) { - CancellationToken token = cancel_mgr_->get_cancellation_token(); - const bool not_yet_cancelled = cancel_mgr_->RegisterCallback( - token, [this, token]() { opts_.StartCancel(); }); - if (not_yet_cancelled) { - IssueCall([this, token, done](const Status& s) { - cancel_mgr_->DeregisterCallback(token); - done(s); - }); - } else { - done(errors::Cancelled("RPC Request was cancelled")); - } - } - - protected: - mutable mutex mu_; - CancellationManager* cancel_mgr_; // Not owned - const string remote_worker_; - WorkerCacheInterface* wc_; // Not owned - WorkerInterface* wi_; // Owned by wc_, must be released. - CallOptions opts_; -}; - class RecvBufCall : public CancellableCall { public: RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, @@ -119,7 +81,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( }; State* state = new State; - // Logic to be executed on the RecvBufferAsync callback. + // Logic to be executed on the RecvBufAsync callback. auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, to_device_ctx, to_tensor, done](const Status& s) { if (s.ok()) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8447c55bf4..e2f13df19f 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -118,9 +120,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug( Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, + int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, Item* item) { item->session = session; + item->collective_graph_key = collective_graph_key; item->lib_def.reset( new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library())); @@ -280,11 +284,12 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, Status GraphMgr::Register(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, + int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, string* handle) { Item* item = new Item; - Status s = - InitItem(session, gdef, graph_options, debug_options, cluster_flr, item); + Status s = InitItem(session, gdef, graph_options, debug_options, + collective_graph_key, cluster_flr, item); if (!s.ok()) { item->Unref(); return s; @@ -415,7 +420,12 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = rendezvous->Initialize(session); - + CollectiveExecutor::Handle* ce_handle = + item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey + ? new CollectiveExecutor::Handle( + worker_env_->collective_executor_mgr->FindOrCreate(step_id), + true) + : nullptr; // Sends values specified by the caller. if (s.ok()) { std::vector keys; @@ -431,22 +441,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, if (!s.ok()) { done(s); + delete ce_handle; item->Unref(); rendezvous->Unref(); return; } - StartParallelExecutors(handle, step_id, item, rendezvous, collector, - cost_graph, cancellation_manager, - [item, rendezvous, done](const Status& s) { + StartParallelExecutors(handle, step_id, item, rendezvous, ce_handle, + collector, cost_graph, cancellation_manager, + [item, rendezvous, ce_handle, done](const Status& s) { done(s); rendezvous->Unref(); item->Unref(); + delete ce_handle; }); } void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, Item* item, Rendezvous* rendezvous, + CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, @@ -471,6 +484,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, args.step_id = ++next_id_; } args.rendezvous = rendezvous; + args.collective_executor = ce_handle ? ce_handle->get() : nullptr; args.cancellation_manager = cancellation_manager; args.stats_collector = collector; args.step_container = step_container; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index cc35264b8f..5196046c19 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/message_wrappers.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/refcount.h" @@ -75,7 +76,7 @@ class GraphMgr { // reference to cluster_flr to do cross process function calls. Status Register(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, - const DebugOptions& debug_options, + const DebugOptions& debug_options, int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, string* handle); @@ -138,6 +139,8 @@ class GraphMgr { // Used to deregister a cost model when cost model is required in graph // manager. GraphMgr* graph_mgr; + + int64 collective_graph_key; }; const WorkerEnv* worker_env_; // Not owned. @@ -161,6 +164,7 @@ class GraphMgr { void StartParallelExecutors(const string& handle, int64 step_id, Item* item, Rendezvous* rendezvous, + CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, @@ -175,7 +179,7 @@ class GraphMgr { Status InitItem(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, - const DebugOptions& debug_options, + const DebugOptions& debug_options, int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, Item* item); Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index 16f4d93c8b..da26c42aca 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -26,6 +26,7 @@ limitations under the License. namespace tensorflow { +class CollectiveExecutorMgrInterface; class Device; class DeviceSet; class Env; @@ -90,6 +91,10 @@ struct MasterEnv { std::function worker_cache_factory; + + // Generates per-step CollectiveExecutors and has access to utilities + // supporting collective operations. + CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr; }; } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index e29bb76ddf..d34ca53f73 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -69,6 +70,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { bool is_partial, WorkerCacheInterface* worker_cache, bool should_deregister) : session_handle_(handle), + bg_opts_(bopts), client_graph_(std::move(cg)), session_opts_(session_opts), is_partial_(is_partial), @@ -100,6 +102,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const CallableOptions& callable_options() { return callable_opts_; } + const BuildGraphOptions& build_graph_options() { return bg_opts_; } + std::unique_ptr GetProfileHandler(uint64 step, int64 execution_count, const RunOptions& ropts) { @@ -225,6 +229,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { private: const string session_handle_; + const BuildGraphOptions bg_opts_; const std::unique_ptr client_graph_; const SessionOptions session_opts_; const bool is_partial_; @@ -444,6 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = callable_opts_.run_options().debug_options(); + c->req.set_collective_graph_key(bg_opts_.collective_graph_key); VLOG(2) << "Register " << c->req.graph_def().DebugString(); auto cb = [c, &done](const Status& s) { c->status = s; @@ -1065,6 +1071,9 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req, *callable_opts->mutable_run_options()->mutable_debug_options() = req.options().debug_options(); } + + opts->collective_graph_key = + req.options().experimental().collective_graph_key(); } void BuildBuildGraphOptions(const PartialRunSetupRequest& req, @@ -1102,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { h = Hash64(watch_summary.c_str(), watch_summary.size(), h); } + if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { + h = Hash64Combine(opts.collective_graph_key, h); + } + return h; } @@ -1118,6 +1131,9 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) { for (const string& name : opts.callable_options.fetch()) { strings::StrAppend(&buf, " FeE: ", name); } + if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { + strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key); + } strings::StrAppend(&buf, "\n"); return buf; } @@ -1430,11 +1446,35 @@ void MasterSession::ClearRunsTable(std::vector* to_unref, rcg_map->clear(); } -namespace { -uint64 MakeStepId() { - return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); +uint64 MasterSession::NewStepId(int64 graph_key) { + if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) { + // StepId must leave the most-significant 7 bits empty for future use. + return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56)); + } else { + uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key); + int32 retry_count = 0; + while (step_id == CollectiveExecutor::kInvalidId) { + Notification note; + Status status; + env_->collective_executor_mgr->RefreshStepIdSequenceAsync( + graph_key, [&status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Bad status from " + "collective_executor_mgr->RefreshStepIdSequence: " + << status << ". Retrying."; + int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count); + Env::Default()->SleepForMicroseconds(delay_micros); + } else { + step_id = env_->collective_executor_mgr->NextStepId(graph_key); + } + } + return step_id; + } } -} // namespace Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, PartialRunSetupResponse* resp) { @@ -1456,15 +1496,13 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, // Prepare. BuildGraphOptions opts; BuildBuildGraphOptions(*req, &opts); - int64 count; + int64 count = 0; TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count)); - // Keeps the highest 8 bits 0x01: we reserve some bits of the - // step_id for future use. - const uint64 step_id = MakeStepId(); - TRACEPRINTF("stepid %llu", step_id); rcg->Ref(); - RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count); + RunState* run_state = + new RunState(inputs, outputs, rcg, + NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count); { mutex_lock l(mu_); partial_runs_.emplace( @@ -1566,6 +1604,13 @@ Status MasterSession::DoPartialRun(CallOptions* opts, } run_state = it->second.get(); } + // CollectiveOps are not supported in partial runs. + if (req.options().experimental().collective_graph_key() != + BuildGraphOptions::kNoCollectiveGraphKey) { + return errors::InvalidArgument( + "PartialRun does not support Collective ops. collective_graph_key " + "must be kNoCollectiveGraphKey."); + } // If this is the first partial run, initialize the PerStepState. if (!run_state->step_started) { @@ -1743,7 +1788,11 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, Status s = run_status; if (s.ok()) { pss->end_micros = Env::Default()->NowMicros(); - + if (rcg->build_graph_options().collective_graph_key != + BuildGraphOptions::kNoCollectiveGraphKey) { + env_->collective_executor_mgr->RetireStepId( + rcg->build_graph_options().collective_graph_key, step_id); + } // Schedule post-processing and cleanup to be done asynchronously. rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); } else if (errors::IsCancelled(s)) { @@ -1801,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution( // Keeps the highest 8 bits 0x01: we reserve some bits of the // step_id for future use. - const uint64 step_id = MakeStepId(); + uint64 step_id = NewStepId(bgopts.collective_graph_key); TRACEPRINTF("stepid %llu", step_id); std::unique_ptr ph; @@ -1865,9 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, // Prepare. int64 count = rcg->get_and_increment_execution_count(); - // Keeps the highest 8 bits 0x01: we reserve some bits of the - // step_id for future use. - const uint64 step_id = MakeStepId(); + const uint64 step_id = + NewStepId(rcg->build_graph_options().collective_graph_key); TRACEPRINTF("stepid %llu", step_id); const RunOptions& run_options = rcg->callable_options().run_options(); diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index ec34e20b79..449a6d3e3c 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -141,6 +141,8 @@ class MasterSession : public core::RefCounted { std::atomic partial_run_handle_counter_ = {0}; + uint64 NewStepId(int64 graph_key); + mutex mu_; std::unique_ptr execution_state_ GUARDED_BY(mu_); int64 graph_version_; @@ -175,6 +177,7 @@ class MasterSession : public core::RefCounted { std::unordered_map pending_outputs; // true if fetched ReffedClientGraph* rcg = nullptr; uint64 step_id; + int64 collective_graph_key; int64 count = 0; PerStepState pss; std::unique_ptr ph; diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 4b2747f26d..2eadfcde54 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -274,11 +274,14 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime:graph_mgr", "//tensorflow/core/distributed_runtime:local_master", "//tensorflow/core/distributed_runtime:master", "//tensorflow/core/distributed_runtime:master_env", "//tensorflow/core/distributed_runtime:master_session", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h index f5dc4c831d..9b863ccee5 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h @@ -74,7 +74,7 @@ class EagerGrpcServer : public GrpcServer { this->eager_service_.reset( new eager::GrpcEagerServiceImpl(worker_env, server_builder)); }, - nullptr)); + nullptr, nullptr)); worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( "", worker_name_, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index c0a9b43bf4..43dbe20836 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/local_master.h" #include "tensorflow/core/distributed_runtime/master.h" @@ -38,6 +40,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/op.h" @@ -106,6 +109,7 @@ GrpcServer::~GrpcServer() { Status GrpcServer::Init( ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func, const StatsPublisherFactory& stats_factory) { mutex_lock l(mu_); @@ -204,6 +208,26 @@ Status GrpcServer::Init( WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); CHECK_NE(nullptr, worker_cache); + if (collective_mgr_func) { + worker_env_.collective_executor_mgr = + collective_mgr_func(config, &worker_env_, worker_cache); + if (!worker_env_.collective_executor_mgr) { + return errors::Internal( + "collective_mgr_func did not return CollectiveExecutorMgr"); + } + } else { + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed(config, worker_env_.device_mgr, + dev_resolver.get(), worker_cache, + default_worker_name)); + worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr( + config, worker_env_.device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name); + } + // Set up worker environment. worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), @@ -246,18 +270,21 @@ Status GrpcServer::Init( Status GrpcServer::Init( ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func) { - return Init(std::move(service_func), rendezvous_mgr_func, worker_func, - CreateNoOpStatsPublisher); + return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func, + worker_func, CreateNoOpStatsPublisher); } Status GrpcServer::Init( ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func) { - return Init(service_func, rendezvous_mgr_func, nullptr); + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func) { + return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func, + nullptr); } -Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); } +Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); } Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { @@ -403,7 +430,7 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); ServiceInitFunction service_func = nullptr; - TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr)); + TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr)); *out_server = std::move(ret); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index b1c2eda0cf..ca9946cafc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/platform/env.h" @@ -41,6 +42,11 @@ class Master; typedef std::function RendezvousMgrCreationFunction; +// function that creates a CollectiveExecutorMgr. +typedef std::function + CollectiveMgrCreationFunction; + // function that registers a service to the server. The service needs to // be registered before builder.BuildAndStart(). typedef std::function @@ -71,15 +77,18 @@ class GrpcServer : public ServerInterface { protected: Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func, const StatsPublisherFactory& stats_factory); Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func); Status Init(ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func); + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func); Status Init(); diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc new file mode 100644 index 0000000000..5eeed6e382 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name) + : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver)), + worker_cache_(worker_cache), + task_name_(task_name) { + group_leader_ = (task_name == config.experimental().collective_group_leader()) + ? "" + : config.experimental().collective_group_leader(); +} + +RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() { + for (auto it : sequence_table_) { + delete it.second; + } +} + +CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); +} + +namespace { +// StepId must leave the most-significant 7 bits empty for future use. +static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56)); + +int64 NewRandomStepId() { + int64 step_id = random::New64(); + // Leave MS 8 bits clear for future use. + step_id &= kStepIdMask; + return step_id; +} +} // namespace + +void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( + int64 graph_key, const StatusCallback& done) { + if (group_leader_.empty()) { + mutex_lock l(sequence_mu_); + GraphKeySequence* gks = nullptr; + auto it = sequence_table_.find(graph_key); + if (it == sequence_table_.end()) { + gks = new GraphKeySequence(graph_key); + sequence_table_[graph_key] = gks; + } else { + gks = it->second; + } + gks->next_step_id_ = NewRandomStepId(); + done(Status::OK()); + } else { + WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_); + GetStepSequenceRequest* req = new GetStepSequenceRequest; + GetStepSequenceResponse* resp = new GetStepSequenceResponse; + req->add_graph_key(graph_key); + wi->GetStepSequenceAsync( + req, resp, [this, req, resp, done](const Status& s) { + if (!s.ok()) { + LOG(ERROR) << "Bad response [" << s + << "] from GetStepSequenceAsync call to " + << group_leader_; + done(s); + } else { + done(UpdateStepSequences(*resp)); + } + delete req; + delete resp; + }); + } +} + +Status RpcCollectiveExecutorMgr::UpdateStepSequences( + const GetStepSequenceResponse& resp) { + mutex_lock l(sequence_mu_); + for (const StepSequence& ss : resp.step_sequence()) { + GraphKeySequence* gks = nullptr; + auto it = sequence_table_.find(ss.graph_key()); + if (it == sequence_table_.end()) { + gks = new GraphKeySequence(ss.graph_key()); + sequence_table_[ss.graph_key()] = gks; + } else { + gks = it->second; + } + gks->next_step_id_ = ss.next_step_id(); + } + return Status::OK(); +} + +int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) { + mutex_lock l(sequence_mu_); + auto it = sequence_table_.find(graph_key); + if (it != sequence_table_.end()) { + return it->second->next_step_id_; + } + return CollectiveExecutor::kInvalidId; +} + +void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) { + mutex_lock l(sequence_mu_); + auto it = sequence_table_.find(graph_key); + if (it != sequence_table_.end()) { + if (step_id == it->second->next_step_id_) { + it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask; + } else { + it->second->next_step_id_ = CollectiveExecutor::kInvalidId; + } + } else { + LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire."; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h new file mode 100644 index 0000000000..e9f3f0ebe8 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class CollectiveParamResolverDistributed; +class ConfigProto; +class DeviceMgr; +class DeviceResolverDistributed; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs. +// +// In some execution environments it may be possible to implement a +// higher-performance solution and use it in place of this class. +class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { + public: + RpcCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name); + + virtual ~RpcCollectiveExecutorMgr(); + + void RefreshStepIdSequenceAsync(int64 graph_key, + const StatusCallback& done) override; + + int64 NextStepId(int64 graph_key) override; + + void RetireStepId(int64 graph_key, int64 step_id) override; + + protected: + CollectiveExecutor* Create(int64 step_id) override; + + WorkerCacheInterface* const worker_cache_; // Not owned. + const string task_name_; + string group_leader_; + friend class RpcCollectiveExecutorMgrTest; + + private: + Status UpdateStepSequences(const GetStepSequenceResponse& resp); + + // This class maintains the step_id sequencing for a single + // collective_graph_key. + struct GraphKeySequence { + explicit GraphKeySequence(int64 k) + : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {} + + const int64 graph_key_; + int64 next_step_id_; + }; + + mutex sequence_mu_; + gtl::FlatMap sequence_table_ + GUARDED_BY(sequence_mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc new file mode 100644 index 0000000000..37b83d82be --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +#define NUM_DEVS 3 + +class RpcCollectiveExecutorMgrTest : public ::testing::Test { + protected: + RpcCollectiveExecutorMgrTest() { + string task_name = "/job:localhost/replica:0/task:0"; + SessionOptions options; + options.config.mutable_experimental()->set_collective_group_leader( + task_name); + WorkerCacheInterface* worker_cache = nullptr; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", NUM_DEVS}); + TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); + device_mgr_.reset(new DeviceMgr(devices_)); + std::unique_ptr dr(new DeviceResolverDistributed( + device_mgr_.get(), worker_cache, task_name)); + std::unique_ptr cpr( + new CollectiveParamResolverDistributed(options.config, + device_mgr_.get(), dr.get(), + worker_cache, task_name)); + // This CME is the group leader. + cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(), + std::move(dr), std::move(cpr), + worker_cache, task_name)); + } + + std::unique_ptr cme_; + std::vector devices_; + std::unique_ptr device_mgr_; +}; + +TEST_F(RpcCollectiveExecutorMgrTest, FindOrCreate) { + CollectiveExecutor::Handle* h = + new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true); + EXPECT_TRUE(h->get()); + CollectiveExecutor::Handle* h2 = + new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true); + EXPECT_EQ(h->get(), h2->get()); + CollectiveExecutor* ce = h->get(); + delete h; + delete h2; + CollectiveExecutor* ce2 = cme_->FindOrCreate(1); + EXPECT_EQ(ce, ce2); + ce2->Unref(); + cme_->Cleanup(1); +} + +TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) { + int64 x = cme_->NextStepId(7); + EXPECT_EQ(x, CollectiveExecutor::kInvalidId); + // Calling Refresh should generate a valid id. + { + Notification note; + Status status; + cme_->RefreshStepIdSequenceAsync(7, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + EXPECT_TRUE(status.ok()); + } + x = cme_->NextStepId(7); + EXPECT_NE(x, CollectiveExecutor::kInvalidId); + // Should keep returning same number. + EXPECT_EQ(x, cme_->NextStepId(7)); + EXPECT_EQ(x, cme_->NextStepId(7)); + // Retire on a different graph_key should have no effect. + cme_->RetireStepId(6, x); + EXPECT_EQ(x, cme_->NextStepId(7)); + // Retire on same graph_key should advance. + cme_->RetireStepId(7, x); + int64 y = cme_->NextStepId(7); + EXPECT_EQ((x + 1) & (((1uLL << 56) - 1) | (1uLL << 56)), y); + // Calling refresh should jump to a different point in the random space. + { + Notification note; + Status status; + cme_->RefreshStepIdSequenceAsync(7, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + + note.WaitForNotification(); + EXPECT_TRUE(status.ok()); + } + int64 z = cme_->NextStepId(7); + // z should not be equal to or a successor of y. + EXPECT_NE(y, z); + EXPECT_GT(llabs(y - z), 3); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 4e6500fbc6..1ea19c48f0 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" @@ -72,7 +73,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, s = session->graph_mgr->Register( request->session_handle(), request->graph_def(), request->graph_options(), request->debug_options(), - session->cluster_flr.get(), response->mutable_graph_handle()); + request->collective_graph_key(), session->cluster_flr.get(), + response->mutable_graph_handle()); } done(s); } @@ -315,6 +317,12 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, if (env_->collective_executor_mgr) { env_->collective_executor_mgr->Cleanup(step_id); } + for (Device* d : env_->local_devices) { + ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); + if (sam) { + sam->Cleanup(step_id); + } + } done(Status::OK()); } -- cgit v1.2.3