aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/coordinator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/coordinator_test.py')
-rw-r--r--tensorflow/python/training/coordinator_test.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
new file mode 100644
index 0000000000..ce9126caf4
--- /dev/null
+++ b/tensorflow/python/training/coordinator_test.py
@@ -0,0 +1,98 @@
+"""Tests for Coordinator."""
+import sys
+import threading
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+def StopInN(coord, n_secs):
+ time.sleep(n_secs)
+ coord.request_stop()
+
+
+def RaiseInN(coord, n_secs, ex, report_exception):
+ try:
+ time.sleep(n_secs)
+ raise ex
+ except RuntimeError, e:
+ if report_exception:
+ coord.request_stop(e)
+ else:
+ coord.request_stop(sys.exc_info())
+
+
+def SleepABit(n_secs):
+ time.sleep(n_secs)
+
+
+class CoordinatorTest(tf.test.TestCase):
+
+ def testStopAPI(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ coord.request_stop()
+ self.assertTrue(coord.should_stop())
+ self.assertTrue(coord.wait_for_stop(0.01))
+
+ def testStopAsync(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.1))
+ threading.Thread(target=StopInN, args=(coord, 0.02)).start()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ self.assertTrue(coord.wait_for_stop(0.03))
+ self.assertTrue(coord.should_stop())
+
+ def testJoin(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=SleepABit, args=(0.01,)),
+ threading.Thread(target=SleepABit, args=(0.02,)),
+ threading.Thread(target=SleepABit, args=(0.01,))]
+ for t in threads:
+ t.start()
+ coord.join(threads)
+
+ def testJoinGraceExpires(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=StopInN, args=(coord, 0.01)),
+ threading.Thread(target=SleepABit, args=(10.0,))]
+ for t in threads:
+ t.daemon = True
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "threads still running"):
+ coord.join(threads, stop_grace_period_secs=0.02)
+
+ def testJoinRaiseReportExcInfo(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), False)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), False))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+ def testJoinRaiseReportException(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), True)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), True))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+
+if __name__ == "__main__":
+ tf.test.main()