aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-10-31 16:18:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 17:36:05 -0700
commit5de9d3c392d5531eb3bbcefd007fcc25db7448cd (patch)
tree73525f3255e34689b309002a87f01979e0a9ac27 /tensorflow/cc/training
parent5399bfaebfb4666b01ee6afc76dbf0455731bddd (diff)
Added the ability to wait for queues to start running before returning from
QueueRunner::Start(). This provides a reliable way to check the value of the status_ variable. Change: 137769682
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc32
-rw-r--r--tensorflow/cc/training/queue_runner.h9
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc17
3 files changed, 56 insertions, 2 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index ed1d0a5da0..6435e88d48 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -48,6 +48,7 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix(queue_name_), runs_));
should_stop_ = false;
+
return Status::OK();
}
@@ -57,11 +58,29 @@ QueueRunner::~QueueRunner() {
Join();
}
-Status QueueRunner::Start(Session* sess) {
+Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
+
+Status QueueRunner::Start(Session* sess, int wait_for) {
+ counter_.reset(new BlockingCounter(runs_));
for (const string& enqueue_op : enqueue_op_names_) {
thread_pool_->Schedule(
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
}
+ // Wait for up to 'wait_for' milliseconds.
+ if (wait_for > 0) {
+ if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
+ return Status(error::DEADLINE_EXCEEDED,
+ "Queues not fed before the timeout");
+ }
+ // Check the status of the queue runner as well as the result of the enqueue
+ // operations.
+ mutex_lock l(mu_);
+ if (!enqueue_status_.ok()) {
+ return enqueue_status_;
+ } else {
+ return status_;
+ }
+ }
return Status::OK();
}
@@ -76,13 +95,23 @@ Status QueueRunner::Stop(Session* sess) {
Status QueueRunner::Join() {
thread_pool_.reset();
+ mutex_lock l(mu_);
return status_;
}
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
bool decremented = false;
+ bool first_iteration = true;
while (!should_stop_.load()) {
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
+ if (first_iteration) {
+ if (!status.ok()) {
+ mutex_lock l(mu_);
+ enqueue_status_ = status;
+ }
+ counter_->DecrementCount();
+ first_iteration = false;
+ }
if (status.ok()) {
continue;
} else if (queue_closed_exception_types_.count(
@@ -114,6 +143,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// subsequent queues.
Stop(sess);
}
+ first_iteration = false;
}
if (!decremented) {
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 9374fe3605..3affe45e55 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -46,6 +47,10 @@ class QueueRunner {
// Starts the queue runner with the given session.
Status Start(Session* sess);
+ // Starts the queue runner with the given session, and wait for up to the
+ // specified time (in milliseconds) for the queues to start to fill up.
+ Status Start(Session* sess, int wait_for);
+
// Requests to stop and runs the cancel op.
Status Stop(Session* sess);
@@ -78,7 +83,9 @@ class QueueRunner {
mutex mu_;
// TODO(yuefengz): implement c++ coordinator.
int runs_ = 0;
- Status status_;
+ Status status_ GUARDED_BY(mu_);
+ Status enqueue_status_ GUARDED_BY(mu_);
+ std::unique_ptr<BlockingCounter> counter_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 0d06c62056..df48b25543 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -317,5 +317,22 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) {
Code::INVALID_ARGUMENT);
}
+TEST(QueueRunnerTest, StartTimeout) {
+ GraphDef graph_def = BuildDoubleQueueGraph();
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
+
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ // This will timeout since queue0 is not fed and queue1 is fetching data from
+ // queue0.
+ EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
+ session->Close();
+}
+
} // namespace
} // namespace tensorflow