diff options
Diffstat (limited to 'src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py')
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py | 302 |
1 files changed, 153 insertions, 149 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 20115fb22c..d77f5ecb27 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -26,7 +26,6 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - """Test making many calls and immediately cancelling most of them.""" import threading @@ -51,173 +50,178 @@ _SUCCESS_CALL_FRACTION = 1.0 / 8.0 class _State(object): - def __init__(self): - self.condition = threading.Condition() - self.handlers_released = False - self.parked_handlers = 0 - self.handled_rpcs = 0 + def __init__(self): + self.condition = threading.Condition() + self.handlers_released = False + self.parked_handlers = 0 + self.handled_rpcs = 0 def _is_cancellation_event(event): - return ( - event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and - event.batch_operations[0].received_cancelled) + return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and + event.batch_operations[0].received_cancelled) class _Handler(object): - def __init__(self, state, completion_queue, rpc_event): - self._state = state - self._lock = threading.Lock() - self._completion_queue = completion_queue - self._call = rpc_event.operation_call - - def __call__(self): - with self._state.condition: - self._state.parked_handlers += 1 - if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY: - self._state.condition.notify_all() - while not self._state.handlers_released: - self._state.condition.wait() - - with self._lock: - self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), - _RECEIVE_CLOSE_ON_SERVER_TAG) - self._call.start_server_batch( - cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)), - _RECEIVE_MESSAGE_TAG) - first_event = self._completion_queue.poll() - if _is_cancellation_event(first_event): - self._completion_queue.poll() - else: - with self._lock: - operations = ( - cygrpc.operation_send_initial_metadata( - _EMPTY_METADATA, _EMPTY_FLAGS), - cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS), - cygrpc.operation_send_status_from_server( - _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', - _EMPTY_FLAGS), - ) - self._call.start_server_batch( - cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) - self._completion_queue.poll() - self._completion_queue.poll() + def __init__(self, state, completion_queue, rpc_event): + self._state = state + self._lock = threading.Lock() + self._completion_queue = completion_queue + self._call = rpc_event.operation_call + + def __call__(self): + with self._state.condition: + self._state.parked_handlers += 1 + if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY: + self._state.condition.notify_all() + while not self._state.handlers_released: + self._state.condition.wait() + + with self._lock: + self._call.start_server_batch( + cygrpc.Operations( + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + _RECEIVE_CLOSE_ON_SERVER_TAG) + self._call.start_server_batch( + cygrpc.Operations( + (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + _RECEIVE_MESSAGE_TAG) + first_event = self._completion_queue.poll() + if _is_cancellation_event(first_event): + self._completion_queue.poll() + else: + with self._lock: + operations = ( + cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, + _EMPTY_FLAGS), + cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', + _EMPTY_FLAGS),) + self._call.start_server_batch( + cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) + self._completion_queue.poll() + self._completion_queue.poll() def _serve(state, server, server_completion_queue, thread_pool): - for _ in range(test_constants.RPC_CONCURRENCY): - call_completion_queue = cygrpc.CompletionQueue() - server.request_call( - call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG) - rpc_event = server_completion_queue.poll() - thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) - with state.condition: - state.handled_rpcs += 1 - if test_constants.RPC_CONCURRENCY <= state.handled_rpcs: - state.condition.notify_all() - server_completion_queue.poll() + for _ in range(test_constants.RPC_CONCURRENCY): + call_completion_queue = cygrpc.CompletionQueue() + server.request_call(call_completion_queue, server_completion_queue, + _REQUEST_CALL_TAG) + rpc_event = server_completion_queue.poll() + thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) + with state.condition: + state.handled_rpcs += 1 + if test_constants.RPC_CONCURRENCY <= state.handled_rpcs: + state.condition.notify_all() + server_completion_queue.poll() class _QueueDriver(object): - def __init__(self, condition, completion_queue, due): - self._condition = condition - self._completion_queue = completion_queue - self._due = due - self._events = [] - self._returned = False - - def start(self): - def in_thread(): - while True: - event = self._completion_queue.poll() + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + + thread = threading.Thread(target=in_thread) + thread.start() + + def events(self, at_least): with self._condition: - self._events.append(event) - self._due.remove(event.tag) - self._condition.notify_all() - if not self._due: - self._returned = True - return - thread = threading.Thread(target=in_thread) - thread.start() - - def events(self, at_least): - with self._condition: - while len(self._events) < at_least: - self._condition.wait() - return tuple(self._events) + while len(self._events) < at_least: + self._condition.wait() + return tuple(self._events) class CancelManyCallsTest(unittest.TestCase): - def testCancelManyCalls(self): - server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) - - server_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server(cygrpc.ChannelArgs([])) - server.register_completion_queue(server_completion_queue) - port = server.add_http2_port(b'[::]:0') - server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), - cygrpc.ChannelArgs([])) - - state = _State() - - server_thread_args = ( - state, server, server_completion_queue, server_thread_pool,) - server_thread = threading.Thread(target=_serve, args=server_thread_args) - server_thread.start() - - client_condition = threading.Condition() - client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver( - client_condition, client_completion_queue, client_due) - client_driver.start() - - with client_condition: - client_calls = [] - for index in range(test_constants.RPC_CONCURRENCY): - client_call = channel.create_call( - None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None, - _INFINITE_FUTURE) - operations = ( - cygrpc.operation_send_initial_metadata( - _EMPTY_METADATA, _EMPTY_FLAGS), - cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS), - cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), - cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), - cygrpc.operation_receive_message(_EMPTY_FLAGS), - cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), - ) - tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch(cygrpc.Operations(operations), tag) - client_due.add(tag) - client_calls.append(client_call) - - with state.condition: - while True: - if state.parked_handlers < test_constants.THREAD_CONCURRENCY: - state.condition.wait() - elif state.handled_rpcs < test_constants.RPC_CONCURRENCY: - state.condition.wait() - else: - state.handlers_released = True - state.condition.notify_all() - break - - client_driver.events( - test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) - with client_condition: - for client_call in client_calls: - client_call.cancel() - - with state.condition: - server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) + def testCancelManyCalls(self): + server_thread_pool = logging_pool.pool( + test_constants.THREAD_CONCURRENCY) + + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server(cygrpc.ChannelArgs([])) + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port(b'[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), + cygrpc.ChannelArgs([])) + + state = _State() + + server_thread_args = ( + state, + server, + server_completion_queue, + server_thread_pool,) + server_thread = threading.Thread(target=_serve, args=server_thread_args) + server_thread.start() + + client_condition = threading.Condition() + client_due = set() + client_completion_queue = cygrpc.CompletionQueue() + client_driver = _QueueDriver(client_condition, client_completion_queue, + client_due) + client_driver.start() + + with client_condition: + client_calls = [] + for index in range(test_constants.RPC_CONCURRENCY): + client_call = channel.create_call( + None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', + None, _INFINITE_FUTURE) + operations = ( + cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, + _EMPTY_FLAGS), + cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) + tag = 'client_complete_call_{0:04d}_tag'.format(index) + client_call.start_client_batch( + cygrpc.Operations(operations), tag) + client_due.add(tag) + client_calls.append(client_call) + + with state.condition: + while True: + if state.parked_handlers < test_constants.THREAD_CONCURRENCY: + state.condition.wait() + elif state.handled_rpcs < test_constants.RPC_CONCURRENCY: + state.condition.wait() + else: + state.handlers_released = True + state.condition.notify_all() + break + + client_driver.events(test_constants.RPC_CONCURRENCY * + _SUCCESS_CALL_FRACTION) + with client_condition: + for client_call in client_calls: + client_call.cancel() + + with state.condition: + server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) if __name__ == '__main__': - unittest.main(verbosity=2) + unittest.main(verbosity=2) |