diff options
author | 2018-06-08 18:12:16 -0700 | |
---|---|---|
committer | 2018-06-08 18:14:57 -0700 | |
commit | 9070f24ae15a4f589219d4cb9c962b14612c2d8c (patch) | |
tree | 561e2362e67fc2c45ddd0e8736de2d9e5b5a022f /tensorflow/core/common_runtime | |
parent | 53901f9bb9a3965ed5dce65284053b0eb387b0c4 (diff) |
Collective Ops Part 8
Enable collective op execution in distibuted mode:
Pass collective_graph_key into graph building and
step execution contexts (MasterSession) where it triggers
allocation of an RpcCollectiveExecutorMgr that becomes
accessible via the WorkerEnv and MasterEnv.
The collective_graph_key is used to synchronize step_ids
(which are otherwise random) between otherwise independent
graph executions that contain collective ops that need
to rendezvous.
All APIs for using collectives are still non-public and
experimental.
PiperOrigin-RevId: 199879087
Diffstat (limited to 'tensorflow/core/common_runtime')
7 files changed, 38 insertions, 18 deletions
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc index a9dc6ca6cd..00f7a8e645 100644 --- a/tensorflow/core/common_runtime/build_graph_options.cc +++ b/tensorflow/core/common_runtime/build_graph_options.cc @@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const { for (auto& s : callable_options.target()) { strings::StrAppend(&rv, s, ", "); } + if (collective_graph_key != kNoCollectiveGraphKey) { + strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key); + } return rv; } diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h index 5ca170e922..3d0f242ea5 100644 --- a/tensorflow/core/common_runtime/build_graph_options.h +++ b/tensorflow/core/common_runtime/build_graph_options.h @@ -31,6 +31,9 @@ struct BuildGraphOptions { // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval. bool use_function_convention = false; + static const int64 kNoCollectiveGraphKey = 0; + int64 collective_graph_key = kNoCollectiveGraphKey; + string DebugString() const; }; diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index e07829b286..4f03a5e13a 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -25,11 +25,11 @@ namespace tensorflow { CollectiveExecutorMgr::CollectiveExecutorMgr( const ConfigProto& config, const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - ParamResolverInterface* param_resolver) + std::unique_ptr<DeviceResolverInterface> dev_resolver, + std::unique_ptr<ParamResolverInterface> param_resolver) : dev_mgr_(dev_mgr), - dev_resolver_(dev_resolver), - param_resolver_(param_resolver) {} + dev_resolver_(std::move(dev_resolver)), + param_resolver_(std::move(param_resolver)) {} CollectiveExecutorMgr::~CollectiveExecutorMgr() { for (auto iter : executor_table_) { @@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { if (it != executor_table_.end()) { ce = it->second; } else { - CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal( - dev_mgr_, dev_resolver_.get(), step_id); - ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); + ce = Create(step_id); executor_table_[step_id] = ce; } ce->Ref(); @@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { return ce; } +CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessLocal* rma = + new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); +} + void CollectiveExecutorMgr::Cleanup(int64 step_id) { CollectiveExecutor* ce = nullptr; { diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h index 4b42e2b4d1..9de6ab8968 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/collective_executor_mgr.h @@ -25,8 +25,8 @@ class DeviceMgr; class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { public: CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - ParamResolverInterface* param_resolver); + std::unique_ptr<DeviceResolverInterface> dev_resolver, + std::unique_ptr<ParamResolverInterface> param_resolver); virtual ~CollectiveExecutorMgr(); @@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { void RetireStepId(int64 graph_key, int64 step_id) override {} protected: + // Called by FindOrCreate when table entry does not yet exist. + virtual CollectiveExecutor* Create(int64 step_id); + const DeviceMgr* dev_mgr_; std::unique_ptr<DeviceResolverInterface> dev_resolver_; std::unique_ptr<ParamResolverInterface> param_resolver_; CollectiveRemoteAccess* remote_access_; string task_name_; + + private: mutex exec_mu_; // Map from step_id to CollectiveExecutor gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_); diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc index 34c9163d6a..91994c5731 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc @@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test { device_count->insert({"CPU", NUM_DEVS}); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); device_mgr_.reset(new DeviceMgr(devices_)); - DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); - cme_.reset(new CollectiveExecutorMgr( - cp, device_mgr_.get(), drl, - new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name))); + std::unique_ptr<DeviceResolverInterface> drl( + new DeviceResolverLocal(device_mgr_.get())); + std::unique_ptr<ParamResolverInterface> prl( + new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(), + task_name)); + cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl), + std::move(prl))); } std::unique_ptr<CollectiveExecutorMgr> cme_; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 3a871f962d..43c404f2ec 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -201,7 +201,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { LOCKS_EXCLUDED(irec->out_mu); const DeviceMgr* dev_mgr_; - DeviceResolverInterface* dev_resolver_; + DeviceResolverInterface* dev_resolver_; // Not owned. string task_name_; mutex group_mu_; gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_ diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 07c1eafedc..5cef93c605 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -450,11 +450,13 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, // Set up for collectives if the RunOption declares a key. if (run_options.experimental().collective_graph_key() > 0) { if (!collective_executor_mgr_) { - DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); + std::unique_ptr<DeviceResolverInterface> drl( + new DeviceResolverLocal(device_mgr_.get())); + std::unique_ptr<ParamResolverInterface> cprl( + new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(), + "/job:localhost/replica:0/task:0")); collective_executor_mgr_.reset(new CollectiveExecutorMgr( - options_.config, device_mgr_.get(), drl, - new CollectiveParamResolverLocal(device_mgr_.get(), drl, - "/job:localhost/replica:0/task:0"))); + options_.config, device_mgr_.get(), std::move(drl), std::move(cprl))); } run_state.collective_executor.reset(new CollectiveExecutor::Handle( collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); |