diff options
Diffstat (limited to 'src/python/grpcio/grpc/_server.py')
-rw-r--r-- | src/python/grpcio/grpc/_server.py | 94 |
1 files changed, 57 insertions, 37 deletions
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 84e096d4c0..47838c2c98 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -504,37 +504,37 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, def _handle_unary_unary(rpc_event, state, method_handler, thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) - thread_pool.submit(_unary_response_in_pool, rpc_event, state, - method_handler.unary_unary, unary_request, - method_handler.request_deserializer, - method_handler.response_serializer) + return thread_pool.submit(_unary_response_in_pool, rpc_event, state, + method_handler.unary_unary, unary_request, + method_handler.request_deserializer, + method_handler.response_serializer) def _handle_unary_stream(rpc_event, state, method_handler, thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) - thread_pool.submit(_stream_response_in_pool, rpc_event, state, - method_handler.unary_stream, unary_request, - method_handler.request_deserializer, - method_handler.response_serializer) + return thread_pool.submit(_stream_response_in_pool, rpc_event, state, + method_handler.unary_stream, unary_request, + method_handler.request_deserializer, + method_handler.response_serializer) def _handle_stream_unary(rpc_event, state, method_handler, thread_pool): request_iterator = _RequestIterator(state, rpc_event.operation_call, method_handler.request_deserializer) - thread_pool.submit(_unary_response_in_pool, rpc_event, state, - method_handler.stream_unary, lambda: request_iterator, - method_handler.request_deserializer, - method_handler.response_serializer) + return thread_pool.submit( + _unary_response_in_pool, rpc_event, state, method_handler.stream_unary, + lambda: request_iterator, method_handler.request_deserializer, + method_handler.response_serializer) def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): request_iterator = _RequestIterator(state, rpc_event.operation_call, method_handler.request_deserializer) - thread_pool.submit(_stream_response_in_pool, rpc_event, state, - method_handler.stream_stream, lambda: request_iterator, - method_handler.request_deserializer, - method_handler.response_serializer) + return thread_pool.submit( + _stream_response_in_pool, rpc_event, state, + method_handler.stream_stream, lambda: request_iterator, + method_handler.request_deserializer, method_handler.response_serializer) def _find_method_handler(rpc_event, generic_handlers): @@ -549,13 +549,12 @@ def _find_method_handler(rpc_event, generic_handlers): return None -def _handle_unrecognized_method(rpc_event): +def _reject_rpc(rpc_event, status, details): operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA, _EMPTY_FLAGS), cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - _common.EMPTY_METADATA, cygrpc.StatusCode.unimplemented, - b'Method not found!', _EMPTY_FLAGS),) + _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),) rpc_state = _RPCState() rpc_event.operation_call.start_server_batch( operations, lambda ignored_event: (rpc_state, (),)) @@ -572,33 +571,37 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool): state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) if method_handler.request_streaming: if method_handler.response_streaming: - _handle_stream_stream(rpc_event, state, method_handler, - thread_pool) + return state, _handle_stream_stream(rpc_event, state, + method_handler, thread_pool) else: - _handle_stream_unary(rpc_event, state, method_handler, - thread_pool) + return state, _handle_stream_unary(rpc_event, state, + method_handler, thread_pool) else: if method_handler.response_streaming: - _handle_unary_stream(rpc_event, state, method_handler, - thread_pool) + return state, _handle_unary_stream(rpc_event, state, + method_handler, thread_pool) else: - _handle_unary_unary(rpc_event, state, method_handler, - thread_pool) - return state + return state, _handle_unary_unary(rpc_event, state, + method_handler, thread_pool) -def _handle_call(rpc_event, generic_handlers, thread_pool): +def _handle_call(rpc_event, generic_handlers, thread_pool, + concurrency_exceeded): if not rpc_event.success: - return None + return None, None if rpc_event.request_call_details.method is not None: method_handler = _find_method_handler(rpc_event, generic_handlers) if method_handler is None: - return _handle_unrecognized_method(rpc_event) + return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, + b'Method not found!'), None + elif concurrency_exceeded: + return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, + b'Concurrent RPC limit exceeded!'), None else: return _handle_with_method_handler(rpc_event, method_handler, thread_pool) else: - return None + return None, None @enum.unique @@ -610,7 +613,8 @@ class _ServerStage(enum.Enum): class _ServerState(object): - def __init__(self, completion_queue, server, generic_handlers, thread_pool): + def __init__(self, completion_queue, server, generic_handlers, thread_pool, + maximum_concurrent_rpcs): self.lock = threading.Lock() self.completion_queue = completion_queue self.server = server @@ -618,6 +622,8 @@ class _ServerState(object): self.thread_pool = thread_pool self.stage = _ServerStage.STOPPED self.shutdown_events = None + self.maximum_concurrent_rpcs = maximum_concurrent_rpcs + self.active_rpc_count = 0 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. self.rpc_states = set() @@ -657,6 +663,11 @@ def _stop_serving(state): return False +def _on_call_completed(state): + with state.lock: + state.active_rpc_count -= 1 + + def _serve(state): while True: event = state.completion_queue.poll() @@ -668,10 +679,18 @@ def _serve(state): elif event.tag is _REQUEST_CALL_TAG: with state.lock: state.due.remove(_REQUEST_CALL_TAG) - rpc_state = _handle_call(event, state.generic_handlers, - state.thread_pool) + concurrency_exceeded = ( + state.maximum_concurrent_rpcs is not None and + state.active_rpc_count >= state.maximum_concurrent_rpcs) + rpc_state, rpc_future = _handle_call( + event, state.generic_handlers, state.thread_pool, + concurrency_exceeded) if rpc_state is not None: state.rpc_states.add(rpc_state) + if rpc_future is not None: + state.active_rpc_count += 1 + rpc_future.add_done_callback( + lambda unused_future: _on_call_completed(state)) if state.stage is _ServerStage.STARTED: _request_call(state) elif _stop_serving(state): @@ -749,12 +768,13 @@ def _start(state): class Server(grpc.Server): - def __init__(self, thread_pool, generic_handlers, options): + def __init__(self, thread_pool, generic_handlers, options, + maximum_concurrent_rpcs): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_common.channel_args(options)) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, - thread_pool) + thread_pool, maximum_concurrent_rpcs) def add_generic_rpc_handlers(self, generic_rpc_handlers): _add_generic_handlers(self._state, generic_rpc_handlers) |