diff options
author | Ayush Dubey <ayushd@google.com> | 2018-08-30 19:49:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 19:54:26 -0700 |
commit | 73a3477356990f2451e220f553c9d7782df836ac (patch) | |
tree | 8894ebceebe247e59c0fbdb35ce92b45f63b6cdf /tensorflow/core/distributed_runtime | |
parent | 1cb6544826551524f0f53f3e9632f71e67ea7851 (diff) |
Initialize collective_graph_key based on the graph if unspecified in RunOptions.
Before this CL, for collective_ops to work, the client had to specify a
collective_graph_key in the RunOptions of a session.Run call.
After this change, if a client does not specify a collective_graph_key for a
graph that contains collective ops, a graph key is generated automatically as a
hash of the set of keys of collective instances in the placed graph.
PiperOrigin-RevId: 211024617
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.cc | 15 |
1 files changed, 5 insertions, 10 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index abd07e37b7..8e9eec1ed9 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = callable_opts_.run_options().debug_options(); - c->req.set_collective_graph_key(bg_opts_.collective_graph_key); + c->req.set_collective_graph_key(client_graph()->collective_graph_key); VLOG(2) << "Register " << c->req.graph_def().DebugString(); auto cb = [c, &done](const Status& s) { c->status = s; @@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { h = Hash64(watch_summary.c_str(), watch_summary.size(), h); } - if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { - h = Hash64Combine(opts.collective_graph_key, h); - } - return h; } @@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, Status s = run_status; if (s.ok()) { pss->end_micros = Env::Default()->NowMicros(); - if (rcg->build_graph_options().collective_graph_key != + if (rcg->client_graph()->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { env_->collective_executor_mgr->RetireStepId( - rcg->build_graph_options().collective_graph_key, step_id); + rcg->client_graph()->collective_graph_key, step_id); } // Schedule post-processing and cleanup to be done asynchronously. rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); @@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution( // Keeps the highest 8 bits 0x01: we reserve some bits of the // step_id for future use. - uint64 step_id = NewStepId(bgopts.collective_graph_key); + uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); TRACEPRINTF("stepid %llu", step_id); std::unique_ptr<ProfileHandler> ph; @@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, // Prepare. int64 count = rcg->get_and_increment_execution_count(); - const uint64 step_id = - NewStepId(rcg->build_graph_options().collective_graph_key); + const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); TRACEPRINTF("stepid %llu", step_id); const RunOptions& run_options = rcg->callable_options().run_options(); |