diff options
author | Benoit Steiner <bsteiner@google.com> | 2016-10-31 16:18:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-31 17:36:05 -0700 |
commit | 5de9d3c392d5531eb3bbcefd007fcc25db7448cd (patch) | |
tree | 73525f3255e34689b309002a87f01979e0a9ac27 /tensorflow/cc/training | |
parent | 5399bfaebfb4666b01ee6afc76dbf0455731bddd (diff) |
Added the ability to wait for queues to start running before returning from
QueueRunner::Start(). This provides a reliable way to check the value of the status_ variable.
Change: 137769682
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 32 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 9 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 17 |
3 files changed, 56 insertions, 2 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index ed1d0a5da0..6435e88d48 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -48,6 +48,7 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { thread_pool_.reset(new thread::ThreadPool( Env::Default(), SanitizeThreadSuffix(queue_name_), runs_)); should_stop_ = false; + return Status::OK(); } @@ -57,11 +58,29 @@ QueueRunner::~QueueRunner() { Join(); } -Status QueueRunner::Start(Session* sess) { +Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } + +Status QueueRunner::Start(Session* sess, int wait_for) { + counter_.reset(new BlockingCounter(runs_)); for (const string& enqueue_op : enqueue_op_names_) { thread_pool_->Schedule( std::bind(&QueueRunner::Run, this, sess, enqueue_op)); } + // Wait for up to 'wait_for' milliseconds. + if (wait_for > 0) { + if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) { + return Status(error::DEADLINE_EXCEEDED, + "Queues not fed before the timeout"); + } + // Check the status of the queue runner as well as the result of the enqueue + // operations. + mutex_lock l(mu_); + if (!enqueue_status_.ok()) { + return enqueue_status_; + } else { + return status_; + } + } return Status::OK(); } @@ -76,13 +95,23 @@ Status QueueRunner::Stop(Session* sess) { Status QueueRunner::Join() { thread_pool_.reset(); + mutex_lock l(mu_); return status_; } void QueueRunner::Run(Session* sess, const string& enqueue_op) { bool decremented = false; + bool first_iteration = true; while (!should_stop_.load()) { auto status = sess->Run({}, {}, {enqueue_op}, nullptr); + if (first_iteration) { + if (!status.ok()) { + mutex_lock l(mu_); + enqueue_status_ = status; + } + counter_->DecrementCount(); + first_iteration = false; + } if (status.ok()) { continue; } else if (queue_closed_exception_types_.count( @@ -114,6 +143,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // subsequent queues. Stop(sess); } + first_iteration = false; } if (!decremented) { diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 9374fe3605..3affe45e55 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -46,6 +47,10 @@ class QueueRunner { // Starts the queue runner with the given session. Status Start(Session* sess); + // Starts the queue runner with the given session, and wait for up to the + // specified time (in milliseconds) for the queues to start to fill up. + Status Start(Session* sess, int wait_for); + // Requests to stop and runs the cancel op. Status Stop(Session* sess); @@ -78,7 +83,9 @@ class QueueRunner { mutex mu_; // TODO(yuefengz): implement c++ coordinator. int runs_ = 0; - Status status_; + Status status_ GUARDED_BY(mu_); + Status enqueue_status_ GUARDED_BY(mu_); + std::unique_ptr<BlockingCounter> counter_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 0d06c62056..df48b25543 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -317,5 +317,22 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) { Code::INVALID_ARGUMENT); } +TEST(QueueRunnerTest, StartTimeout) { + GraphDef graph_def = BuildDoubleQueueGraph(); + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); + + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + // This will timeout since queue0 is not fed and queue1 is fetching data from + // queue0. + EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); + session->Close(); +} + } // namespace } // namespace tensorflow |