aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/training/coordinator_test.cc53
1 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
index 79f2a955d5..a87913deaf 100644
--- a/tensorflow/cc/training/coordinator_test.cc
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -29,9 +29,10 @@ namespace {
using error::Code;
-void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
+void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
+ Notification* done) {
+ about_to_wait->Notify();
coord->WaitForStop();
- *stopped = true;
done->Notify();
}
@@ -39,17 +40,17 @@ TEST(CoordinatorTest, TestStopAndWaitOnStop) {
Coordinator coord;
EXPECT_EQ(coord.ShouldStop(), false);
- bool stopped = false;
+ Notification about_to_wait;
Notification done;
Env::Default()->SchedClosure(
- std::bind(&WaitForStopThread, &coord, &stopped, &done));
- Env::Default()->SleepForMicroseconds(10000000);
- EXPECT_EQ(stopped, false);
+ std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
+ about_to_wait.WaitForNotification();
+ Env::Default()->SleepForMicroseconds(1000 * 1000);
+ EXPECT_FALSE(done.HasBeenNotified());
TF_EXPECT_OK(coord.RequestStop());
done.WaitForNotification();
- EXPECT_EQ(stopped, true);
- EXPECT_EQ(coord.ShouldStop(), true);
+ EXPECT_TRUE(coord.ShouldStop());
}
class MockQueueRunner : public RunnerInterface {
@@ -66,14 +67,16 @@ class MockQueueRunner : public RunnerInterface {
join_counter_ = join_counter;
}
- void StartCounting(std::atomic<int>* counter, int until) {
+ void StartCounting(std::atomic<int>* counter, int until,
+ Notification* start = nullptr) {
thread_pool_->Schedule(
- std::bind(&MockQueueRunner::CountThread, this, counter, until));
+ std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
}
- void StartSettingStatus(const Status& status, BlockingCounter* counter) {
- thread_pool_->Schedule(
- std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
+ void StartSettingStatus(const Status& status, BlockingCounter* counter,
+ Notification* start) {
+ thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
+ status, counter, start));
}
Status Join() {
@@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface {
void Stop() { stopped_ = true; }
private:
- void CountThread(std::atomic<int>* counter, int until) {
+ void CountThread(std::atomic<int>* counter, int until, Notification* start) {
+ if (start != nullptr) start->WaitForNotification();
while (!coord_->ShouldStop() && counter->load() < until) {
(*counter)++;
- Env::Default()->SleepForMicroseconds(100000);
+ Env::Default()->SleepForMicroseconds(10 * 1000);
}
coord_->RequestStop().IgnoreError();
}
- void SetStatusThread(const Status& status, BlockingCounter* counter) {
- Env::Default()->SleepForMicroseconds(100000);
+ void SetStatusThread(const Status& status, BlockingCounter* counter,
+ Notification* start) {
+ start->WaitForNotification();
SetStatus(status);
counter->DecrementCount();
}
@@ -130,7 +135,7 @@ TEST(CoordinatorTest, TestRealStop) {
TF_EXPECT_OK(coord.RequestStop());
int temp_counter = counter.load();
- Env::Default()->SleepForMicroseconds(10000000);
+ Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_EQ(temp_counter, counter.load());
TF_EXPECT_OK(coord.Join());
}
@@ -138,12 +143,14 @@ TEST(CoordinatorTest, TestRealStop) {
TEST(CoordinatorTest, TestRequestStop) {
Coordinator coord;
std::atomic<int> counter(0);
+ Notification start;
std::unique_ptr<MockQueueRunner> qr;
for (int i = 0; i < 10; i++) {
qr.reset(new MockQueueRunner(&coord));
- qr->StartCounting(&counter, 10);
+ qr->StartCounting(&counter, 10, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
}
+ start.Notify();
coord.WaitForStop();
EXPECT_EQ(coord.ShouldStop(), true);
@@ -168,20 +175,22 @@ TEST(CoordinatorTest, TestJoin) {
TEST(CoordinatorTest, StatusReporting) {
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
+ Notification start;
BlockingCounter counter(3);
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
- qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
+ qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
- qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
+ qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
- qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
+ qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));
+ start.Notify();
counter.Wait();
TF_EXPECT_OK(coord.RequestStop());
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);