aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/collective_executor_mgr.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/collective_executor_mgr.cc')
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index e07829b286..4f03a5e13a 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -25,11 +25,11 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver)
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver)
: dev_mgr_(dev_mgr),
- dev_resolver_(dev_resolver),
- param_resolver_(param_resolver) {}
+ dev_resolver_(std::move(dev_resolver)),
+ param_resolver_(std::move(param_resolver)) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
if (it != executor_table_.end()) {
ce = it->second;
} else {
- CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
- dev_mgr_, dev_resolver_.get(), step_id);
- ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+ ce = Create(step_id);
executor_table_[step_id] = ce;
}
ce->Ref();
@@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
return ce;
}
+CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessLocal* rma =
+ new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
void CollectiveExecutorMgr::Cleanup(int64 step_id) {
CollectiveExecutor* ce = nullptr;
{