aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc29
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h4
2 files changed, 17 insertions, 16 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index cacaf83816..1918eae875 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -970,7 +970,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
handle_(strings::FpToString(random::New64())),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
- runs_(5),
+ run_graphs_(5),
+ partial_run_graphs_(5),
cancellation_manager_(new CancellationManager) {
UpdateLastAccessTime();
@@ -996,8 +997,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
MasterSession::~MasterSession() {
delete cancellation_manager_;
- for (const auto& iter : runs_) iter.second->Unref();
- for (const auto& iter : obsolete_) iter.second->Unref();
+ for (const auto& iter : run_graphs_) iter.second->Unref();
+ for (const auto& iter : partial_run_graphs_) iter.second->Unref();
for (Device* dev : remote_devs_) delete dev;
}
@@ -1065,23 +1066,23 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
// this session.
int64* c = &subgraph_execution_counts_[hash];
*count = (*c)++;
- auto iter = runs_.find(hash);
- if (iter == runs_.end()) {
+ // TODO(suharshs): We cache partial run graphs and run graphs separately
+ // because there is preprocessing that needs to only be run for partial
+ // run calls.
+ RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
+ auto iter = m->find(hash);
+ if (iter == m->end()) {
// We have not seen this subgraph before. Build the subgraph and
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
- << BuildGraphOptionsString(opts);
+ << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
+ << "\n";
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial);
- iter = runs_.insert({hash, entry}).first;
- auto obs_iter = obsolete_.find(hash);
- if (obs_iter != obsolete_.end()) {
- to_unref = obs_iter->second;
- obsolete_.erase(obs_iter);
- }
+ iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
*rcg = iter->second;
@@ -1383,8 +1384,8 @@ Status MasterSession::Close() {
while (num_running_ != 0) {
num_running_is_zero_.wait(l);
}
- ClearRunsTable(&to_unref, &runs_);
- ClearRunsTable(&to_unref, &obsolete_);
+ ClearRunsTable(&to_unref, &run_graphs_);
+ ClearRunsTable(&to_unref, &partial_run_graphs_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
delete this;
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 96d759d9c8..4af6ab6681 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -119,8 +119,8 @@ class MasterSession {
// scope and lose their state.
class ReffedClientGraph;
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
- RCGMap runs_ GUARDED_BY(mu_);
- RCGMap obsolete_ GUARDED_BY(mu_);
+ RCGMap run_graphs_ GUARDED_BY(mu_);
+ RCGMap partial_run_graphs_ GUARDED_BY(mu_);
struct PerStepState {
bool collect_costs = false;