diff options
author | siddharthshukla <siddharthshukla@outlook.com> | 2016-08-03 17:55:10 +0200 |
---|---|---|
committer | siddharthshukla <siddharthshukla@outlook.com> | 2016-08-03 22:13:37 +0200 |
commit | de84d566b8fad6808e5263a25a17fa231cb5713c (patch) | |
tree | d12f09f10b50db7a9e57264161e5a19607f7aa3b /src/python/grpcio_tests | |
parent | eedc335580c6691f5401674841be0362596e1d9c (diff) |
Fix the ThreadPoolExecutor: max_workers can't be 0
Add a RecordingThreadPool that inherits from Executor, contains a
ThreadPoolExecutor and has an extra method 'was_used' to indicate if
submit method was ever called i.e. if the thread pool was ever used.
Diffstat (limited to 'src/python/grpcio_tests')
3 files changed, 59 insertions, 5 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py index 3c00f686ce..9cae96a00d 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -32,12 +32,12 @@ import threading import time import unittest -from concurrent import futures import grpc from grpc import _channel from grpc import _server from tests.unit.framework.common import test_constants +from tests.unit import _thread_pool def _ready_in_connectivities(connectivities): @@ -104,7 +104,8 @@ class ChannelConnectivityTest(unittest.TestCase): grpc.ChannelConnectivity.READY, fifth_connectivities) def test_immediately_connectable_channel_connectivity(self): - server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ()) + thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) + server = _server.Server(thread_pool, ()) port = server.add_insecure_port('[::]:0') server.start() first_callback = _Callback() @@ -141,9 +142,11 @@ class ChannelConnectivityTest(unittest.TestCase): fourth_connectivities) self.assertNotIn( grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities) + self.assertFalse(thread_pool.was_used()) def test_reachable_then_unreachable_channel_connectivity(self): - server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ()) + thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) + server = _server.Server(thread_pool, ()) port = server.add_insecure_port('[::]:0') server.start() callback = _Callback() @@ -155,6 +158,7 @@ class ChannelConnectivityTest(unittest.TestCase): server.stop(None) callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready) channel.unsubscribe(callback.update) + self.assertFalse(thread_pool.was_used()) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py index e8982ed2de..24f5b45b18 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -31,12 +31,12 @@ import threading import unittest -from concurrent import futures import grpc from grpc import _channel from grpc import _server from tests.unit.framework.common import test_constants +from tests.unit import _thread_pool class _Callback(object): @@ -78,7 +78,8 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertFalse(ready_future.running()) def test_immediately_connectable_channel_connectivity(self): - server = _server.Server(futures.ThreadPoolExecutor(max_workers=0), ()) + thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) + server = _server.Server(thread_pool, ()) port = server.add_insecure_port('[::]:0') server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) @@ -97,6 +98,7 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertFalse(ready_future.cancelled()) self.assertTrue(ready_future.done()) self.assertFalse(ready_future.running()) + self.assertFalse(thread_pool.was_used()) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_thread_pool.py b/src/python/grpcio_tests/tests/unit/_thread_pool.py new file mode 100644 index 0000000000..f13cc2f86f --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_thread_pool.py @@ -0,0 +1,48 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import threading +from concurrent import futures + + +class RecordingThreadPool(futures.Executor): + """A thread pool that records if used.""" + def __init__(self, max_workers): + self._tp_executor = futures.ThreadPoolExecutor(max_workers=max_workers) + self._lock = threading.Lock() + self._was_used = False + + def submit(self, fn, *args, **kwargs): + with self._lock: + self._was_used = True + self._tp_executor.submit(fn, *args, **kwargs) + + def was_used(self): + with self._lock: + return self._was_used |