diff options
author | 2017-04-13 15:44:29 -0800 | |
---|---|---|
committer | 2017-04-13 17:04:07 -0700 | |
commit | dd3a6d364a739496b864e61e9a93c21cbc1b5d1d (patch) | |
tree | 6fe163e06959207b586ea6ce6a5277eac90535ae /tensorflow/cc/training | |
parent | c8399f61ea845b0c440d13407429a92f6a0591e3 (diff) |
Only record the cost graph in the queue runner: this ensures that the memory
usage remains bounded over time.
Change: 153123196
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/coordinator.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator.h | 4 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 46 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 24 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 51 |
5 files changed, 75 insertions, 64 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 4618c932c3..fe45931f7f 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -116,17 +116,13 @@ void Coordinator::WaitForStop() { } Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { - RunMetadata tmp_metadata; - { - mutex_lock l(runners_lock_); - for (auto& t : runners_) { - Status s = t->ExportRunMetadata(&tmp_metadata); - if (!s.ok()) { - return s; - } + mutex_lock l(runners_lock_); + for (auto& t : runners_) { + Status s = t->ExportCostGraph(cost_graph); + if (!s.ok()) { + return s; } } - cost_graph->MergeFrom(tmp_metadata.cost_graph()); return Status::OK(); } diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 632418c5ca..0e01b19cd9 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -36,8 +36,8 @@ class RunnerInterface { public: virtual ~RunnerInterface() {} virtual Status Join() = 0; - virtual Status ExportRunMetadata(RunMetadata* metadata) const { - return Status(error::INVALID_ARGUMENT, "No RunMetadata to export."); + virtual Status ExportCostGraph(CostGraphDef* cost_graph) const { + return Status(error::INVALID_ARGUMENT, "No cost model to export."); } /// Returns true iff the runner is running, i.e. if it is trying to populate /// its queue. diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 51eba5d8a1..324a62e1a9 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -82,9 +82,9 @@ QueueRunner::~QueueRunner() { Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } -Status QueueRunner::StartAndCollectRunMetadata(Session* sess, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(sess, 0); } @@ -115,10 +115,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } -Status QueueRunner::StartAndCollectRunMetadata(Session* session, - int wait_for_ms, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(session, wait_for_ms); } @@ -127,7 +126,7 @@ void QueueRunner::Stop(Session* sess) { coord_->WaitForStop(); } if (!cancel_op_name_.empty()) { - UpdateStatus(RealRun(sess, cancel_op_name_)); + UpdateStatus(RealRun(sess, cancel_op_name_, false)); } stopped_ = true; } @@ -162,7 +161,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { if (coord_ && coord_->ShouldStop()) { break; } - status = RealRun(sess, enqueue_op); + status = RealRun(sess, enqueue_op, true); if (first_iteration) { if (!status.ok()) { mutex_lock l(mu_); @@ -183,7 +182,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // will be run anway in this case. if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { if (last_run && !close_op_name_.empty()) { - UpdateStatus(RealRun(sess, close_op_name_)); + UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { LOG(ERROR) << "Queue runner thread got a failure status: " @@ -200,34 +199,35 @@ Status QueueRunner::GetStatus() { return status_; } -Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const { - if (!rm_mu_) { +Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { + if (!cg_mu_) { return Status(error::FAILED_PRECONDITION, - "This QueueRunner doesn't collect and store RunMetadata."); + "This QueueRunner doesn't collect a cost graph."); } - mutex_lock l(*rm_mu_); - metadata->MergeFrom(*run_metadata_); + mutex_lock l(*cg_mu_); + cost_graph->MergeFrom(*cost_graph_); return Status::OK(); } -void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) { - rm_mu_.reset(new mutex()); +void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) { + cg_mu_.reset(new mutex()); { - mutex_lock l(*rm_mu_); - run_metadata_.reset(new RunMetadata()); + mutex_lock l(*cg_mu_); + cost_graph_.reset(new CostGraphDef()); } if (run_options) { run_options_ = *run_options; } } -Status QueueRunner::RealRun(Session* sess, const string& op) { +Status QueueRunner::RealRun(Session* sess, const string& op, + bool update_costs) { Status s; - if (rm_mu_) { + if (update_costs && cg_mu_) { RunMetadata metadata; s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); - mutex_lock l(*rm_mu_); - run_metadata_->MergeFrom(metadata); + mutex_lock l(*cg_mu_); + cost_graph_->Swap(metadata.mutable_cost_graph()); } else { s = sess->Run({}, {}, {op}, nullptr); } diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index c69f28793a..71ed44c9c6 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -60,15 +60,15 @@ class QueueRunner : public RunnerInterface { Status Start(Session* sess); /// Starts the queue runner with the given session and sets the run arguments - /// for sess->Run. It also collects and stores the run metedata. - Status StartAndCollectRunMetadata(Session* sess, - const RunOptions* run_options = nullptr); + /// for sess->Run. It also collects and stores the cost model. + Status StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options = nullptr); /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); - Status StartAndCollectRunMetadata(Session* session, int wait_for_ms, - const RunOptions* run_options = nullptr); + Status StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -82,11 +82,11 @@ class QueueRunner : public RunnerInterface { /// Returns the latest status. Status GetStatus(); - // Returns the stored run metadata. - Status ExportRunMetadata(RunMetadata* metadata) const override; + // Returns the stored cost model. + Status ExportCostGraph(CostGraphDef* cost_graph) const override; private: - QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {} + QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -105,9 +105,9 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } - void SetRunArgumentsAndRunMetadata(const RunOptions* run_options); + void SetRunArgumentsAndCostGraph(const RunOptions* run_options); - Status RealRun(Session* sess, const string& op); + Status RealRun(Session* sess, const string& op, bool update_costs); string queue_name_; std::vector<string> enqueue_op_names_; @@ -130,8 +130,8 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector<std::function<void(Status)>> callbacks_; - mutable std::unique_ptr<mutex> rm_mu_; - std::unique_ptr<RunMetadata> run_metadata_ GUARDED_BY(rm_mu_); + mutable std::unique_ptr<mutex> cg_mu_; + std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_); RunOptions run_options_; }; diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index c37a69a7f7..e814be8630 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -44,6 +44,7 @@ using ops::FIFOQueue; using ops::QueueClose; using ops::QueueDequeue; using ops::QueueEnqueue; +using ops::RandomNormal; using ops::Square; using ops::Variable; @@ -84,7 +85,7 @@ QueueRunnerDef BuildQueueRunnerDef( const std::string& close_op, const std::string& cancel_op, const std::vector<Code>& queue_closed_error_codes) { QueueRunnerDef queue_runner_def; - *queue_runner_def.mutable_queue_name() = kQueueName; + *queue_runner_def.mutable_queue_name() = queue_name; for (const std::string& enqueue_op : enqueue_ops) { *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; } @@ -345,37 +346,51 @@ TEST(QueueRunnerTest, CallbackCalledOnError) { } TEST(QueueRunnerTest, RunMetaDataTest) { + Scope root = Scope::NewRootScope(); + auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT}); + Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT); + Output square = Square(root.WithOpName(kSquareOpName), rnd); + auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square}); + auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); + auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, + QueueClose::CancelPendingEnqueues(true)); + auto dequeue0 = + QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT}); + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + for (auto& node : *graph_def.mutable_node()) { + node.set_device("/cpu:0"); + } SessionOptions sess_options; sess_options.config.mutable_graph_options()->set_build_cost_model(1); std::unique_ptr<Session> session(NewSession(sess_options)); - GraphDef graph_def = BuildSimpleGraph(); TF_CHECK_OK(session->Create(graph_def)); - TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr)); - RunOptions run_options; - run_options.set_trace_level(RunOptions::HARDWARE_TRACE); - - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {}); std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - TF_CHECK_OK(qr->StartAndCollectRunMetadata(session.get(), &run_options)); + RunOptions run_options; + TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), &run_options)); - TF_EXPECT_OK(qr->Join()); - RunMetadata run_metadata; - TF_CHECK_OK(qr->ExportRunMetadata(&run_metadata)); + // Make sure there was at least one element enqueued in q0: this prevents a + // race condition where we close the queue before it was populated. + std::vector<Tensor> dq0; + TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); + + CostGraphDef cost_graph; + TF_CHECK_OK(qr->ExportCostGraph(&cost_graph)); + EXPECT_TRUE(cost_graph.node_size() > 0); - EXPECT_TRUE(run_metadata.has_cost_graph()); + qr->Stop(session.get()); } TEST(QueueRunnerTest, NoRunMetaDataTest) { GraphDef graph_def = BuildSimpleGraph(); auto session = BuildSessionAndInitVariable(graph_def); - RunOptions run_options; - run_options.set_trace_level(RunOptions::HARDWARE_TRACE); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); std::unique_ptr<QueueRunner> qr; @@ -383,8 +398,8 @@ TEST(QueueRunnerTest, NoRunMetaDataTest) { TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(qr->Join()); - RunMetadata run_metadata; - EXPECT_EQ(qr->ExportRunMetadata(&run_metadata).code(), + CostGraphDef cost_graph; + EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(), error::FAILED_PRECONDITION); } |