diff options
author | Benoit Steiner <bsteiner@google.com> | 2016-12-08 14:25:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-08 14:44:44 -0800 |
commit | 8daaae4c89a7da16caec240db582fea6d4d4b512 (patch) | |
tree | 6a8050958496d6d4d6d07ff85958ebfee563c588 /tensorflow/cc/training | |
parent | 056c0877adbc6ce21842f0bcdc6bb62e18f15a03 (diff) |
Added a new AllRunnersStopped() to check if all the runners have been stopped.
Change: 141484062
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/coordinator.cc | 10 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator.h | 9 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator_test.cc | 16 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 8 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 6 |
5 files changed, 44 insertions, 5 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index e1a06123da..53a566db95 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -48,6 +48,16 @@ Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) { return Status::OK(); } +bool Coordinator::AllRunnersStopped() { + mutex_lock l(runners_lock_); + for (const auto& runner : runners_) { + if (runner->IsRunning()) { + return false; + } + } + return true; +} + Status Coordinator::RequestStop() { mutex_lock l(mu_); if (should_stop_) { diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 1c3f0e3cda..e1ef4b0f23 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -32,6 +32,10 @@ class RunnerInterface { public: virtual ~RunnerInterface() {} virtual Status Join() = 0; + + // Returns true iff the runner is running, i.e. if it is trying to populate + // its queue. + virtual bool IsRunning() const = 0; }; // Coordinator class manages the termination of a collection of QueueRunners. @@ -74,6 +78,9 @@ class Coordinator { // supposed to be in running state when they are registered here. Status RegisterRunner(std::unique_ptr<RunnerInterface> runner); + // Returns true iff all the registered runners have been stopped. + bool AllRunnersStopped(); + // Requests all running threads to stop. Status RequestStop(); @@ -107,6 +114,8 @@ class Coordinator { std::vector<std::unique_ptr<RunnerInterface>> runners_ GUARDED_BY(runners_lock_); + std::atomic<int> num_runners_to_cancel_; + TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); }; diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 6870ea65c5..5e4a696690 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -58,6 +58,7 @@ class MockQueueRunner : public RunnerInterface { coord_ = coord; join_counter_ = nullptr; thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10)); + stopped_ = false; } MockQueueRunner(Coordinator* coord, int* join_counter) @@ -87,6 +88,10 @@ class MockQueueRunner : public RunnerInterface { void SetStatus(const Status& status) { status_ = status; } + bool IsRunning() const override { return !stopped_; }; + + void Stop() { stopped_ = true; } + private: void CountThread(std::atomic<int>* counter, int until) { while (!coord_->ShouldStop() && counter->load() < until) { @@ -104,6 +109,7 @@ class MockQueueRunner : public RunnerInterface { Status status_; Coordinator* coord_; int* join_counter_; + bool stopped_; }; TEST(CoordinatorTest, TestRealStop) { @@ -189,5 +195,15 @@ TEST(CoordinatorTest, JoinWithoutStop) { EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION); } +TEST(CoordinatorTest, AllRunnersStopped) { + Coordinator coord; + MockQueueRunner* qr = new MockQueueRunner(&coord); + coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr)); + + EXPECT_FALSE(coord.AllRunnersStopped()); + qr->Stop(); + EXPECT_TRUE(coord.AllRunnersStopped()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index e703a9bb30..cd6cc67327 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -110,13 +110,13 @@ Status QueueRunner::Start(Session* sess, int wait_for) { } void QueueRunner::Stop(Session* sess) { - if (cancel_op_name_.empty()) { - return; - } if (coord_ != nullptr) { coord_->WaitForStop(); } - UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr)); + if (!cancel_op_name_.empty()) { + UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr)); + } + stopped_ = true; } Status QueueRunner::Join() { diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index fd9f97a958..e5aae8219f 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -75,7 +75,7 @@ class QueueRunner : public RunnerInterface { Status GetStatus(); private: - QueueRunner() : coord_(nullptr) {} + QueueRunner() : coord_(nullptr), stopped_(false) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -92,6 +92,8 @@ class QueueRunner : public RunnerInterface { static_cast<int>(status.code())) > 0; } + bool IsRunning() const override { return !stopped_; } + string queue_name_; std::vector<string> enqueue_op_names_; string close_op_name_; @@ -108,6 +110,8 @@ class QueueRunner : public RunnerInterface { Coordinator* coord_; + std::atomic<bool> stopped_; + mutex cb_mu_; std::vector<std::function<void(Status)>> callbacks_; }; |