aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/coordinator_test.py
blob: ce9126caf4764c636617f68a45a4f49b5238070a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Tests for Coordinator."""
import sys
import threading
import time

import tensorflow.python.platform

import tensorflow as tf


def StopInN(coord, n_secs):
  time.sleep(n_secs)
  coord.request_stop()


def RaiseInN(coord, n_secs, ex, report_exception):
  try:
    time.sleep(n_secs)
    raise ex
  except RuntimeError, e:
    if report_exception:
      coord.request_stop(e)
    else:
      coord.request_stop(sys.exc_info())


def SleepABit(n_secs):
  time.sleep(n_secs)


class CoordinatorTest(tf.test.TestCase):

  def testStopAPI(self):
    coord = tf.train.Coordinator()
    self.assertFalse(coord.should_stop())
    self.assertFalse(coord.wait_for_stop(0.01))
    coord.request_stop()
    self.assertTrue(coord.should_stop())
    self.assertTrue(coord.wait_for_stop(0.01))

  def testStopAsync(self):
    coord = tf.train.Coordinator()
    self.assertFalse(coord.should_stop())
    self.assertFalse(coord.wait_for_stop(0.1))
    threading.Thread(target=StopInN, args=(coord, 0.02)).start()
    self.assertFalse(coord.should_stop())
    self.assertFalse(coord.wait_for_stop(0.01))
    self.assertTrue(coord.wait_for_stop(0.03))
    self.assertTrue(coord.should_stop())

  def testJoin(self):
    coord = tf.train.Coordinator()
    threads = [
        threading.Thread(target=SleepABit, args=(0.01,)),
        threading.Thread(target=SleepABit, args=(0.02,)),
        threading.Thread(target=SleepABit, args=(0.01,))]
    for t in threads:
      t.start()
    coord.join(threads)

  def testJoinGraceExpires(self):
    coord = tf.train.Coordinator()
    threads = [
        threading.Thread(target=StopInN, args=(coord, 0.01)),
        threading.Thread(target=SleepABit, args=(10.0,))]
    for t in threads:
      t.daemon = True
      t.start()
    with self.assertRaisesRegexp(RuntimeError, "threads still running"):
      coord.join(threads, stop_grace_period_secs=0.02)

  def testJoinRaiseReportExcInfo(self):
    coord = tf.train.Coordinator()
    threads = [
        threading.Thread(target=RaiseInN,
                         args=(coord, 0.01, RuntimeError("First"), False)),
        threading.Thread(target=RaiseInN,
                         args=(coord, 0.02, RuntimeError("Too late"), False))]
    for t in threads:
      t.start()
    with self.assertRaisesRegexp(RuntimeError, "First"):
      coord.join(threads)

  def testJoinRaiseReportException(self):
    coord = tf.train.Coordinator()
    threads = [
        threading.Thread(target=RaiseInN,
                         args=(coord, 0.01, RuntimeError("First"), True)),
        threading.Thread(target=RaiseInN,
                         args=(coord, 0.02, RuntimeError("Too late"), True))]
    for t in threads:
      t.start()
    with self.assertRaisesRegexp(RuntimeError, "First"):
      coord.join(threads)


if __name__ == "__main__":
  tf.test.main()