diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2016-11-07 18:37:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-08 16:14:51 -0800 |
commit | 6c8e9b059f3747aaee71f6f2e2138bb80120b71e (patch) | |
tree | 0a52b5fabc24a1982a751ac6bede65995dcbf2a7 /tensorflow/cc/training | |
parent | 8373430ce1fe008b061777324c097731826da68b (diff) |
Make Coordinator::RegisterRunner and Coordinator::Join thread-safe.
Change: 138467240
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/coordinator.cc | 28 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator.h | 12 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator_test.cc | 10 |
3 files changed, 42 insertions, 8 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 254538d778..e1a06123da 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -36,6 +36,14 @@ Coordinator::~Coordinator() { } Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) { + { + mutex_lock l(mu_); + if (should_stop_) { + return Status(error::FAILED_PRECONDITION, + "The coordinator has been stopped."); + } + } + mutex_lock l(runners_lock_); runners_.push_back(std::move(runner)); return Status::OK(); } @@ -57,13 +65,23 @@ bool Coordinator::ShouldStop() { } Status Coordinator::Join() { - // TODO(yuefengz): deal with unexpected calls to Join(). // TODO(yuefengz): deal with stragglers. - for (const auto& t : runners_) { - ReportStatus(t->Join()); + { + mutex_lock l(mu_); + if (!should_stop_) { + return Status(error::FAILED_PRECONDITION, + "Joining coordinator without requesting to stop."); + } } - runners_.clear(); - return status_; + + { + mutex_lock l(runners_lock_); + for (const auto& t : runners_) { + ReportStatus(t->Join()); + } + runners_.clear(); + } + return GetStatus(); } void Coordinator::ReportStatus(const Status& status) { diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 987d243fbd..1c3f0e3cda 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -94,13 +94,19 @@ class Coordinator { void WaitForStop(); private: - std::vector<std::unique_ptr<RunnerInterface>> runners_; std::unordered_set<int> clean_stop_errors_; + condition_variable wait_for_stop_; + mutex mu_; bool should_stop_ GUARDED_BY(mu_); + mutex status_lock_; - Status status_; - condition_variable wait_for_stop_; + Status status_ GUARDED_BY(status_lock_); + + mutex runners_lock_; + std::vector<std::unique_ptr<RunnerInterface>> runners_ + GUARDED_BY(runners_lock_); + TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); }; diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 3bdce5f07f..6870ea65c5 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -155,6 +155,7 @@ TEST(CoordinatorTest, TestJoin) { new MockQueueRunner(&coord, &join_counter)); coord.RegisterRunner(std::move(qr2)); + coord.RequestStop(); TF_EXPECT_OK(coord.Join()); EXPECT_EQ(join_counter, 2); } @@ -176,8 +177,17 @@ TEST(CoordinatorTest, StatusReporting) { coord.RegisterRunner(std::move(qr3)); counter.Wait(); + coord.RequestStop(); EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT); } +TEST(CoordinatorTest, JoinWithoutStop) { + Coordinator coord; + std::unique_ptr<MockQueueRunner> qr(new MockQueueRunner(&coord)); + coord.RegisterRunner(std::move(qr)); + + EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION); +} + } // namespace } // namespace tensorflow |