aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-12-08 14:25:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 14:44:44 -0800
commit8daaae4c89a7da16caec240db582fea6d4d4b512 (patch)
tree6a8050958496d6d4d6d07ff85958ebfee563c588 /tensorflow/cc/training
parent056c0877adbc6ce21842f0bcdc6bb62e18f15a03 (diff)
Added a new AllRunnersStopped() to check if all the runners have been stopped.
Change: 141484062
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/coordinator.cc10
-rw-r--r--tensorflow/cc/training/coordinator.h9
-rw-r--r--tensorflow/cc/training/coordinator_test.cc16
-rw-r--r--tensorflow/cc/training/queue_runner.cc8
-rw-r--r--tensorflow/cc/training/queue_runner.h6
5 files changed, 44 insertions, 5 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc
index e1a06123da..53a566db95 100644
--- a/tensorflow/cc/training/coordinator.cc
+++ b/tensorflow/cc/training/coordinator.cc
@@ -48,6 +48,16 @@ Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
return Status::OK();
}
+bool Coordinator::AllRunnersStopped() {
+ mutex_lock l(runners_lock_);
+ for (const auto& runner : runners_) {
+ if (runner->IsRunning()) {
+ return false;
+ }
+ }
+ return true;
+}
+
Status Coordinator::RequestStop() {
mutex_lock l(mu_);
if (should_stop_) {
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
index 1c3f0e3cda..e1ef4b0f23 100644
--- a/tensorflow/cc/training/coordinator.h
+++ b/tensorflow/cc/training/coordinator.h
@@ -32,6 +32,10 @@ class RunnerInterface {
public:
virtual ~RunnerInterface() {}
virtual Status Join() = 0;
+
+ // Returns true iff the runner is running, i.e. if it is trying to populate
+ // its queue.
+ virtual bool IsRunning() const = 0;
};
// Coordinator class manages the termination of a collection of QueueRunners.
@@ -74,6 +78,9 @@ class Coordinator {
// supposed to be in running state when they are registered here.
Status RegisterRunner(std::unique_ptr<RunnerInterface> runner);
+ // Returns true iff all the registered runners have been stopped.
+ bool AllRunnersStopped();
+
// Requests all running threads to stop.
Status RequestStop();
@@ -107,6 +114,8 @@ class Coordinator {
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
+ std::atomic<int> num_runners_to_cancel_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
index 6870ea65c5..5e4a696690 100644
--- a/tensorflow/cc/training/coordinator_test.cc
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -58,6 +58,7 @@ class MockQueueRunner : public RunnerInterface {
coord_ = coord;
join_counter_ = nullptr;
thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
+ stopped_ = false;
}
MockQueueRunner(Coordinator* coord, int* join_counter)
@@ -87,6 +88,10 @@ class MockQueueRunner : public RunnerInterface {
void SetStatus(const Status& status) { status_ = status; }
+ bool IsRunning() const override { return !stopped_; };
+
+ void Stop() { stopped_ = true; }
+
private:
void CountThread(std::atomic<int>* counter, int until) {
while (!coord_->ShouldStop() && counter->load() < until) {
@@ -104,6 +109,7 @@ class MockQueueRunner : public RunnerInterface {
Status status_;
Coordinator* coord_;
int* join_counter_;
+ bool stopped_;
};
TEST(CoordinatorTest, TestRealStop) {
@@ -189,5 +195,15 @@ TEST(CoordinatorTest, JoinWithoutStop) {
EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION);
}
+TEST(CoordinatorTest, AllRunnersStopped) {
+ Coordinator coord;
+ MockQueueRunner* qr = new MockQueueRunner(&coord);
+ coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr));
+
+ EXPECT_FALSE(coord.AllRunnersStopped());
+ qr->Stop();
+ EXPECT_TRUE(coord.AllRunnersStopped());
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index e703a9bb30..cd6cc67327 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -110,13 +110,13 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
}
void QueueRunner::Stop(Session* sess) {
- if (cancel_op_name_.empty()) {
- return;
- }
if (coord_ != nullptr) {
coord_->WaitForStop();
}
- UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
+ if (!cancel_op_name_.empty()) {
+ UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
+ }
+ stopped_ = true;
}
Status QueueRunner::Join() {
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index fd9f97a958..e5aae8219f 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -75,7 +75,7 @@ class QueueRunner : public RunnerInterface {
Status GetStatus();
private:
- QueueRunner() : coord_(nullptr) {}
+ QueueRunner() : coord_(nullptr), stopped_(false) {}
// Initializes the instance with the QueueRunnerDef proto.
Status Init(const QueueRunnerDef& queue_runner_def);
@@ -92,6 +92,8 @@ class QueueRunner : public RunnerInterface {
static_cast<int>(status.code())) > 0;
}
+ bool IsRunning() const override { return !stopped_; }
+
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
@@ -108,6 +110,8 @@ class QueueRunner : public RunnerInterface {
Coordinator* coord_;
+ std::atomic<bool> stopped_;
+
mutex cb_mu_;
std::vector<std::function<void(Status)>> callbacks_;
};