aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorGravatar Ken Payson <kpayson@google.com>2017-03-07 11:48:35 -0800
committerGravatar Ken Payson <kpayson@google.com>2017-03-17 10:38:50 -0700
commit39a5932097b5a2ed9481cd9660522658ee96fc65 (patch)
treeaf8d2b3b2ee8ad4fef5cf6e4050a2a794d4f27e8 /src
parenteb064ec7b81b60c5e1eb47d6124d0c05056b3097 (diff)
Add max_requests argument to server
If the server is already serving max requests, return RESOURCE_EXHAUSTED
Diffstat (limited to 'src')
-rw-r--r--src/python/grpcio/grpc/__init__.py11
-rw-r--r--src/python/grpcio/grpc/_server.py94
-rw-r--r--src/python/grpcio_tests/tests/tests.json1
-rw-r--r--src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py270
4 files changed, 337 insertions, 39 deletions
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index a4481b2ac3..4960df3be9 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -1273,7 +1273,10 @@ def secure_channel(target, credentials, options=None):
credentials._credentials)
-def server(thread_pool, handlers=None, options=None):
+def server(thread_pool,
+ handlers=None,
+ options=None,
+ maximum_concurrent_rpcs=None):
"""Creates a Server with which RPCs can be serviced.
Args:
@@ -1286,13 +1289,17 @@ def server(thread_pool, handlers=None, options=None):
returned Server is started.
options: A sequence of string-value pairs according to which to configure
the created server.
+ maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
+ will service before returning status RESOURCE_EXHAUSTED, or None to
+ indicate no limit.
Returns:
A Server with which RPCs can be serviced.
"""
from grpc import _server # pylint: disable=cyclic-import
return _server.Server(thread_pool, () if handlers is None else handlers, ()
- if options is None else options)
+ if options is None else options,
+ maximum_concurrent_rpcs)
################################### __all__ #################################
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index 84e096d4c0..47838c2c98 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -504,37 +504,37 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
- thread_pool.submit(_unary_response_in_pool, rpc_event, state,
- method_handler.unary_unary, unary_request,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
+ method_handler.unary_unary, unary_request,
+ method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
- thread_pool.submit(_stream_response_in_pool, rpc_event, state,
- method_handler.unary_stream, unary_request,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
+ method_handler.unary_stream, unary_request,
+ method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
- thread_pool.submit(_unary_response_in_pool, rpc_event, state,
- method_handler.stream_unary, lambda: request_iterator,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(
+ _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
+ lambda: request_iterator, method_handler.request_deserializer,
+ method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
- thread_pool.submit(_stream_response_in_pool, rpc_event, state,
- method_handler.stream_stream, lambda: request_iterator,
- method_handler.request_deserializer,
- method_handler.response_serializer)
+ return thread_pool.submit(
+ _stream_response_in_pool, rpc_event, state,
+ method_handler.stream_stream, lambda: request_iterator,
+ method_handler.request_deserializer, method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers):
@@ -549,13 +549,12 @@ def _find_method_handler(rpc_event, generic_handlers):
return None
-def _handle_unrecognized_method(rpc_event):
+def _reject_rpc(rpc_event, status, details):
operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
- _common.EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
- b'Method not found!', _EMPTY_FLAGS),)
+ _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
rpc_state = _RPCState()
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
@@ -572,33 +571,37 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
- _handle_stream_stream(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_stream_stream(rpc_event, state,
+ method_handler, thread_pool)
else:
- _handle_stream_unary(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_stream_unary(rpc_event, state,
+ method_handler, thread_pool)
else:
if method_handler.response_streaming:
- _handle_unary_stream(rpc_event, state, method_handler,
- thread_pool)
+ return state, _handle_unary_stream(rpc_event, state,
+ method_handler, thread_pool)
else:
- _handle_unary_unary(rpc_event, state, method_handler,
- thread_pool)
- return state
+ return state, _handle_unary_unary(rpc_event, state,
+ method_handler, thread_pool)
-def _handle_call(rpc_event, generic_handlers, thread_pool):
+def _handle_call(rpc_event, generic_handlers, thread_pool,
+ concurrency_exceeded):
if not rpc_event.success:
- return None
+ return None, None
if rpc_event.request_call_details.method is not None:
method_handler = _find_method_handler(rpc_event, generic_handlers)
if method_handler is None:
- return _handle_unrecognized_method(rpc_event)
+ return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
+ b'Method not found!'), None
+ elif concurrency_exceeded:
+ return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
+ b'Concurrent RPC limit exceeded!'), None
else:
return _handle_with_method_handler(rpc_event, method_handler,
thread_pool)
else:
- return None
+ return None, None
@enum.unique
@@ -610,7 +613,8 @@ class _ServerStage(enum.Enum):
class _ServerState(object):
- def __init__(self, completion_queue, server, generic_handlers, thread_pool):
+ def __init__(self, completion_queue, server, generic_handlers, thread_pool,
+ maximum_concurrent_rpcs):
self.lock = threading.Lock()
self.completion_queue = completion_queue
self.server = server
@@ -618,6 +622,8 @@ class _ServerState(object):
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.shutdown_events = None
+ self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
+ self.active_rpc_count = 0
# TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
self.rpc_states = set()
@@ -657,6 +663,11 @@ def _stop_serving(state):
return False
+def _on_call_completed(state):
+ with state.lock:
+ state.active_rpc_count -= 1
+
+
def _serve(state):
while True:
event = state.completion_queue.poll()
@@ -668,10 +679,18 @@ def _serve(state):
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
- rpc_state = _handle_call(event, state.generic_handlers,
- state.thread_pool)
+ concurrency_exceeded = (
+ state.maximum_concurrent_rpcs is not None and
+ state.active_rpc_count >= state.maximum_concurrent_rpcs)
+ rpc_state, rpc_future = _handle_call(
+ event, state.generic_handlers, state.thread_pool,
+ concurrency_exceeded)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
+ if rpc_future is not None:
+ state.active_rpc_count += 1
+ rpc_future.add_done_callback(
+ lambda unused_future: _on_call_completed(state))
if state.stage is _ServerStage.STARTED:
_request_call(state)
elif _stop_serving(state):
@@ -749,12 +768,13 @@ def _start(state):
class Server(grpc.Server):
- def __init__(self, thread_pool, generic_handlers, options):
+ def __init__(self, thread_pool, generic_handlers, options,
+ maximum_concurrent_rpcs):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
- thread_pool)
+ thread_pool, maximum_concurrent_rpcs)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_add_generic_handlers(self._state, generic_rpc_handlers)
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index 70d965d3ca..f750b05102 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -31,6 +31,7 @@
"unit._invocation_defects_test.InvocationDefectsTest",
"unit._metadata_code_details_test.MetadataCodeDetailsTest",
"unit._metadata_test.MetadataTest",
+ "unit._resource_exhausted_test.ResourceExhaustedTest",
"unit._rpc_test.RPCTest",
"unit._sanity._sanity_test.Sanity",
"unit._thread_cleanup_test.CleanupThreadTest",
diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
new file mode 100644
index 0000000000..88c82b5541
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
@@ -0,0 +1,270 @@
+# Copyright 2017, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests server responding with RESOURCE_EXHAUSTED."""
+
+import threading
+import unittest
+
+import grpc
+from grpc import _channel
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x00\x00\x00'
+
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
+
+
+class _TestTrigger(object):
+
+ def __init__(self, total_call_count):
+ self._total_call_count = total_call_count
+ self._pending_calls = 0
+ self._triggered = False
+ self._finish_condition = threading.Condition()
+ self._start_condition = threading.Condition()
+
+ # Wait for all calls be be blocked in their handler
+ def await_calls(self):
+ with self._start_condition:
+ while self._pending_calls < self._total_call_count:
+ self._start_condition.wait()
+
+ # Block in a response handler and wait for a trigger
+ def await_trigger(self):
+ with self._start_condition:
+ self._pending_calls += 1
+ self._start_condition.notify()
+
+ with self._finish_condition:
+ if not self._triggered:
+ self._finish_condition.wait()
+
+ # Finish all response handlers
+ def trigger(self):
+ with self._finish_condition:
+ self._triggered = True
+ self._finish_condition.notify_all()
+
+
+def handle_unary_unary(trigger, request, servicer_context):
+ trigger.await_trigger()
+ return _RESPONSE
+
+
+def handle_unary_stream(trigger, request, servicer_context):
+ trigger.await_trigger()
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
+
+
+def handle_stream_unary(trigger, request_iterator, servicer_context):
+ trigger.await_trigger()
+ # TODO(issue:#6891) We should be able to remove this loop
+ for request in request_iterator:
+ pass
+ return _RESPONSE
+
+
+def handle_stream_stream(trigger, request_iterator, servicer_context):
+ trigger.await_trigger()
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ for request in request_iterator:
+ yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, trigger, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = (
+ lambda x, y: handle_stream_stream(trigger, x, y))
+ elif self.request_streaming:
+ self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
+ elif self.response_streaming:
+ self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
+ else:
+ self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, trigger):
+ self._trigger = trigger
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(self._trigger, False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(self._trigger, False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(self._trigger, True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(self._trigger, True, True)
+ else:
+ return None
+
+
+class ResourceExhaustedTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ self._server_pool,
+ handlers=(_GenericHandler(self._trigger),),
+ maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ multi_callable = self._channel.unary_unary(_UNARY_UNARY)
+ futures = []
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ futures.append(multi_callable.future(_REQUEST))
+
+ self._trigger.await_calls()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(_REQUEST)
+
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ exception_context.exception.code())
+
+ future_exception = multi_callable.future(_REQUEST)
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ future_exception.exception().code())
+
+ self._trigger.trigger()
+ for future in futures:
+ self.assertEqual(_RESPONSE, future.result())
+
+ # Ensure a new request can be handled
+ self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
+
+ def testUnaryStream(self):
+ multi_callable = self._channel.unary_stream(_UNARY_STREAM)
+ calls = []
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ calls.append(multi_callable(_REQUEST))
+
+ self._trigger.await_calls()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(multi_callable(_REQUEST))
+
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ exception_context.exception.code())
+
+ self._trigger.trigger()
+
+ for call in calls:
+ for response in call:
+ self.assertEqual(_RESPONSE, response)
+
+ # Ensure a new request can be handled
+ new_call = multi_callable(_REQUEST)
+ for response in new_call:
+ self.assertEqual(_RESPONSE, response)
+
+ def testStreamUnary(self):
+ multi_callable = self._channel.stream_unary(_STREAM_UNARY)
+ futures = []
+ request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ futures.append(multi_callable.future(request))
+
+ self._trigger.await_calls()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(request)
+
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ exception_context.exception.code())
+
+ future_exception = multi_callable.future(request)
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ future_exception.exception().code())
+
+ self._trigger.trigger()
+
+ for future in futures:
+ self.assertEqual(_RESPONSE, future.result())
+
+ # Ensure a new request can be handled
+ self.assertEqual(_RESPONSE, multi_callable(request))
+
+ def testStreamStream(self):
+ multi_callable = self._channel.stream_stream(_STREAM_STREAM)
+ calls = []
+ request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ calls.append(multi_callable(request))
+
+ self._trigger.await_calls()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(multi_callable(request))
+
+ self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+ exception_context.exception.code())
+
+ self._trigger.trigger()
+
+ for call in calls:
+ for response in call:
+ self.assertEqual(_RESPONSE, response)
+
+ # Ensure a new request can be handled
+ new_call = multi_callable(request)
+ for response in new_call:
+ self.assertEqual(_RESPONSE, response)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)