diff options
author | 2017-04-13 15:44:29 -0800 | |
---|---|---|
committer | 2017-04-13 17:04:07 -0700 | |
commit | dd3a6d364a739496b864e61e9a93c21cbc1b5d1d (patch) | |
tree | 6fe163e06959207b586ea6ce6a5277eac90535ae /tensorflow/cc/training/queue_runner.cc | |
parent | c8399f61ea845b0c440d13407429a92f6a0591e3 (diff) |
Only record the cost graph in the queue runner: this ensures that the memory
usage remains bounded over time.
Change: 153123196
Diffstat (limited to 'tensorflow/cc/training/queue_runner.cc')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 46 |
1 files changed, 23 insertions, 23 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 51eba5d8a1..324a62e1a9 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -82,9 +82,9 @@ QueueRunner::~QueueRunner() { Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } -Status QueueRunner::StartAndCollectRunMetadata(Session* sess, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(sess, 0); } @@ -115,10 +115,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } -Status QueueRunner::StartAndCollectRunMetadata(Session* session, - int wait_for_ms, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(session, wait_for_ms); } @@ -127,7 +126,7 @@ void QueueRunner::Stop(Session* sess) { coord_->WaitForStop(); } if (!cancel_op_name_.empty()) { - UpdateStatus(RealRun(sess, cancel_op_name_)); + UpdateStatus(RealRun(sess, cancel_op_name_, false)); } stopped_ = true; } @@ -162,7 +161,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { if (coord_ && coord_->ShouldStop()) { break; } - status = RealRun(sess, enqueue_op); + status = RealRun(sess, enqueue_op, true); if (first_iteration) { if (!status.ok()) { mutex_lock l(mu_); @@ -183,7 +182,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // will be run anway in this case. if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { if (last_run && !close_op_name_.empty()) { - UpdateStatus(RealRun(sess, close_op_name_)); + UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { LOG(ERROR) << "Queue runner thread got a failure status: " @@ -200,34 +199,35 @@ Status QueueRunner::GetStatus() { return status_; } -Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const { - if (!rm_mu_) { +Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { + if (!cg_mu_) { return Status(error::FAILED_PRECONDITION, - "This QueueRunner doesn't collect and store RunMetadata."); + "This QueueRunner doesn't collect a cost graph."); } - mutex_lock l(*rm_mu_); - metadata->MergeFrom(*run_metadata_); + mutex_lock l(*cg_mu_); + cost_graph->MergeFrom(*cost_graph_); return Status::OK(); } -void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) { - rm_mu_.reset(new mutex()); +void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) { + cg_mu_.reset(new mutex()); { - mutex_lock l(*rm_mu_); - run_metadata_.reset(new RunMetadata()); + mutex_lock l(*cg_mu_); + cost_graph_.reset(new CostGraphDef()); } if (run_options) { run_options_ = *run_options; } } -Status QueueRunner::RealRun(Session* sess, const string& op) { +Status QueueRunner::RealRun(Session* sess, const string& op, + bool update_costs) { Status s; - if (rm_mu_) { + if (update_costs && cg_mu_) { RunMetadata metadata; s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); - mutex_lock l(*rm_mu_); - run_metadata_->MergeFrom(metadata); + mutex_lock l(*cg_mu_); + cost_graph_->Swap(metadata.mutable_cost_graph()); } else { s = sess->Run({}, {}, {op}, nullptr); } |