aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 18:12:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 18:14:57 -0700
commit9070f24ae15a4f589219d4cb9c962b14612c2d8c (patch)
tree561e2362e67fc2c45ddd0e8736de2d9e5b5a022f /tensorflow/core/common_runtime
parent53901f9bb9a3965ed5dce65284053b0eb387b0c4 (diff)
Collective Ops Part 8
Enable collective op execution in distibuted mode: Pass collective_graph_key into graph building and step execution contexts (MasterSession) where it triggers allocation of an RpcCollectiveExecutorMgr that becomes accessible via the WorkerEnv and MasterEnv. The collective_graph_key is used to synchronize step_ids (which are otherwise random) between otherwise independent graph executions that contain collective ops that need to rendezvous. All APIs for using collectives are still non-public and experimental. PiperOrigin-RevId: 199879087
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.cc3
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h3
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc18
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h9
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr_test.cc11
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc10
7 files changed, 38 insertions, 18 deletions
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc
index a9dc6ca6cd..00f7a8e645 100644
--- a/tensorflow/core/common_runtime/build_graph_options.cc
+++ b/tensorflow/core/common_runtime/build_graph_options.cc
@@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const {
for (auto& s : callable_options.target()) {
strings::StrAppend(&rv, s, ", ");
}
+ if (collective_graph_key != kNoCollectiveGraphKey) {
+ strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
+ }
return rv;
}
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 5ca170e922..3d0f242ea5 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -31,6 +31,9 @@ struct BuildGraphOptions {
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
+ static const int64 kNoCollectiveGraphKey = 0;
+ int64 collective_graph_key = kNoCollectiveGraphKey;
+
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index e07829b286..4f03a5e13a 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -25,11 +25,11 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver)
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver)
: dev_mgr_(dev_mgr),
- dev_resolver_(dev_resolver),
- param_resolver_(param_resolver) {}
+ dev_resolver_(std::move(dev_resolver)),
+ param_resolver_(std::move(param_resolver)) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
if (it != executor_table_.end()) {
ce = it->second;
} else {
- CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
- dev_mgr_, dev_resolver_.get(), step_id);
- ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+ ce = Create(step_id);
executor_table_[step_id] = ce;
}
ce->Ref();
@@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
return ce;
}
+CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessLocal* rma =
+ new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
void CollectiveExecutorMgr::Cleanup(int64 step_id) {
CollectiveExecutor* ce = nullptr;
{
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 4b42e2b4d1..9de6ab8968 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -25,8 +25,8 @@ class DeviceMgr;
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
public:
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver);
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver);
virtual ~CollectiveExecutorMgr();
@@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
void RetireStepId(int64 graph_key, int64 step_id) override {}
protected:
+ // Called by FindOrCreate when table entry does not yet exist.
+ virtual CollectiveExecutor* Create(int64 step_id);
+
const DeviceMgr* dev_mgr_;
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
std::unique_ptr<ParamResolverInterface> param_resolver_;
CollectiveRemoteAccess* remote_access_;
string task_name_;
+
+ private:
mutex exec_mu_;
// Map from step_id to CollectiveExecutor
gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index 34c9163d6a..91994c5731 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
- cme_.reset(new CollectiveExecutorMgr(
- cp, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> prl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ task_name));
+ cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
+ std::move(prl)));
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 3a871f962d..43c404f2ec 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -201,7 +201,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
LOCKS_EXCLUDED(irec->out_mu);
const DeviceMgr* dev_mgr_;
- DeviceResolverInterface* dev_resolver_;
+ DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
mutex group_mu_;
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 07c1eafedc..5cef93c605 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -450,11 +450,13 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Set up for collectives if the RunOption declares a key.
if (run_options.experimental().collective_graph_key() > 0) {
if (!collective_executor_mgr_) {
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> cprl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
- options_.config, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl,
- "/job:localhost/replica:0/task:0")));
+ options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
}
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));