diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/stage_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/stage_op_test.py | 87 |
1 files changed, 45 insertions, 42 deletions
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py index 4a89fb64e3..1a6a869e3d 100644 --- a/tensorflow/python/kernel_tests/stage_op_test.py +++ b/tensorflow/python/kernel_tests/stage_op_test.py @@ -23,9 +23,11 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +TIMEOUT = 1 class StageTest(test.TestCase): + def testSimple(self): with ops.Graph().as_default() as G: with ops.device('/cpu:0'): @@ -172,8 +174,7 @@ class StageTest(test.TestCase): import threading queue = Queue.Queue() - n = 5 - missed = 0 + n = 8 with self.test_session(use_gpu=True, graph=G) as sess: # Stage data in a separate thread which will block @@ -185,31 +186,33 @@ class StageTest(test.TestCase): queue.put(0) t = threading.Thread(target=thread_run) + t.daemon = True t.start() - # Get tokens from the queue, making notes of when we timeout - for i in range(n): - try: - queue.get(timeout=0.05) - except Queue.Empty: - missed += 1 - - # We timed out n - capacity times waiting for queue puts - self.assertTrue(missed == n - capacity) - - # Clear the staging area out a bit - for i in range(n - capacity): - self.assertTrue(sess.run(ret) == i) - - # Thread should be able to join now - t.join() - + # Get tokens from the queue until a timeout occurs + try: + for i in range(n): + queue.get(timeout=TIMEOUT) + except Queue.Empty: + pass + + # Should've timed out on the iteration 'capacity' + if not i == capacity: + self.fail("Expected to timeout on iteration '{}' " + "but instead timed out on iteration '{}' " + "Staging Area size is '{}' and configured " + "capacity is '{}'.".format(capacity, i, + sess.run(size), + capacity)) + + # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) # Clear the staging area completely - for i in range(capacity): - self.assertTrue(sess.run(ret) == i+(n-capacity)) + for i in range(n): + self.assertTrue(sess.run(ret) == i) + # It should now be empty self.assertTrue(sess.run(size) == 0) def testMemoryLimit(self): @@ -234,8 +237,7 @@ class StageTest(test.TestCase): import numpy as np queue = Queue.Queue() - n = 5 - missed = 0 + n = 8 with self.test_session(use_gpu=True, graph=G) as sess: # Stage data in a separate thread which will block @@ -247,30 +249,31 @@ class StageTest(test.TestCase): queue.put(0) t = threading.Thread(target=thread_run) + t.daemon = True t.start() - # Get tokens from the queue, making notes of when we timeout - for i in range(n): - try: - queue.get(timeout=0.05) - except Queue.Empty: - missed += 1 - - # We timed out n - capacity times waiting for queue puts - self.assertTrue(missed == n - capacity) - - # Clear the staging area out a bit - for i in range(n - capacity): - self.assertTrue(sess.run(ret)[0] == i) - - # Thread should be able to join now - t.join() - + # Get tokens from the queue until a timeout occurs + try: + for i in range(n): + queue.get(timeout=TIMEOUT) + except Queue.Empty: + pass + + # Should've timed out on the iteration 'capacity' + if not i == capacity: + self.fail("Expected to timeout on iteration '{}' " + "but instead timed out on iteration '{}' " + "Staging Area size is '{}' and configured " + "capacity is '{}'.".format(capacity, i, + sess.run(size), + capacity)) + + # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) # Clear the staging area completely - for i in range(capacity): - self.assertTrue(sess.run(ret)[0] == i+(n-capacity)) + for i in range(n): + self.assertTrue(np.all(sess.run(ret) == i)) self.assertTrue(sess.run(size) == 0) |