aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/stage_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/stage_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/stage_op_test.py87
1 files changed, 45 insertions, 42 deletions
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index 4a89fb64e3..1a6a869e3d 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -23,9 +23,11 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+TIMEOUT = 1
class StageTest(test.TestCase):
+
def testSimple(self):
with ops.Graph().as_default() as G:
with ops.device('/cpu:0'):
@@ -172,8 +174,7 @@ class StageTest(test.TestCase):
import threading
queue = Queue.Queue()
- n = 5
- missed = 0
+ n = 8
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -185,31 +186,33 @@ class StageTest(test.TestCase):
queue.put(0)
t = threading.Thread(target=thread_run)
+ t.daemon = True
t.start()
- # Get tokens from the queue, making notes of when we timeout
- for i in range(n):
- try:
- queue.get(timeout=0.05)
- except Queue.Empty:
- missed += 1
-
- # We timed out n - capacity times waiting for queue puts
- self.assertTrue(missed == n - capacity)
-
- # Clear the staging area out a bit
- for i in range(n - capacity):
- self.assertTrue(sess.run(ret) == i)
-
- # Thread should be able to join now
- t.join()
-
+ # Get tokens from the queue until a timeout occurs
+ try:
+ for i in range(n):
+ queue.get(timeout=TIMEOUT)
+ except Queue.Empty:
+ pass
+
+ # Should've timed out on the iteration 'capacity'
+ if not i == capacity:
+ self.fail("Expected to timeout on iteration '{}' "
+ "but instead timed out on iteration '{}' "
+ "Staging Area size is '{}' and configured "
+ "capacity is '{}'.".format(capacity, i,
+ sess.run(size),
+ capacity))
+
+ # Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
- for i in range(capacity):
- self.assertTrue(sess.run(ret) == i+(n-capacity))
+ for i in range(n):
+ self.assertTrue(sess.run(ret) == i)
+ # It should now be empty
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
@@ -234,8 +237,7 @@ class StageTest(test.TestCase):
import numpy as np
queue = Queue.Queue()
- n = 5
- missed = 0
+ n = 8
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -247,30 +249,31 @@ class StageTest(test.TestCase):
queue.put(0)
t = threading.Thread(target=thread_run)
+ t.daemon = True
t.start()
- # Get tokens from the queue, making notes of when we timeout
- for i in range(n):
- try:
- queue.get(timeout=0.05)
- except Queue.Empty:
- missed += 1
-
- # We timed out n - capacity times waiting for queue puts
- self.assertTrue(missed == n - capacity)
-
- # Clear the staging area out a bit
- for i in range(n - capacity):
- self.assertTrue(sess.run(ret)[0] == i)
-
- # Thread should be able to join now
- t.join()
-
+ # Get tokens from the queue until a timeout occurs
+ try:
+ for i in range(n):
+ queue.get(timeout=TIMEOUT)
+ except Queue.Empty:
+ pass
+
+ # Should've timed out on the iteration 'capacity'
+ if not i == capacity:
+ self.fail("Expected to timeout on iteration '{}' "
+ "but instead timed out on iteration '{}' "
+ "Staging Area size is '{}' and configured "
+ "capacity is '{}'.".format(capacity, i,
+ sess.run(size),
+ capacity))
+
+ # Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
- for i in range(capacity):
- self.assertTrue(sess.run(ret)[0] == i+(n-capacity))
+ for i in range(n):
+ self.assertTrue(np.all(sess.run(ret) == i))
self.assertTrue(sess.run(size) == 0)