diff options
Diffstat (limited to 'tensorflow/python/training/queue_runner_test.py')
-rw-r--r-- | tensorflow/python/training/queue_runner_test.py | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py new file mode 100644 index 0000000000..c94c02da66 --- /dev/null +++ b/tensorflow/python/training/queue_runner_test.py @@ -0,0 +1,186 @@ +"""Tests for QueueRunner.""" +import time + +import tensorflow.python.platform + +import tensorflow as tf + + +class QueueRunnerTest(tf.test.TestCase): + + def testBasic(self): + with self.test_session() as sess: + # CountUpTo will raise OUT_OF_RANGE when it reaches the count. + zero64 = tf.constant(0, dtype=tf.int64) + var = tf.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = tf.FIFOQueue(10, tf.float32) + tf.initialize_all_variables().run() + qr = tf.train.QueueRunner(queue, [count_up_to]) + threads = qr.create_threads(sess) + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(0, len(qr.exceptions_raised)) + # The variable should be 3. + self.assertEqual(3, var.eval()) + + def testTwoOps(self): + with self.test_session() as sess: + # CountUpTo will raise OUT_OF_RANGE when it reaches the count. + zero64 = tf.constant(0, dtype=tf.int64) + var0 = tf.Variable(zero64) + count_up_to_3 = var0.count_up_to(3) + var1 = tf.Variable(zero64) + count_up_to_30 = var1.count_up_to(30) + queue = tf.FIFOQueue(10, tf.float32) + qr = tf.train.QueueRunner(queue, [count_up_to_3, count_up_to_30]) + threads = qr.create_threads(sess) + tf.initialize_all_variables().run() + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(0, len(qr.exceptions_raised)) + self.assertEqual(3, var0.eval()) + self.assertEqual(30, var1.eval()) + + def testExceptionsCaptured(self): + with self.test_session() as sess: + queue = tf.FIFOQueue(10, tf.float32) + qr = tf.train.QueueRunner(queue, ["i fail", "so fail"]) + threads = qr.create_threads(sess) + tf.initialize_all_variables().run() + for t in threads: + t.start() + for t in threads: + t.join() + exceptions = qr.exceptions_raised + self.assertEqual(2, len(exceptions)) + self.assertTrue("Operation not in the graph" in str(exceptions[0])) + self.assertTrue("Operation not in the graph" in str(exceptions[1])) + + def testRealDequeueEnqueue(self): + with self.test_session() as sess: + q0 = tf.FIFOQueue(3, tf.float32) + enqueue0 = q0.enqueue((10.0,)) + close0 = q0.close() + q1 = tf.FIFOQueue(30, tf.float32) + enqueue1 = q1.enqueue((q0.dequeue(),)) + dequeue1 = q1.dequeue() + qr = tf.train.QueueRunner(q1, [enqueue1]) + threads = qr.create_threads(sess) + for t in threads: + t.start() + # Enqueue 2 values, then close queue0. + enqueue0.run() + enqueue0.run() + close0.run() + # Wait for the queue runner to terminate. + for t in threads: + t.join() + # It should have terminated cleanly. + self.assertEqual(0, len(qr.exceptions_raised)) + # The 2 values should be in queue1. + self.assertEqual(10.0, dequeue1.eval()) + self.assertEqual(10.0, dequeue1.eval()) + # And queue1 should now be closed. + with self.assertRaisesRegexp(tf.errors.OutOfRangeError, "is closed"): + dequeue1.eval() + + def testRespectCoordShouldStop(self): + with self.test_session() as sess: + # CountUpTo will raise OUT_OF_RANGE when it reaches the count. + zero64 = tf.constant(0, dtype=tf.int64) + var = tf.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = tf.FIFOQueue(10, tf.float32) + tf.initialize_all_variables().run() + qr = tf.train.QueueRunner(queue, [count_up_to]) + # As the coordinator to stop. The queue runner should + # finish immediately. + coord = tf.train.Coordinator() + coord.request_stop() + threads = qr.create_threads(sess, coord) + for t in threads: + t.start() + coord.join(threads) + self.assertEqual(0, len(qr.exceptions_raised)) + # The variable should be 0. + self.assertEqual(0, var.eval()) + + def testRequestStopOnException(self): + with self.test_session() as sess: + queue = tf.FIFOQueue(10, tf.float32) + qr = tf.train.QueueRunner(queue, ["not an op"]) + coord = tf.train.Coordinator() + threads = qr.create_threads(sess, coord) + for t in threads: + t.start() + # The exception should be re-raised when joining. + with self.assertRaisesRegexp(ValueError, "Operation not in the graph"): + coord.join(threads) + + def testGracePeriod(self): + with self.test_session() as sess: + # The enqueue will quickly block. + queue = tf.FIFOQueue(2, tf.float32) + enqueue = queue.enqueue((10.0,)) + dequeue = queue.dequeue() + qr = tf.train.QueueRunner(queue, [enqueue]) + coord = tf.train.Coordinator() + threads = qr.create_threads(sess, coord, start=True) + # Dequeue one element and then request stop. + dequeue.op.run() + time.sleep(0.02) + coord.request_stop() + # We should be able to join because the RequestStop() will cause + # the queue to be closed and the enqueue to terminate. + coord.join(threads, stop_grace_period_secs=0.05) + + def testNoMultiThreads(self): + with self.test_session() as sess: + # CountUpTo will raise OUT_OF_RANGE when it reaches the count. + zero64 = tf.constant(0, dtype=tf.int64) + var = tf.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = tf.FIFOQueue(10, tf.float32) + tf.initialize_all_variables().run() + coord = tf.train.Coordinator() + qr = tf.train.QueueRunner(queue, [count_up_to]) + threads = [] + threads.extend(qr.create_threads(sess, coord=coord)) + with self.assertRaisesRegexp( + RuntimeError, + "Threads are already running"): + threads.extend(qr.create_threads(sess, coord=coord)) + coord.request_stop() + coord.join(threads, stop_grace_period_secs=0.5) + + def testThreads(self): + with self.test_session() as sess: + # CountUpTo will raise OUT_OF_RANGE when it reaches the count. + zero64 = tf.constant(0, dtype=tf.int64) + var = tf.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = tf.FIFOQueue(10, tf.float32) + tf.initialize_all_variables().run() + qr = tf.train.QueueRunner(queue, [count_up_to, "bad op"]) + threads = qr.create_threads(sess, start=True) + for t in threads: + t.join() + exceptions = qr.exceptions_raised + self.assertEqual(1, len(exceptions)) + self.assertTrue("Operation not in the graph" in str(exceptions[0])) + + threads = qr.create_threads(sess, start=True) + for t in threads: + t.join() + exceptions = qr.exceptions_raised + self.assertEqual(1, len(exceptions)) + self.assertTrue("Operation not in the graph" in str(exceptions[0])) + + +if __name__ == "__main__": + tf.test.main() |