aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-11-07 18:37:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:14:51 -0800
commit6c8e9b059f3747aaee71f6f2e2138bb80120b71e (patch)
tree0a52b5fabc24a1982a751ac6bede65995dcbf2a7 /tensorflow/cc/training
parent8373430ce1fe008b061777324c097731826da68b (diff)
Make Coordinator::RegisterRunner and Coordinator::Join thread-safe.
Change: 138467240
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/coordinator.cc28
-rw-r--r--tensorflow/cc/training/coordinator.h12
-rw-r--r--tensorflow/cc/training/coordinator_test.cc10
3 files changed, 42 insertions, 8 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc
index 254538d778..e1a06123da 100644
--- a/tensorflow/cc/training/coordinator.cc
+++ b/tensorflow/cc/training/coordinator.cc
@@ -36,6 +36,14 @@ Coordinator::~Coordinator() {
}
Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
+ {
+ mutex_lock l(mu_);
+ if (should_stop_) {
+ return Status(error::FAILED_PRECONDITION,
+ "The coordinator has been stopped.");
+ }
+ }
+ mutex_lock l(runners_lock_);
runners_.push_back(std::move(runner));
return Status::OK();
}
@@ -57,13 +65,23 @@ bool Coordinator::ShouldStop() {
}
Status Coordinator::Join() {
- // TODO(yuefengz): deal with unexpected calls to Join().
// TODO(yuefengz): deal with stragglers.
- for (const auto& t : runners_) {
- ReportStatus(t->Join());
+ {
+ mutex_lock l(mu_);
+ if (!should_stop_) {
+ return Status(error::FAILED_PRECONDITION,
+ "Joining coordinator without requesting to stop.");
+ }
}
- runners_.clear();
- return status_;
+
+ {
+ mutex_lock l(runners_lock_);
+ for (const auto& t : runners_) {
+ ReportStatus(t->Join());
+ }
+ runners_.clear();
+ }
+ return GetStatus();
}
void Coordinator::ReportStatus(const Status& status) {
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
index 987d243fbd..1c3f0e3cda 100644
--- a/tensorflow/cc/training/coordinator.h
+++ b/tensorflow/cc/training/coordinator.h
@@ -94,13 +94,19 @@ class Coordinator {
void WaitForStop();
private:
- std::vector<std::unique_ptr<RunnerInterface>> runners_;
std::unordered_set<int> clean_stop_errors_;
+ condition_variable wait_for_stop_;
+
mutex mu_;
bool should_stop_ GUARDED_BY(mu_);
+
mutex status_lock_;
- Status status_;
- condition_variable wait_for_stop_;
+ Status status_ GUARDED_BY(status_lock_);
+
+ mutex runners_lock_;
+ std::vector<std::unique_ptr<RunnerInterface>> runners_
+ GUARDED_BY(runners_lock_);
+
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
index 3bdce5f07f..6870ea65c5 100644
--- a/tensorflow/cc/training/coordinator_test.cc
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -155,6 +155,7 @@ TEST(CoordinatorTest, TestJoin) {
new MockQueueRunner(&coord, &join_counter));
coord.RegisterRunner(std::move(qr2));
+ coord.RequestStop();
TF_EXPECT_OK(coord.Join());
EXPECT_EQ(join_counter, 2);
}
@@ -176,8 +177,17 @@ TEST(CoordinatorTest, StatusReporting) {
coord.RegisterRunner(std::move(qr3));
counter.Wait();
+ coord.RequestStop();
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
}
+TEST(CoordinatorTest, JoinWithoutStop) {
+ Coordinator coord;
+ std::unique_ptr<MockQueueRunner> qr(new MockQueueRunner(&coord));
+ coord.RegisterRunner(std::move(qr));
+
+ EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION);
+}
+
} // namespace
} // namespace tensorflow