aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-11-04 12:33:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-04 13:43:40 -0700
commit8f9493858a5a89fcda5ace0b86b0f964dc470d20 (patch)
tree58ab6ea7040d84243cae6807eb91473211dcd15a /tensorflow/cc/training
parent879e0accd1c833771c8058d3eb5f2d4f06f895d4 (diff)
C++ QueueRunner: Bug fixes.
Three bug fixes: (1) There was a thread-unsafe access to runs_ which could result in the queue close operation being invoked multiple times. (2) The Run() loop would not exit when there were multiple threads and the queue was closed (i.e., the enqueue failed with a queue_closed_exception_types_ error). Without this fix, the changed QueueRunnerTest.QueueCloseCode test would fail with a timeout since qr->Join() would be blocked on the never-exiting Run() call (3) Errors in invoking the close operation were being ignored. Without this fix, the added QueueRunnerTest.QueueCloseFails test would fail as Join() would return OK instead of NOT_FOUND Two other minor changes: - Slight simplification to QueueRunner::Run() so that runs_ is manipulated only once and the body of the loop is clearer - Avoid starting an extra thread which will not be used when there is no Coordinator. (Though in practice I suppose we always intend to have a coordinator). Change: 138228243
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc66
-rw-r--r--tensorflow/cc/training/queue_runner.h11
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc24
3 files changed, 55 insertions, 46 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index bc48a41ff5..10ff80c9cd 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -53,8 +53,13 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
}
}
+ int nthreads = runs_;
+ if (coord_) {
+ // One more thread to call Stop()
+ nthreads++;
+ }
thread_pool_.reset(new thread::ThreadPool(
- Env::Default(), SanitizeThreadSuffix(queue_name_), runs_ + 1));
+ Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
return Status::OK();
}
@@ -95,13 +100,14 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
}
void QueueRunner::Stop(Session* sess) {
+ DCHECK(coord_ != nullptr);
if (cancel_op_name_.empty()) {
return;
- } else {
- CHECK(coord_ != nullptr);
+ }
+ if (coord_ != nullptr) {
coord_->WaitForStop();
- UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
}
+ UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
}
Status QueueRunner::Join() {
@@ -113,9 +119,7 @@ Status QueueRunner::Join() {
void QueueRunner::UpdateStatus(const Status& status) {
{
mutex_lock l(mu_);
- if (!status_.ok() || status.ok() ||
- queue_closed_exception_types_.count(static_cast<int>(status.code())) >
- 0) {
+ if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
return;
}
status_ = status;
@@ -126,13 +130,13 @@ void QueueRunner::UpdateStatus(const Status& status) {
}
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
- bool decremented = false;
bool first_iteration = true;
- while (true) {
+ Status status;
+ while (status.ok()) {
if (coord_ && coord_->ShouldStop()) {
break;
}
- auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
+ status = sess->Run({}, {}, {enqueue_op}, nullptr);
if (first_iteration) {
if (!status.ok()) {
mutex_lock l(mu_);
@@ -141,37 +145,23 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
counter_->DecrementCount();
first_iteration = false;
}
- if (status.ok()) {
- continue;
- } else if (queue_closed_exception_types_.count(
- static_cast<int>(status.code())) > 0) {
- {
- mutex_lock l(mu_);
- runs_--;
- decremented = true;
- }
-
- // If all enqueue ops have finished, run the close op.
- if (runs_ == 0) {
- if (!close_op_name_.empty()) {
- auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
- UpdateStatus(status);
- }
- break;
- }
- } else {
- UpdateStatus(status);
- if (coord_) {
- coord_->RequestStop();
- }
- break;
- }
- first_iteration = false;
}
-
- if (!decremented) {
+ bool last_run = false;
+ {
mutex_lock l(mu_);
runs_--;
+ last_run = (runs_ == 0);
+ }
+
+ if (IsQueueClosed(status)) {
+ if (last_run && !close_op_name_.empty()) {
+ UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr));
+ }
+ } else if (!status.ok()) {
+ UpdateStatus(status);
+ if (coord_) {
+ coord_->RequestStop();
+ }
}
}
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 01dd745951..273eb39671 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -54,13 +54,13 @@ class QueueRunner : public RunnerInterface {
// Starts the queue runner with the given session, and wait for up to the
// specified time (in milliseconds) for the queues to start to fill up.
- Status Start(Session* sess, int wait_for);
+ Status Start(Session* sess, int wait_for_ms);
// Joins all the threads. Returns okay if all threads run successfully;
// otherwise returns the first captured failure status.
Status Join() final;
- // Returns the lastest status.
+ // Returns the latest status.
Status GetStatus();
private:
@@ -80,6 +80,11 @@ class QueueRunner : public RunnerInterface {
// status.
void UpdateStatus(const Status& status);
+ bool IsQueueClosed(Status status) const {
+ return queue_closed_exception_types_.count(
+ static_cast<int>(status.code())) > 0;
+ }
+
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
@@ -88,9 +93,7 @@ class QueueRunner : public RunnerInterface {
std::unordered_set<int> queue_closed_exception_types_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- condition_variable wait_to_close_;
mutex mu_;
- // TODO(yuefengz): implement c++ coordinator.
int runs_ = 0;
Status status_ GUARDED_BY(mu_);
Status enqueue_status_ GUARDED_BY(mu_);
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 73ea5a307f..0e7e94b40f 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -129,9 +129,10 @@ TEST(QueueRunnerTest, QueueClosedCode) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "",
- {Code::OUT_OF_RANGE, Code::CANCELLED});
+ // Start two queues so that multiple threads are in Run.
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "",
+ {Code::OUT_OF_RANGE, Code::CANCELLED});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
@@ -144,7 +145,22 @@ TEST(QueueRunnerTest, QueueClosedCode) {
EXPECT_EQ(square_value, 100);
}
-TEST(QueueRunnerDef, CatchErrorInJoin) {
+TEST(QueueRunnerTest, QueueCloseFails) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kIllegalOpName1, "",
+ {Code::OUT_OF_RANGE});
+
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_EXPECT_OK(qr->Start(session.get()));
+ auto status = qr->Join();
+ EXPECT_EQ(status.code(), Code::NOT_FOUND) << status;
+}
+
+TEST(QueueRunnerTest, CatchErrorInJoin) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);