aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-04-13 15:44:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-13 17:04:07 -0700
commitdd3a6d364a739496b864e61e9a93c21cbc1b5d1d (patch)
tree6fe163e06959207b586ea6ce6a5277eac90535ae /tensorflow/cc/training
parentc8399f61ea845b0c440d13407429a92f6a0591e3 (diff)
Only record the cost graph in the queue runner: this ensures that the memory
usage remains bounded over time. Change: 153123196
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/coordinator.cc14
-rw-r--r--tensorflow/cc/training/coordinator.h4
-rw-r--r--tensorflow/cc/training/queue_runner.cc46
-rw-r--r--tensorflow/cc/training/queue_runner.h24
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc51
5 files changed, 75 insertions, 64 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc
index 4618c932c3..fe45931f7f 100644
--- a/tensorflow/cc/training/coordinator.cc
+++ b/tensorflow/cc/training/coordinator.cc
@@ -116,17 +116,13 @@ void Coordinator::WaitForStop() {
}
Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const {
- RunMetadata tmp_metadata;
- {
- mutex_lock l(runners_lock_);
- for (auto& t : runners_) {
- Status s = t->ExportRunMetadata(&tmp_metadata);
- if (!s.ok()) {
- return s;
- }
+ mutex_lock l(runners_lock_);
+ for (auto& t : runners_) {
+ Status s = t->ExportCostGraph(cost_graph);
+ if (!s.ok()) {
+ return s;
}
}
- cost_graph->MergeFrom(tmp_metadata.cost_graph());
return Status::OK();
}
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
index 632418c5ca..0e01b19cd9 100644
--- a/tensorflow/cc/training/coordinator.h
+++ b/tensorflow/cc/training/coordinator.h
@@ -36,8 +36,8 @@ class RunnerInterface {
public:
virtual ~RunnerInterface() {}
virtual Status Join() = 0;
- virtual Status ExportRunMetadata(RunMetadata* metadata) const {
- return Status(error::INVALID_ARGUMENT, "No RunMetadata to export.");
+ virtual Status ExportCostGraph(CostGraphDef* cost_graph) const {
+ return Status(error::INVALID_ARGUMENT, "No cost model to export.");
}
/// Returns true iff the runner is running, i.e. if it is trying to populate
/// its queue.
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 51eba5d8a1..324a62e1a9 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -82,9 +82,9 @@ QueueRunner::~QueueRunner() {
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
-Status QueueRunner::StartAndCollectRunMetadata(Session* sess,
- const RunOptions* run_options) {
- SetRunArgumentsAndRunMetadata(run_options);
+Status QueueRunner::StartAndCollectCostGraph(Session* sess,
+ const RunOptions* run_options) {
+ SetRunArgumentsAndCostGraph(run_options);
return Start(sess, 0);
}
@@ -115,10 +115,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
return Status::OK();
}
-Status QueueRunner::StartAndCollectRunMetadata(Session* session,
- int wait_for_ms,
- const RunOptions* run_options) {
- SetRunArgumentsAndRunMetadata(run_options);
+Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
+ const RunOptions* run_options) {
+ SetRunArgumentsAndCostGraph(run_options);
return Start(session, wait_for_ms);
}
@@ -127,7 +126,7 @@ void QueueRunner::Stop(Session* sess) {
coord_->WaitForStop();
}
if (!cancel_op_name_.empty()) {
- UpdateStatus(RealRun(sess, cancel_op_name_));
+ UpdateStatus(RealRun(sess, cancel_op_name_, false));
}
stopped_ = true;
}
@@ -162,7 +161,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
if (coord_ && coord_->ShouldStop()) {
break;
}
- status = RealRun(sess, enqueue_op);
+ status = RealRun(sess, enqueue_op, true);
if (first_iteration) {
if (!status.ok()) {
mutex_lock l(mu_);
@@ -183,7 +182,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// will be run anway in this case.
if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
if (last_run && !close_op_name_.empty()) {
- UpdateStatus(RealRun(sess, close_op_name_));
+ UpdateStatus(RealRun(sess, close_op_name_, false));
}
} else if (!status.ok()) {
LOG(ERROR) << "Queue runner thread got a failure status: "
@@ -200,34 +199,35 @@ Status QueueRunner::GetStatus() {
return status_;
}
-Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const {
- if (!rm_mu_) {
+Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
+ if (!cg_mu_) {
return Status(error::FAILED_PRECONDITION,
- "This QueueRunner doesn't collect and store RunMetadata.");
+ "This QueueRunner doesn't collect a cost graph.");
}
- mutex_lock l(*rm_mu_);
- metadata->MergeFrom(*run_metadata_);
+ mutex_lock l(*cg_mu_);
+ cost_graph->MergeFrom(*cost_graph_);
return Status::OK();
}
-void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) {
- rm_mu_.reset(new mutex());
+void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) {
+ cg_mu_.reset(new mutex());
{
- mutex_lock l(*rm_mu_);
- run_metadata_.reset(new RunMetadata());
+ mutex_lock l(*cg_mu_);
+ cost_graph_.reset(new CostGraphDef());
}
if (run_options) {
run_options_ = *run_options;
}
}
-Status QueueRunner::RealRun(Session* sess, const string& op) {
+Status QueueRunner::RealRun(Session* sess, const string& op,
+ bool update_costs) {
Status s;
- if (rm_mu_) {
+ if (update_costs && cg_mu_) {
RunMetadata metadata;
s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
- mutex_lock l(*rm_mu_);
- run_metadata_->MergeFrom(metadata);
+ mutex_lock l(*cg_mu_);
+ cost_graph_->Swap(metadata.mutable_cost_graph());
} else {
s = sess->Run({}, {}, {op}, nullptr);
}
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index c69f28793a..71ed44c9c6 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -60,15 +60,15 @@ class QueueRunner : public RunnerInterface {
Status Start(Session* sess);
/// Starts the queue runner with the given session and sets the run arguments
- /// for sess->Run. It also collects and stores the run metedata.
- Status StartAndCollectRunMetadata(Session* sess,
- const RunOptions* run_options = nullptr);
+ /// for sess->Run. It also collects and stores the cost model.
+ Status StartAndCollectCostGraph(Session* sess,
+ const RunOptions* run_options = nullptr);
/// Starts the queue runner with the given session, and wait for up to the
/// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for_ms);
- Status StartAndCollectRunMetadata(Session* session, int wait_for_ms,
- const RunOptions* run_options = nullptr);
+ Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
+ const RunOptions* run_options = nullptr);
/// Requests to stop and runs the cancel op. It would be called in a separate
/// thread when coordinator is set. If there is no coordinator it should be
@@ -82,11 +82,11 @@ class QueueRunner : public RunnerInterface {
/// Returns the latest status.
Status GetStatus();
- // Returns the stored run metadata.
- Status ExportRunMetadata(RunMetadata* metadata) const override;
+ // Returns the stored cost model.
+ Status ExportCostGraph(CostGraphDef* cost_graph) const override;
private:
- QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {}
+ QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}
// Initializes the instance with the QueueRunnerDef proto.
Status Init(const QueueRunnerDef& queue_runner_def);
@@ -105,9 +105,9 @@ class QueueRunner : public RunnerInterface {
bool IsRunning() const override { return !stopped_; }
- void SetRunArgumentsAndRunMetadata(const RunOptions* run_options);
+ void SetRunArgumentsAndCostGraph(const RunOptions* run_options);
- Status RealRun(Session* sess, const string& op);
+ Status RealRun(Session* sess, const string& op, bool update_costs);
string queue_name_;
std::vector<string> enqueue_op_names_;
@@ -130,8 +130,8 @@ class QueueRunner : public RunnerInterface {
mutex cb_mu_;
std::vector<std::function<void(Status)>> callbacks_;
- mutable std::unique_ptr<mutex> rm_mu_;
- std::unique_ptr<RunMetadata> run_metadata_ GUARDED_BY(rm_mu_);
+ mutable std::unique_ptr<mutex> cg_mu_;
+ std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
RunOptions run_options_;
};
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index c37a69a7f7..e814be8630 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -44,6 +44,7 @@ using ops::FIFOQueue;
using ops::QueueClose;
using ops::QueueDequeue;
using ops::QueueEnqueue;
+using ops::RandomNormal;
using ops::Square;
using ops::Variable;
@@ -84,7 +85,7 @@ QueueRunnerDef BuildQueueRunnerDef(
const std::string& close_op, const std::string& cancel_op,
const std::vector<Code>& queue_closed_error_codes) {
QueueRunnerDef queue_runner_def;
- *queue_runner_def.mutable_queue_name() = kQueueName;
+ *queue_runner_def.mutable_queue_name() = queue_name;
for (const std::string& enqueue_op : enqueue_ops) {
*queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
}
@@ -345,37 +346,51 @@ TEST(QueueRunnerTest, CallbackCalledOnError) {
}
TEST(QueueRunnerTest, RunMetaDataTest) {
+ Scope root = Scope::NewRootScope();
+ auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT});
+ Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT);
+ Output square = Square(root.WithOpName(kSquareOpName), rnd);
+ auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square});
+ auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
+ auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
+ QueueClose::CancelPendingEnqueues(true));
+ auto dequeue0 =
+ QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT});
+
+ GraphDef graph_def;
+ TF_EXPECT_OK(root.ToGraphDef(&graph_def));
+ for (auto& node : *graph_def.mutable_node()) {
+ node.set_device("/cpu:0");
+ }
SessionOptions sess_options;
sess_options.config.mutable_graph_options()->set_build_cost_model(1);
std::unique_ptr<Session> session(NewSession(sess_options));
- GraphDef graph_def = BuildSimpleGraph();
TF_CHECK_OK(session->Create(graph_def));
- TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
- RunOptions run_options;
- run_options.set_trace_level(RunOptions::HARDWARE_TRACE);
-
- QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- TF_CHECK_OK(qr->StartAndCollectRunMetadata(session.get(), &run_options));
+ RunOptions run_options;
+ TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), &run_options));
- TF_EXPECT_OK(qr->Join());
- RunMetadata run_metadata;
- TF_CHECK_OK(qr->ExportRunMetadata(&run_metadata));
+ // Make sure there was at least one element enqueued in q0: this prevents a
+ // race condition where we close the queue before it was populated.
+ std::vector<Tensor> dq0;
+ TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
+
+ CostGraphDef cost_graph;
+ TF_CHECK_OK(qr->ExportCostGraph(&cost_graph));
+ EXPECT_TRUE(cost_graph.node_size() > 0);
- EXPECT_TRUE(run_metadata.has_cost_graph());
+ qr->Stop(session.get());
}
TEST(QueueRunnerTest, NoRunMetaDataTest) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);
- RunOptions run_options;
- run_options.set_trace_level(RunOptions::HARDWARE_TRACE);
-
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
std::unique_ptr<QueueRunner> qr;
@@ -383,8 +398,8 @@ TEST(QueueRunnerTest, NoRunMetaDataTest) {
TF_CHECK_OK(qr->Start(session.get()));
TF_EXPECT_OK(qr->Join());
- RunMetadata run_metadata;
- EXPECT_EQ(qr->ExportRunMetadata(&run_metadata).code(),
+ CostGraphDef cost_graph;
+ EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(),
error::FAILED_PRECONDITION);
}