diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/stage_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/stage_op_test.py | 87 |
1 files changed, 42 insertions, 45 deletions
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py index 1a6a869e3d..4a89fb64e3 100644 --- a/tensorflow/python/kernel_tests/stage_op_test.py +++ b/tensorflow/python/kernel_tests/stage_op_test.py @@ -23,11 +23,9 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -TIMEOUT = 1 class StageTest(test.TestCase): - def testSimple(self): with ops.Graph().as_default() as G: with ops.device('/cpu:0'): @@ -174,7 +172,8 @@ class StageTest(test.TestCase): import threading queue = Queue.Queue() - n = 8 + n = 5 + missed = 0 with self.test_session(use_gpu=True, graph=G) as sess: # Stage data in a separate thread which will block @@ -186,33 +185,31 @@ class StageTest(test.TestCase): queue.put(0) t = threading.Thread(target=thread_run) - t.daemon = True t.start() - # Get tokens from the queue until a timeout occurs - try: - for i in range(n): - queue.get(timeout=TIMEOUT) - except Queue.Empty: - pass - - # Should've timed out on the iteration 'capacity' - if not i == capacity: - self.fail("Expected to timeout on iteration '{}' " - "but instead timed out on iteration '{}' " - "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) - - # Should have capacity elements in the staging area + # Get tokens from the queue, making notes of when we timeout + for i in range(n): + try: + queue.get(timeout=0.05) + except Queue.Empty: + missed += 1 + + # We timed out n - capacity times waiting for queue puts + self.assertTrue(missed == n - capacity) + + # Clear the staging area out a bit + for i in range(n - capacity): + self.assertTrue(sess.run(ret) == i) + + # Thread should be able to join now + t.join() + self.assertTrue(sess.run(size) == capacity) # Clear the staging area completely - for i in range(n): - self.assertTrue(sess.run(ret) == i) + for i in range(capacity): + self.assertTrue(sess.run(ret) == i+(n-capacity)) - # It should now be empty self.assertTrue(sess.run(size) == 0) def testMemoryLimit(self): @@ -237,7 +234,8 @@ class StageTest(test.TestCase): import numpy as np queue = Queue.Queue() - n = 8 + n = 5 + missed = 0 with self.test_session(use_gpu=True, graph=G) as sess: # Stage data in a separate thread which will block @@ -249,31 +247,30 @@ class StageTest(test.TestCase): queue.put(0) t = threading.Thread(target=thread_run) - t.daemon = True t.start() - # Get tokens from the queue until a timeout occurs - try: - for i in range(n): - queue.get(timeout=TIMEOUT) - except Queue.Empty: - pass - - # Should've timed out on the iteration 'capacity' - if not i == capacity: - self.fail("Expected to timeout on iteration '{}' " - "but instead timed out on iteration '{}' " - "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) - - # Should have capacity elements in the staging area + # Get tokens from the queue, making notes of when we timeout + for i in range(n): + try: + queue.get(timeout=0.05) + except Queue.Empty: + missed += 1 + + # We timed out n - capacity times waiting for queue puts + self.assertTrue(missed == n - capacity) + + # Clear the staging area out a bit + for i in range(n - capacity): + self.assertTrue(sess.run(ret)[0] == i) + + # Thread should be able to join now + t.join() + self.assertTrue(sess.run(size) == capacity) # Clear the staging area completely - for i in range(n): - self.assertTrue(np.all(sess.run(ret) == i)) + for i in range(capacity): + self.assertTrue(sess.run(ret)[0] == i+(n-capacity)) self.assertTrue(sess.run(size) == 0) |