diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/master_session.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.cc | 83 |
1 files changed, 46 insertions, 37 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 878a1398c9..01da54fcb3 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { client_graph_(std::move(cg)), session_opts_(session_opts), is_partial_(is_partial), - debug_opts_(bopts.debug_options), + debug_opts_(bopts.callable_options.run_options().debug_options()), worker_cache_(worker_cache), should_deregister_(should_deregister) { VLOG(1) << "Created ReffedClientGraph for node with " @@ -921,61 +921,70 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { } } +namespace { +void CopyAndSortStrings(size_t size, + const std::function<string(size_t)>& input_accessor, + protobuf::RepeatedPtrField<string>* output) { + std::vector<string> temp; + temp.reserve(size); + for (size_t i = 0; i < size; ++i) { + output->Add(input_accessor(i)); + } + std::sort(output->begin(), output->end()); +} +} // namespace + void BuildBuildGraphOptions(const RunStepRequestWrapper& req, BuildGraphOptions* opts) { - for (size_t i = 0; i < req.num_feeds(); ++i) { - opts->feed_endpoints.push_back(req.feed_name(i)); - } - for (size_t i = 0; i < req.num_fetches(); ++i) { - opts->fetch_endpoints.push_back(req.fetch_name(i)); - } - for (size_t i = 0; i < req.num_targets(); ++i) { - opts->target_nodes.push_back(req.target_name(i)); - } + CallableOptions* callable_opts = &opts->callable_options; + CopyAndSortStrings(req.num_feeds(), + [&req](size_t i) { return req.feed_name(i); }, + callable_opts->mutable_feed()); + CopyAndSortStrings(req.num_fetches(), + [&req](size_t i) { return req.fetch_name(i); }, + callable_opts->mutable_fetch()); + CopyAndSortStrings(req.num_targets(), + [&req](size_t i) { return req.target_name(i); }, + callable_opts->mutable_target()); if (!req.options().debug_options().debug_tensor_watch_opts().empty()) { - opts->debug_options = req.options().debug_options(); + *callable_opts->mutable_run_options()->mutable_debug_options() = + req.options().debug_options(); } - - std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end()); - std::sort(opts->target_nodes.begin(), opts->target_nodes.end()); - std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end()); } void BuildBuildGraphOptions(const PartialRunSetupRequest& req, BuildGraphOptions* opts) { - for (const auto& feed : req.feed()) { - opts->feed_endpoints.push_back(feed); - } - for (const auto& fetch : req.fetch()) { - opts->fetch_endpoints.push_back(fetch); - } - for (const auto& target : req.target()) { - opts->target_nodes.push_back(target); - } + CallableOptions* callable_opts = &opts->callable_options; + CopyAndSortStrings(req.feed_size(), [&req](size_t i) { return req.feed(i); }, + callable_opts->mutable_feed()); + CopyAndSortStrings(req.fetch_size(), + [&req](size_t i) { return req.fetch(i); }, + callable_opts->mutable_fetch()); + CopyAndSortStrings(req.target_size(), + [&req](size_t i) { return req.target(i); }, + callable_opts->mutable_target()); // TODO(cais): Add TFDBG support to partial runs. - - std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end()); - std::sort(opts->target_nodes.begin(), opts->target_nodes.end()); - std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end()); } uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { uint64 h = 0x2b992ddfa23249d6ull; - for (const string& name : opts.feed_endpoints) { + for (const string& name : opts.callable_options.feed()) { h = Hash64(name.c_str(), name.size(), h); } - for (const string& name : opts.target_nodes) { + for (const string& name : opts.callable_options.target()) { h = Hash64(name.c_str(), name.size(), h); } - for (const string& name : opts.fetch_endpoints) { + for (const string& name : opts.callable_options.fetch()) { h = Hash64(name.c_str(), name.size(), h); } - if (!opts.debug_options.debug_tensor_watch_opts().empty()) { - const string watch_summary = SummarizeDebugTensorWatches( - opts.debug_options.debug_tensor_watch_opts()); + const DebugOptions& debug_options = + opts.callable_options.run_options().debug_options(); + if (!debug_options.debug_tensor_watch_opts().empty()) { + const string watch_summary = + SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts()); h = Hash64(watch_summary.c_str(), watch_summary.size(), h); } @@ -984,15 +993,15 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { string BuildGraphOptionsString(const BuildGraphOptions& opts) { string buf; - for (const string& name : opts.feed_endpoints) { + for (const string& name : opts.callable_options.feed()) { strings::StrAppend(&buf, " FdE: ", name); } strings::StrAppend(&buf, "\n"); - for (const string& name : opts.target_nodes) { + for (const string& name : opts.callable_options.target()) { strings::StrAppend(&buf, " TN: ", name); } strings::StrAppend(&buf, "\n"); - for (const string& name : opts.fetch_endpoints) { + for (const string& name : opts.callable_options.fetch()) { strings::StrAppend(&buf, " FeE: ", name); } strings::StrAppend(&buf, "\n"); |