aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio_tests/tests/unit/_early_ok_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio_tests/tests/unit/_early_ok_test.py')
-rw-r--r--src/python/grpcio_tests/tests/unit/_early_ok_test.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_early_ok_test.py b/src/python/grpcio_tests/tests/unit/_early_ok_test.py
new file mode 100644
index 0000000000..041532c661
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_early_ok_test.py
@@ -0,0 +1,206 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests servicers sending OK status without having read all requests.
+
+This is a regression test of https://github.com/grpc/grpc/issues/6891.
+"""
+
+import enum
+import unittest
+
+import six
+
+import grpc
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_RPC_METHOD = '/serffice/Meffod'
+
+
+@enum.unique
+class _MessageCount(enum.Enum):
+
+ ZERO = (
+ 0,
+ 'Zero',
+ )
+ TWO = (
+ 1,
+ 'Two',
+ )
+ MANY = (
+ test_constants.STREAM_LENGTH,
+ 'Many',
+ )
+
+
+@enum.unique
+class _MessageSize(enum.Enum):
+ EMPTY = (
+ 0,
+ 'Empty',
+ )
+ SMALL = (
+ 32,
+ 'Small',
+ ) # Smaller than any flow control window.
+ LARGE = (
+ 3 * 1024 * 1024,
+ 'Large',
+ ) # Larger than any flow control window.
+
+
+_ZERO_MESSAGE = b''
+_SMALL_MESSAGE = b'\x07' * _MessageSize.SMALL.value[0]
+_LARGE_MESSAGE = b'abc' * (_MessageSize.LARGE.value[0] // 3)
+
+
+@enum.unique
+class _ReadRequests(enum.Enum):
+
+ ZERO = (
+ 0,
+ 'Zero',
+ )
+ TWO = (
+ 2,
+ 'Two',
+ )
+
+
+class _Case(object):
+
+ def __init__(self, request_count, request_size, request_reading,
+ response_count, response_size):
+ self.request_count = request_count
+ self.request_size = request_size
+ self.request_reading = request_reading
+ self.response_count = response_count
+ self.response_size = response_size
+
+ def create_test_case_name(self):
+ return '{}{}Requests{}Read{}{}ResponsesEarlyOKTest'.format(
+ self.request_count.value[1], self.request_size.value[1],
+ self.request_reading.value[1], self.response_count.value[1],
+ self.response_size.value[1])
+
+
+def _message(message_size):
+ if message_size is _MessageSize.EMPTY:
+ return _ZERO_MESSAGE
+ elif message_size is _MessageSize.SMALL:
+ return _SMALL_MESSAGE
+ elif message_size is _MessageSize.LARGE:
+ return _LARGE_MESSAGE
+
+
+def _messages_to_send(count, size):
+ for _ in range(count.value[0]):
+ yield _message(size)
+
+
+def _draw_requests(case, request_iterator):
+ for _ in range(
+ min(case.request_count.value[0], case.request_reading.value[0])):
+ next(request_iterator)
+
+
+def _draw_responses(case, response_iterator):
+ for _ in range(case.response_count.value[0]):
+ next(response_iterator)
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, case):
+ self.request_streaming = True
+ self.response_streaming = True
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self._case = case
+
+ def stream_stream(self, request_iterator, servicer_context):
+ _draw_requests(self._case, request_iterator)
+
+ for response in _messages_to_send(self._case.response_count,
+ self._case.response_size):
+ yield response
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, case):
+ self._case = case
+
+ def service(self, handler_call_details):
+ return _MethodHandler(self._case)
+
+
+class _EarlyOkTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server = test_common.test_server()
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(self.case),))
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+ self._multi_callable = self._channel.stream_stream(_RPC_METHOD)
+
+ def tearDown(self):
+ self._server.stop(None)
+
+ def test_early_ok(self):
+ requests = _messages_to_send(self.case.request_count,
+ self.case.request_size)
+
+ response_iterator_call = self._multi_callable(requests)
+
+ _draw_responses(self.case, response_iterator_call)
+
+ self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
+
+
+def _cases():
+ for request_count in _MessageCount:
+ for request_size in _MessageSize:
+ for request_reading in _ReadRequests:
+ for response_count in _MessageCount:
+ for response_size in _MessageSize:
+ yield _Case(request_count, request_size,
+ request_reading, response_count,
+ response_size)
+
+
+def _test_case_classes():
+ for case in _cases():
+ yield type(case.create_test_case_name(), (_EarlyOkTest,), {
+ 'case': case,
+ '__module__': _EarlyOkTest.__module__,
+ })
+
+
+def load_tests(loader, tests, pattern):
+ return unittest.TestSuite(
+ tests=tuple(
+ loader.loadTestsFromTestCase(test_case_class)
+ for test_case_class in _test_case_classes()))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)