aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/map_stage_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/map_stage_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/map_stage_op_test.py93
1 files changed, 49 insertions, 44 deletions
diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py
index 2d2169c310..4ceb24862f 100644
--- a/tensorflow/python/kernel_tests/map_stage_op_test.py
+++ b/tensorflow/python/kernel_tests/map_stage_op_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+TIMEOUT = 1
class MapStageTest(test.TestCase):
@@ -194,8 +195,7 @@ class MapStageTest(test.TestCase):
import threading
queue = Queue.Queue()
- n = 5
- missed = 0
+ n = 8
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -207,31 +207,34 @@ class MapStageTest(test.TestCase):
queue.put(0)
t = threading.Thread(target=thread_run)
+ t.daemon = True
t.start()
- # Get tokens from the queue, making notes of when we timeout
- for i in range(n):
- try:
- queue.get(timeout=0.05)
- except Queue.Empty:
- missed += 1
-
- # We timed out n - capacity times waiting for queue puts
- self.assertTrue(missed == n - capacity)
-
- # Clear the staging area out a bit
- for i in range(n - capacity):
- sess.run(get)
-
- # This should now succeed
- t.join()
-
+ # Get tokens from the queue until a timeout occurs
+ try:
+ for i in range(n):
+ queue.get(timeout=TIMEOUT)
+ except Queue.Empty:
+ pass
+
+ # Should've timed out on the iteration 'capacity'
+ if not i == capacity:
+ self.fail("Expected to timeout on iteration '{}' "
+ "but instead timed out on iteration '{}' "
+ "Staging Area size is '{}' and configured "
+ "capacity is '{}'.".format(capacity, i,
+ sess.run(size),
+ capacity))
+
+ # Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
- # Clear out the staging area completely
- for i in range(capacity):
+ # Clear the staging area completely
+ for i in range(n):
sess.run(get)
+ self.assertTrue(sess.run(size) == 0)
+
def testMemoryLimit(self):
memory_limit = 512*1024 # 512K
chunk = 200*1024 # 256K
@@ -256,8 +259,7 @@ class MapStageTest(test.TestCase):
import numpy as np
queue = Queue.Queue()
- n = 5
- missed = 0
+ n = 8
with self.test_session(use_gpu=True, graph=G) as sess:
# Stage data in a separate thread which will block
@@ -265,36 +267,39 @@ class MapStageTest(test.TestCase):
# not fill the queue with n tokens
def thread_run():
for i in range(n):
- sess.run(stage, feed_dict={x: np.full(chunk, i, dtype=np.uint8),
- pi: i})
+ data = np.full(chunk, i, dtype=np.uint8)
+ sess.run(stage, feed_dict={x: data, pi: i})
queue.put(0)
t = threading.Thread(target=thread_run)
+ t.daemon = True
t.start()
- # Get tokens from the queue, making notes of when we timeout
- for i in range(n):
- try:
- queue.get(timeout=0.05)
- except Queue.Empty:
- missed += 1
-
- # We timed out n - capacity times waiting for queue puts
- self.assertTrue(missed == n - capacity)
-
- # Clear the staging area out a bit
- for i in range(n - capacity):
- sess.run(get)
-
- # This should now succeed
- t.join()
-
+ # Get tokens from the queue until a timeout occurs
+ try:
+ for i in range(n):
+ queue.get(timeout=TIMEOUT)
+ except Queue.Empty:
+ pass
+
+ # Should've timed out on the iteration 'capacity'
+ if not i == capacity:
+ self.fail("Expected to timeout on iteration '{}' "
+ "but instead timed out on iteration '{}' "
+ "Staging Area size is '{}' and configured "
+ "capacity is '{}'.".format(capacity, i,
+ sess.run(size),
+ capacity))
+
+ # Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
- # Clear out the staging area completely
- for i in range(capacity):
+ # Clear the staging area completely
+ for i in range(n):
sess.run(get)
+ self.assertTrue(sess.run(size) == 0)
+
def testOrdering(self):
import six
import random