diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2017-03-06 23:41:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-06 23:51:00 -0800 |
commit | 2823db46405a8b77a61b9da1f9a13019331e5390 (patch) | |
tree | ee283d88f4a0bd2107eeec0bcfa766c00ce0aaac /tensorflow/cc/training | |
parent | 3d725349272ca0a5f443ec631374a24474e5a513 (diff) |
Make queue runner accept run arguments.
Change: 149388619
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 46 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 18 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 24 |
3 files changed, 84 insertions, 4 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 1f6794cce2..5bf11a3788 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -82,6 +82,12 @@ QueueRunner::~QueueRunner() { Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } +Status QueueRunner::Start(Session* sess, RunMetadata* metadata, mutex* rm_mu, + const RunOptions* run_options) { + SetRunArguments(run_options, metadata, rm_mu); + return Start(sess, 0); +} + Status QueueRunner::Start(Session* sess, int wait_for) { counter_.reset(new BlockingCounter(runs_)); for (const string& enqueue_op : enqueue_op_names_) { @@ -109,12 +115,19 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } +Status QueueRunner::Start(Session* session, int wait_for_ms, + RunMetadata* metadata, mutex* rm_mu, + const RunOptions* run_options) { + SetRunArguments(run_options, metadata, rm_mu); + return Start(session, wait_for_ms); +} + void QueueRunner::Stop(Session* sess) { if (coord_ != nullptr) { coord_->WaitForStop(); } if (!cancel_op_name_.empty()) { - UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr)); + UpdateStatus(RealRun(sess, cancel_op_name_)); } stopped_ = true; } @@ -149,7 +162,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { if (coord_ && coord_->ShouldStop()) { break; } - status = sess->Run({}, {}, {enqueue_op}, nullptr); + status = RealRun(sess, enqueue_op); if (first_iteration) { if (!status.ok()) { mutex_lock l(mu_); @@ -170,7 +183,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // will be run anway in this case. if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { if (last_run && !close_op_name_.empty()) { - UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr)); + UpdateStatus(RealRun(sess, close_op_name_)); } } else if (!status.ok()) { UpdateStatus(status); @@ -185,4 +198,31 @@ Status QueueRunner::GetStatus() { return status_; } +void QueueRunner::SetRunArguments(const RunOptions* run_options, + RunMetadata* metadata, mutex* rm_mu) { + DCHECK(metadata != nullptr); + DCHECK(rm_mu != nullptr); + rm_mu_ = rm_mu; + { + mutex_lock l(*rm_mu_); + run_metadata_ = metadata; + } + if (run_options) { + run_options_ = *run_options; + } +} + +Status QueueRunner::RealRun(Session* sess, const string& op) { + Status s; + if (rm_mu_) { + RunMetadata metadata; + s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); + mutex_lock l(*rm_mu_); + run_metadata_->MergeFrom(metadata); + } else { + s = sess->Run({}, {}, {op}, nullptr); + } + return s; +} + } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index bfe6a30593..46ee26eec4 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -58,9 +58,16 @@ class QueueRunner : public RunnerInterface { /// Starts the queue runner with the given session. Status Start(Session* sess); + // Starts the queue runner with the given session and sets the run arguments + // for sess->Run. The mutex lock rm_mu is hold when metadata is being changed. + Status Start(Session* sess, RunMetadata* metadata, mutex* rm_mu, + const RunOptions* run_options = nullptr); + /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); + Status Start(Session* session, int wait_for_ms, RunMetadata* metadata, + mutex* rm_mu, const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -75,7 +82,7 @@ class QueueRunner : public RunnerInterface { Status GetStatus(); private: - QueueRunner() : coord_(nullptr), stopped_(false) {} + QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -94,6 +101,11 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } + void SetRunArguments(const RunOptions* run_options, RunMetadata* metadata, + mutex* rm_mu); + + Status RealRun(Session* sess, const string& op); + string queue_name_; std::vector<string> enqueue_op_names_; string close_op_name_; @@ -114,6 +126,10 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector<std::function<void(Status)>> callbacks_; + + mutex* rm_mu_; + RunMetadata* run_metadata_ GUARDED_BY(rm_mu_); + RunOptions run_options_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 27c302ab28..0b6b800058 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -344,5 +344,29 @@ TEST(QueueRunnerTest, CallbackCalledOnError) { EXPECT_TRUE(error_caught); } +TEST(QueueRunnerTest, RunMetaDataTest) { + SessionOptions sess_options; + sess_options.config.mutable_graph_options()->set_build_cost_model(1); + std::unique_ptr<Session> session(NewSession(sess_options)); + + GraphDef graph_def = BuildSimpleGraph(); + TF_CHECK_OK(session->Create(graph_def)); + TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr)); + + RunOptions run_options; + run_options.set_trace_level(RunOptions::HARDWARE_TRACE); + RunMetadata run_metadata; + mutex mu; + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get(), &run_metadata, &mu, &run_options)); + TF_EXPECT_OK(qr->Join()); + + EXPECT_TRUE(run_metadata.has_cost_graph()); +} + } // namespace } // namespace tensorflow |