aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio_tests
diff options
context:
space:
mode:
authorGravatar siddharthshukla <siddharthshukla@outlook.com>2016-08-03 17:55:10 +0200
committerGravatar siddharthshukla <siddharthshukla@outlook.com>2016-08-03 22:13:37 +0200
commitde84d566b8fad6808e5263a25a17fa231cb5713c (patch)
treed12f09f10b50db7a9e57264161e5a19607f7aa3b /src/python/grpcio_tests
parenteedc335580c6691f5401674841be0362596e1d9c (diff)
Fix the ThreadPoolExecutor: max_workers can't be 0
Add a RecordingThreadPool that inherits from Executor, contains a ThreadPoolExecutor and has an extra method 'was_used' to indicate if submit method was ever called i.e. if the thread pool was ever used.
Diffstat (limited to 'src/python/grpcio_tests')
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py10
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py6
-rw-r--r--src/python/grpcio_tests/tests/unit/_thread_pool.py48
3 files changed, 59 insertions, 5 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
index 3c00f686ce..9cae96a00d 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
@@ -32,12 +32,12 @@
import threading
import time
import unittest
-from concurrent import futures
import grpc
from grpc import _channel
from grpc import _server
from tests.unit.framework.common import test_constants
+from tests.unit import _thread_pool
def _ready_in_connectivities(connectivities):
@@ -104,7 +104,8 @@ class ChannelConnectivityTest(unittest.TestCase):
grpc.ChannelConnectivity.READY, fifth_connectivities)
def test_immediately_connectable_channel_connectivity(self):
- server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ())
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = _server.Server(thread_pool, ())
port = server.add_insecure_port('[::]:0')
server.start()
first_callback = _Callback()
@@ -141,9 +142,11 @@ class ChannelConnectivityTest(unittest.TestCase):
fourth_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities)
+ self.assertFalse(thread_pool.was_used())
def test_reachable_then_unreachable_channel_connectivity(self):
- server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ())
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = _server.Server(thread_pool, ())
port = server.add_insecure_port('[::]:0')
server.start()
callback = _Callback()
@@ -155,6 +158,7 @@ class ChannelConnectivityTest(unittest.TestCase):
server.stop(None)
callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready)
channel.unsubscribe(callback.update)
+ self.assertFalse(thread_pool.was_used())
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
index e8982ed2de..24f5b45b18 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
@@ -31,12 +31,12 @@
import threading
import unittest
-from concurrent import futures
import grpc
from grpc import _channel
from grpc import _server
from tests.unit.framework.common import test_constants
+from tests.unit import _thread_pool
class _Callback(object):
@@ -78,7 +78,8 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertFalse(ready_future.running())
def test_immediately_connectable_channel_connectivity(self):
- server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ())
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = _server.Server(thread_pool, ())
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port))
@@ -97,6 +98,7 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertFalse(ready_future.cancelled())
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
+ self.assertFalse(thread_pool.was_used())
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_thread_pool.py b/src/python/grpcio_tests/tests/unit/_thread_pool.py
new file mode 100644
index 0000000000..f13cc2f86f
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_thread_pool.py
@@ -0,0 +1,48 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import threading
+from concurrent import futures
+
+
+class RecordingThreadPool(futures.Executor):
+ """A thread pool that records if used."""
+ def __init__(self, max_workers):
+ self._tp_executor = futures.ThreadPoolExecutor(max_workers=max_workers)
+ self._lock = threading.Lock()
+ self._was_used = False
+
+ def submit(self, fn, *args, **kwargs):
+ with self._lock:
+ self._was_used = True
+ self._tp_executor.submit(fn, *args, **kwargs)
+
+ def was_used(self):
+ with self._lock:
+ return self._was_used