diff options
author | Nathaniel Manista <nathaniel@google.com> | 2016-05-21 13:55:03 +0000 |
---|---|---|
committer | Nathaniel Manista <nathaniel@google.com> | 2016-05-23 13:14:01 +0000 |
commit | 29f2cb868741967ed34fcb293451784f7effe2d8 (patch) | |
tree | d1dac394da16ba2cbbf04924e439c6b0007a7eed /src/python/grpcio/tests/unit | |
parent | 52aeb54077713f5cf139ba495516be2ee3132865 (diff) |
Add a Cython-level cancel-many-calls test
Diffstat (limited to 'src/python/grpcio/tests/unit')
-rw-r--r-- | src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py new file mode 100644 index 0000000000..c1de779014 --- /dev/null +++ b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py @@ -0,0 +1,222 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test making many calls and immediately cancelling most of them.""" + +import threading +import unittest + +from grpc._cython import cygrpc +from grpc.framework.foundation import logging_pool +from tests.unit.framework.common import test_constants + +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = cygrpc.Metadata(()) + +_SERVER_SHUTDOWN_TAG = 'server_shutdown' +_REQUEST_CALL_TAG = 'request_call' +_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server' +_RECEIVE_MESSAGE_TAG = 'receive_message' +_SERVER_COMPLETE_CALL_TAG = 'server_complete_call' + +_SUCCESS_CALL_FRACTION = 1.0 / 8.0 + + +class _State(object): + + def __init__(self): + self.condition = threading.Condition() + self.handlers_released = False + self.parked_handlers = 0 + self.handled_rpcs = 0 + + +def _is_cancellation_event(event): + return ( + event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and + event.batch_operations[0].received_cancelled) + + +class _Handler(object): + + def __init__(self, state, completion_queue, rpc_event): + self._state = state + self._lock = threading.Lock() + self._completion_queue = completion_queue + self._call = rpc_event.operation_call + + def __call__(self): + with self._state.condition: + self._state.parked_handlers += 1 + if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY: + self._state.condition.notify_all() + while not self._state.handlers_released: + self._state.condition.wait() + + with self._lock: + self._call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + _RECEIVE_CLOSE_ON_SERVER_TAG) + self._call.start_batch( + cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + _RECEIVE_MESSAGE_TAG) + first_event = self._completion_queue.poll() + if _is_cancellation_event(first_event): + self._completion_queue.poll() + else: + with self._lock: + operations = ( + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', + _EMPTY_FLAGS), + ) + self._call.start_batch( + cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) + self._completion_queue.poll() + self._completion_queue.poll() + + +def _serve(state, server, server_completion_queue, thread_pool): + for _ in range(test_constants.RPC_CONCURRENCY): + call_completion_queue = cygrpc.CompletionQueue() + server.request_call( + call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG) + rpc_event = server_completion_queue.poll() + thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) + with state.condition: + state.handled_rpcs += 1 + if test_constants.RPC_CONCURRENCY <= state.handled_rpcs: + state.condition.notify_all() + server_completion_queue.poll() + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + thread = threading.Thread(target=in_thread) + thread.start() + + def events(self, at_least): + with self._condition: + while len(self._events) < at_least: + self._condition.wait() + return tuple(self._events) + + +class CancelManyCallsTest(unittest.TestCase): + + def testCancelManyCalls(self): + server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server() + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port('[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port)) + + state = _State() + + server_thread_args = ( + state, server, server_completion_queue, server_thread_pool,) + server_thread = threading.Thread(target=_serve, args=server_thread_args) + server_thread.start() + + client_condition = threading.Condition() + client_due = set() + client_completion_queue = cygrpc.CompletionQueue() + client_driver = _QueueDriver( + client_condition, client_completion_queue, client_due) + client_driver.start() + + with client_condition: + client_calls = [] + for index in range(test_constants.RPC_CONCURRENCY): + client_call = channel.create_call( + None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None, + _INFINITE_FUTURE) + operations = ( + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + tag = 'client_complete_call_{0:04d}_tag'.format(index) + client_call.start_batch(cygrpc.Operations(operations), tag) + client_due.add(tag) + client_calls.append(client_call) + + with state.condition: + while True: + if state.parked_handlers < test_constants.THREAD_CONCURRENCY: + state.condition.wait() + elif state.handled_rpcs < test_constants.RPC_CONCURRENCY: + state.condition.wait() + else: + state.handlers_released = True + state.condition.notify_all() + break + + client_driver.events( + test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) + with client_condition: + for client_call in client_calls: + client_call.cancel() + + with state.condition: + server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) + + +if __name__ == '__main__': + unittest.main(verbosity=2) |