diff options
Diffstat (limited to 'tensorflow/core/common_runtime/collective_executor_mgr.cc')
-rw-r--r-- | tensorflow/core/common_runtime/collective_executor_mgr.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index e07829b286..4f03a5e13a 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -25,11 +25,11 @@ namespace tensorflow { CollectiveExecutorMgr::CollectiveExecutorMgr( const ConfigProto& config, const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - ParamResolverInterface* param_resolver) + std::unique_ptr<DeviceResolverInterface> dev_resolver, + std::unique_ptr<ParamResolverInterface> param_resolver) : dev_mgr_(dev_mgr), - dev_resolver_(dev_resolver), - param_resolver_(param_resolver) {} + dev_resolver_(std::move(dev_resolver)), + param_resolver_(std::move(param_resolver)) {} CollectiveExecutorMgr::~CollectiveExecutorMgr() { for (auto iter : executor_table_) { @@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { if (it != executor_table_.end()) { ce = it->second; } else { - CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal( - dev_mgr_, dev_resolver_.get(), step_id); - ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); + ce = Create(step_id); executor_table_[step_id] = ce; } ce->Ref(); @@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { return ce; } +CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessLocal* rma = + new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); +} + void CollectiveExecutorMgr::Cleanup(int64 step_id) { CollectiveExecutor* ce = nullptr; { |