aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/stage_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/stage_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/stage_op_test.py87
1 files changed, 42 insertions, 45 deletions
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index 1a6a869e3d..4a89fb64e3 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -23,11 +23,9 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-TIMEOUT = 1
class StageTest(test.TestCase):
-
def testSimple(self):
with ops.Graph().as_default() as G:
with ops.device('/cpu:0'):
@@ -174,7 +172,8 @@ class StageTest(test.TestCase):
import threading
queue = Queue.Queue()
- n = 8
+ n = 5
+ missed = 0
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -186,33 +185,31 @@ class StageTest(test.TestCase):
queue.put(0)
t = threading.Thread(target=thread_run)
- t.daemon = True
t.start()
- # Get tokens from the queue until a timeout occurs
- try:
- for i in range(n):
- queue.get(timeout=TIMEOUT)
- except Queue.Empty:
- pass
-
- # Should've timed out on the iteration 'capacity'
- if not i == capacity:
- self.fail("Expected to timeout on iteration '{}' "
- "but instead timed out on iteration '{}' "
- "Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
-
- # Should have capacity elements in the staging area
+ # Get tokens from the queue, making notes of when we timeout
+ for i in range(n):
+ try:
+ queue.get(timeout=0.05)
+ except Queue.Empty:
+ missed += 1
+
+ # We timed out n - capacity times waiting for queue puts
+ self.assertTrue(missed == n - capacity)
+
+ # Clear the staging area out a bit
+ for i in range(n - capacity):
+ self.assertTrue(sess.run(ret) == i)
+
+ # Thread should be able to join now
+ t.join()
+
self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
- for i in range(n):
- self.assertTrue(sess.run(ret) == i)
+ for i in range(capacity):
+ self.assertTrue(sess.run(ret) == i+(n-capacity))
- # It should now be empty
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
@@ -237,7 +234,8 @@ class StageTest(test.TestCase):
import numpy as np
queue = Queue.Queue()
- n = 8
+ n = 5
+ missed = 0
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -249,31 +247,30 @@ class StageTest(test.TestCase):
queue.put(0)
t = threading.Thread(target=thread_run)
- t.daemon = True
t.start()
- # Get tokens from the queue until a timeout occurs
- try:
- for i in range(n):
- queue.get(timeout=TIMEOUT)
- except Queue.Empty:
- pass
-
- # Should've timed out on the iteration 'capacity'
- if not i == capacity:
- self.fail("Expected to timeout on iteration '{}' "
- "but instead timed out on iteration '{}' "
- "Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
-
- # Should have capacity elements in the staging area
+ # Get tokens from the queue, making notes of when we timeout
+ for i in range(n):
+ try:
+ queue.get(timeout=0.05)
+ except Queue.Empty:
+ missed += 1
+
+ # We timed out n - capacity times waiting for queue puts
+ self.assertTrue(missed == n - capacity)
+
+ # Clear the staging area out a bit
+ for i in range(n - capacity):
+ self.assertTrue(sess.run(ret)[0] == i)
+
+ # Thread should be able to join now
+ t.join()
+
self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
- for i in range(n):
- self.assertTrue(np.all(sess.run(ret) == i))
+ for i in range(capacity):
+ self.assertTrue(sess.run(ret)[0] == i+(n-capacity))
self.assertTrue(sess.run(size) == 0)