aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio/grpc/_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio/grpc/_server.py')
-rw-r--r--src/python/grpcio/grpc/_server.py94
1 files changed, 57 insertions, 37 deletions
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index 84e096d4c0..47838c2c98 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -504,37 +504,37 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
- thread_pool.submit(_unary_response_in_pool, rpc_event, state,
- method_handler.unary_unary, unary_request,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
+ method_handler.unary_unary, unary_request,
+ method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
- thread_pool.submit(_stream_response_in_pool, rpc_event, state,
- method_handler.unary_stream, unary_request,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
+ method_handler.unary_stream, unary_request,
+ method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
- thread_pool.submit(_unary_response_in_pool, rpc_event, state,
- method_handler.stream_unary, lambda: request_iterator,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(
+ _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
+ lambda: request_iterator, method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
- thread_pool.submit(_stream_response_in_pool, rpc_event, state,
- method_handler.stream_stream, lambda: request_iterator,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(
+ _stream_response_in_pool, rpc_event, state,
+ method_handler.stream_stream, lambda: request_iterator,
+ method_handler.request_deserializer, method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers):
@@ -549,13 +549,12 @@ def _find_method_handler(rpc_event, generic_handlers):
return None
-def _handle_unrecognized_method(rpc_event):
+def _reject_rpc(rpc_event, status, details):
operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
- _common.EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
- b'Method not found!', _EMPTY_FLAGS),)
+ _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
rpc_state = _RPCState()
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
@@ -572,33 +571,37 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
- _handle_stream_stream(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_stream_stream(rpc_event, state,
+ method_handler, thread_pool)
else:
- _handle_stream_unary(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_stream_unary(rpc_event, state,
+ method_handler, thread_pool)
else:
if method_handler.response_streaming:
- _handle_unary_stream(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_unary_stream(rpc_event, state,
+ method_handler, thread_pool)
else:
- _handle_unary_unary(rpc_event, state, method_handler,
- thread_pool)
- return state
+ return state, _handle_unary_unary(rpc_event, state,
+ method_handler, thread_pool)
-def _handle_call(rpc_event, generic_handlers, thread_pool):
+def _handle_call(rpc_event, generic_handlers, thread_pool,
+ concurrency_exceeded):
if not rpc_event.success:
- return None
+ return None, None
if rpc_event.request_call_details.method is not None:
method_handler = _find_method_handler(rpc_event, generic_handlers)
if method_handler is None:
- return _handle_unrecognized_method(rpc_event)
+ return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
+ b'Method not found!'), None
+ elif concurrency_exceeded:
+ return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
+ b'Concurrent RPC limit exceeded!'), None
else:
return _handle_with_method_handler(rpc_event, method_handler,
thread_pool)
else:
- return None
+ return None, None
@enum.unique
@@ -610,7 +613,8 @@ class _ServerStage(enum.Enum):
class _ServerState(object):
- def __init__(self, completion_queue, server, generic_handlers, thread_pool):
+ def __init__(self, completion_queue, server, generic_handlers, thread_pool,
+ maximum_concurrent_rpcs):
self.lock = threading.Lock()
self.completion_queue = completion_queue
self.server = server
@@ -618,6 +622,8 @@ class _ServerState(object):
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.shutdown_events = None
+ self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
+ self.active_rpc_count = 0
# TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
self.rpc_states = set()
@@ -657,6 +663,11 @@ def _stop_serving(state):
return False
+def _on_call_completed(state):
+ with state.lock:
+ state.active_rpc_count -= 1
+
+
def _serve(state):
while True:
event = state.completion_queue.poll()
@@ -668,10 +679,18 @@ def _serve(state):
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
- rpc_state = _handle_call(event, state.generic_handlers,
- state.thread_pool)
+ concurrency_exceeded = (
+ state.maximum_concurrent_rpcs is not None and
+ state.active_rpc_count >= state.maximum_concurrent_rpcs)
+ rpc_state, rpc_future = _handle_call(
+ event, state.generic_handlers, state.thread_pool,
+ concurrency_exceeded)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
+ if rpc_future is not None:
+ state.active_rpc_count += 1
+ rpc_future.add_done_callback(
+ lambda unused_future: _on_call_completed(state))
if state.stage is _ServerStage.STARTED:
_request_call(state)
elif _stop_serving(state):
@@ -749,12 +768,13 @@ def _start(state):
class Server(grpc.Server):
- def __init__(self, thread_pool, generic_handlers, options):
+ def __init__(self, thread_pool, generic_handlers, options,
+ maximum_concurrent_rpcs):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
- thread_pool)
+ thread_pool, maximum_concurrent_rpcs)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_add_generic_handlers(self._state, generic_rpc_handlers)