aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/core/notification_test.cc
blob: a9e8942f05b52529810de7dea2842b3284390846 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <gtest/gtest.h>

#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/port.h"

namespace tensorflow {
namespace {

TEST(NotificationTest, TestSingleNotification) {
  thread::ThreadPool* thread_pool =
      new thread::ThreadPool(Env::Default(), "test", 1);

  int counter = 0;
  Notification start;
  Notification proceed;
  thread_pool->Schedule([&start, &proceed, &counter] {
    start.Notify();
    proceed.WaitForNotification();
    ++counter;
  });

  // Wait for the thread to start
  start.WaitForNotification();

  // The thread should be waiting for the 'proceed' notification.
  EXPECT_EQ(0, counter);

  // Unblock the thread
  proceed.Notify();

  delete thread_pool;  // Wait for closure to finish.

  // Verify the counter has been incremented
  EXPECT_EQ(1, counter);
}

TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) {
  const int num_closures = 4;
  thread::ThreadPool* thread_pool =
      new thread::ThreadPool(Env::Default(), "test", num_closures);

  mutex lock;
  int counter = 0;
  Notification n;

  for (int i = 0; i < num_closures; ++i) {
    thread_pool->Schedule([&n, &lock, &counter] {
      n.WaitForNotification();
      mutex_lock l(lock);
      ++counter;
    });
  }
  sleep(1);

  EXPECT_EQ(0, counter);

  n.Notify();
  delete thread_pool;  // Wait for all closures to finish.
  EXPECT_EQ(4, counter);
}

}  // namespace
}  // namespace tensorflow