aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2017-03-06 23:41:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 23:51:00 -0800
commit2823db46405a8b77a61b9da1f9a13019331e5390 (patch)
treeee283d88f4a0bd2107eeec0bcfa766c00ce0aaac /tensorflow/cc/training
parent3d725349272ca0a5f443ec631374a24474e5a513 (diff)
Make queue runner accept run arguments.
Change: 149388619
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc46
-rw-r--r--tensorflow/cc/training/queue_runner.h18
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc24
3 files changed, 84 insertions, 4 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 1f6794cce2..5bf11a3788 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -82,6 +82,12 @@ QueueRunner::~QueueRunner() {
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
+Status QueueRunner::Start(Session* sess, RunMetadata* metadata, mutex* rm_mu,
+ const RunOptions* run_options) {
+ SetRunArguments(run_options, metadata, rm_mu);
+ return Start(sess, 0);
+}
+
Status QueueRunner::Start(Session* sess, int wait_for) {
counter_.reset(new BlockingCounter(runs_));
for (const string& enqueue_op : enqueue_op_names_) {
@@ -109,12 +115,19 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
return Status::OK();
}
+Status QueueRunner::Start(Session* session, int wait_for_ms,
+ RunMetadata* metadata, mutex* rm_mu,
+ const RunOptions* run_options) {
+ SetRunArguments(run_options, metadata, rm_mu);
+ return Start(session, wait_for_ms);
+}
+
void QueueRunner::Stop(Session* sess) {
if (coord_ != nullptr) {
coord_->WaitForStop();
}
if (!cancel_op_name_.empty()) {
- UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
+ UpdateStatus(RealRun(sess, cancel_op_name_));
}
stopped_ = true;
}
@@ -149,7 +162,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
if (coord_ && coord_->ShouldStop()) {
break;
}
- status = sess->Run({}, {}, {enqueue_op}, nullptr);
+ status = RealRun(sess, enqueue_op);
if (first_iteration) {
if (!status.ok()) {
mutex_lock l(mu_);
@@ -170,7 +183,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// will be run anway in this case.
if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
if (last_run && !close_op_name_.empty()) {
- UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr));
+ UpdateStatus(RealRun(sess, close_op_name_));
}
} else if (!status.ok()) {
UpdateStatus(status);
@@ -185,4 +198,31 @@ Status QueueRunner::GetStatus() {
return status_;
}
+void QueueRunner::SetRunArguments(const RunOptions* run_options,
+ RunMetadata* metadata, mutex* rm_mu) {
+ DCHECK(metadata != nullptr);
+ DCHECK(rm_mu != nullptr);
+ rm_mu_ = rm_mu;
+ {
+ mutex_lock l(*rm_mu_);
+ run_metadata_ = metadata;
+ }
+ if (run_options) {
+ run_options_ = *run_options;
+ }
+}
+
+Status QueueRunner::RealRun(Session* sess, const string& op) {
+ Status s;
+ if (rm_mu_) {
+ RunMetadata metadata;
+ s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
+ mutex_lock l(*rm_mu_);
+ run_metadata_->MergeFrom(metadata);
+ } else {
+ s = sess->Run({}, {}, {op}, nullptr);
+ }
+ return s;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index bfe6a30593..46ee26eec4 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -58,9 +58,16 @@ class QueueRunner : public RunnerInterface {
/// Starts the queue runner with the given session.
Status Start(Session* sess);
+ // Starts the queue runner with the given session and sets the run arguments
+ // for sess->Run. The mutex lock rm_mu is hold when metadata is being changed.
+ Status Start(Session* sess, RunMetadata* metadata, mutex* rm_mu,
+ const RunOptions* run_options = nullptr);
+
/// Starts the queue runner with the given session, and wait for up to the
/// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for_ms);
+ Status Start(Session* session, int wait_for_ms, RunMetadata* metadata,
+ mutex* rm_mu, const RunOptions* run_options = nullptr);
/// Requests to stop and runs the cancel op. It would be called in a separate
/// thread when coordinator is set. If there is no coordinator it should be
@@ -75,7 +82,7 @@ class QueueRunner : public RunnerInterface {
Status GetStatus();
private:
- QueueRunner() : coord_(nullptr), stopped_(false) {}
+ QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {}
// Initializes the instance with the QueueRunnerDef proto.
Status Init(const QueueRunnerDef& queue_runner_def);
@@ -94,6 +101,11 @@ class QueueRunner : public RunnerInterface {
bool IsRunning() const override { return !stopped_; }
+ void SetRunArguments(const RunOptions* run_options, RunMetadata* metadata,
+ mutex* rm_mu);
+
+ Status RealRun(Session* sess, const string& op);
+
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
@@ -114,6 +126,10 @@ class QueueRunner : public RunnerInterface {
mutex cb_mu_;
std::vector<std::function<void(Status)>> callbacks_;
+
+ mutex* rm_mu_;
+ RunMetadata* run_metadata_ GUARDED_BY(rm_mu_);
+ RunOptions run_options_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 27c302ab28..0b6b800058 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -344,5 +344,29 @@ TEST(QueueRunnerTest, CallbackCalledOnError) {
EXPECT_TRUE(error_caught);
}
+TEST(QueueRunnerTest, RunMetaDataTest) {
+ SessionOptions sess_options;
+ sess_options.config.mutable_graph_options()->set_build_cost_model(1);
+ std::unique_ptr<Session> session(NewSession(sess_options));
+
+ GraphDef graph_def = BuildSimpleGraph();
+ TF_CHECK_OK(session->Create(graph_def));
+ TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
+
+ RunOptions run_options;
+ run_options.set_trace_level(RunOptions::HARDWARE_TRACE);
+ RunMetadata run_metadata;
+ mutex mu;
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get(), &run_metadata, &mu, &run_options));
+ TF_EXPECT_OK(qr->Join());
+
+ EXPECT_TRUE(run_metadata.has_cost_graph());
+}
+
} // namespace
} // namespace tensorflow