aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/master_session.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/master_session.cc')
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc83
1 files changed, 46 insertions, 37 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 878a1398c9..01da54fcb3 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
- debug_opts_(bopts.debug_options),
+ debug_opts_(bopts.callable_options.run_options().debug_options()),
worker_cache_(worker_cache),
should_deregister_(should_deregister) {
VLOG(1) << "Created ReffedClientGraph for node with "
@@ -921,61 +921,70 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
}
}
+namespace {
+void CopyAndSortStrings(size_t size,
+ const std::function<string(size_t)>& input_accessor,
+ protobuf::RepeatedPtrField<string>* output) {
+ std::vector<string> temp;
+ temp.reserve(size);
+ for (size_t i = 0; i < size; ++i) {
+ output->Add(input_accessor(i));
+ }
+ std::sort(output->begin(), output->end());
+}
+} // namespace
+
void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
BuildGraphOptions* opts) {
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- opts->feed_endpoints.push_back(req.feed_name(i));
- }
- for (size_t i = 0; i < req.num_fetches(); ++i) {
- opts->fetch_endpoints.push_back(req.fetch_name(i));
- }
- for (size_t i = 0; i < req.num_targets(); ++i) {
- opts->target_nodes.push_back(req.target_name(i));
- }
+ CallableOptions* callable_opts = &opts->callable_options;
+ CopyAndSortStrings(req.num_feeds(),
+ [&req](size_t i) { return req.feed_name(i); },
+ callable_opts->mutable_feed());
+ CopyAndSortStrings(req.num_fetches(),
+ [&req](size_t i) { return req.fetch_name(i); },
+ callable_opts->mutable_fetch());
+ CopyAndSortStrings(req.num_targets(),
+ [&req](size_t i) { return req.target_name(i); },
+ callable_opts->mutable_target());
if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
- opts->debug_options = req.options().debug_options();
+ *callable_opts->mutable_run_options()->mutable_debug_options() =
+ req.options().debug_options();
}
-
- std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
- std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
- std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
BuildGraphOptions* opts) {
- for (const auto& feed : req.feed()) {
- opts->feed_endpoints.push_back(feed);
- }
- for (const auto& fetch : req.fetch()) {
- opts->fetch_endpoints.push_back(fetch);
- }
- for (const auto& target : req.target()) {
- opts->target_nodes.push_back(target);
- }
+ CallableOptions* callable_opts = &opts->callable_options;
+ CopyAndSortStrings(req.feed_size(), [&req](size_t i) { return req.feed(i); },
+ callable_opts->mutable_feed());
+ CopyAndSortStrings(req.fetch_size(),
+ [&req](size_t i) { return req.fetch(i); },
+ callable_opts->mutable_fetch());
+ CopyAndSortStrings(req.target_size(),
+ [&req](size_t i) { return req.target(i); },
+ callable_opts->mutable_target());
// TODO(cais): Add TFDBG support to partial runs.
-
- std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
- std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
- std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
uint64 h = 0x2b992ddfa23249d6ull;
- for (const string& name : opts.feed_endpoints) {
+ for (const string& name : opts.callable_options.feed()) {
h = Hash64(name.c_str(), name.size(), h);
}
- for (const string& name : opts.target_nodes) {
+ for (const string& name : opts.callable_options.target()) {
h = Hash64(name.c_str(), name.size(), h);
}
- for (const string& name : opts.fetch_endpoints) {
+ for (const string& name : opts.callable_options.fetch()) {
h = Hash64(name.c_str(), name.size(), h);
}
- if (!opts.debug_options.debug_tensor_watch_opts().empty()) {
- const string watch_summary = SummarizeDebugTensorWatches(
- opts.debug_options.debug_tensor_watch_opts());
+ const DebugOptions& debug_options =
+ opts.callable_options.run_options().debug_options();
+ if (!debug_options.debug_tensor_watch_opts().empty()) {
+ const string watch_summary =
+ SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
@@ -984,15 +993,15 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
string BuildGraphOptionsString(const BuildGraphOptions& opts) {
string buf;
- for (const string& name : opts.feed_endpoints) {
+ for (const string& name : opts.callable_options.feed()) {
strings::StrAppend(&buf, " FdE: ", name);
}
strings::StrAppend(&buf, "\n");
- for (const string& name : opts.target_nodes) {
+ for (const string& name : opts.callable_options.target()) {
strings::StrAppend(&buf, " TN: ", name);
}
strings::StrAppend(&buf, "\n");
- for (const string& name : opts.fetch_endpoints) {
+ for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
strings::StrAppend(&buf, "\n");