aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/queue_runner_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/queue_runner_test.py')
-rw-r--r--tensorflow/python/training/queue_runner_test.py186
1 files changed, 186 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
new file mode 100644
index 0000000000..c94c02da66
--- /dev/null
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -0,0 +1,186 @@
+"""Tests for QueueRunner."""
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class QueueRunnerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 3.
+ self.assertEqual(3, var.eval())
+
+ def testTwoOps(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var0 = tf.Variable(zero64)
+ count_up_to_3 = var0.count_up_to(3)
+ var1 = tf.Variable(zero64)
+ count_up_to_30 = var1.count_up_to(30)
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, [count_up_to_3, count_up_to_30])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ self.assertEqual(3, var0.eval())
+ self.assertEqual(30, var1.eval())
+
+ def testExceptionsCaptured(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["i fail", "so fail"])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(2, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+ self.assertTrue("Operation not in the graph" in str(exceptions[1]))
+
+ def testRealDequeueEnqueue(self):
+ with self.test_session() as sess:
+ q0 = tf.FIFOQueue(3, tf.float32)
+ enqueue0 = q0.enqueue((10.0,))
+ close0 = q0.close()
+ q1 = tf.FIFOQueue(30, tf.float32)
+ enqueue1 = q1.enqueue((q0.dequeue(),))
+ dequeue1 = q1.dequeue()
+ qr = tf.train.QueueRunner(q1, [enqueue1])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ # Enqueue 2 values, then close queue0.
+ enqueue0.run()
+ enqueue0.run()
+ close0.run()
+ # Wait for the queue runner to terminate.
+ for t in threads:
+ t.join()
+ # It should have terminated cleanly.
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The 2 values should be in queue1.
+ self.assertEqual(10.0, dequeue1.eval())
+ self.assertEqual(10.0, dequeue1.eval())
+ # And queue1 should now be closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError, "is closed"):
+ dequeue1.eval()
+
+ def testRespectCoordShouldStop(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ # As the coordinator to stop. The queue runner should
+ # finish immediately.
+ coord = tf.train.Coordinator()
+ coord.request_stop()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ coord.join(threads)
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 0.
+ self.assertEqual(0, var.eval())
+
+ def testRequestStopOnException(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["not an op"])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ # The exception should be re-raised when joining.
+ with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
+ coord.join(threads)
+
+ def testGracePeriod(self):
+ with self.test_session() as sess:
+ # The enqueue will quickly block.
+ queue = tf.FIFOQueue(2, tf.float32)
+ enqueue = queue.enqueue((10.0,))
+ dequeue = queue.dequeue()
+ qr = tf.train.QueueRunner(queue, [enqueue])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord, start=True)
+ # Dequeue one element and then request stop.
+ dequeue.op.run()
+ time.sleep(0.02)
+ coord.request_stop()
+ # We should be able to join because the RequestStop() will cause
+ # the queue to be closed and the enqueue to terminate.
+ coord.join(threads, stop_grace_period_secs=0.05)
+
+ def testNoMultiThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ coord = tf.train.Coordinator()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = []
+ threads.extend(qr.create_threads(sess, coord=coord))
+ with self.assertRaisesRegexp(
+ RuntimeError,
+ "Threads are already running"):
+ threads.extend(qr.create_threads(sess, coord=coord))
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=0.5)
+
+ def testThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to, "bad op"])
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()