# Copyright 2018 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests server and client side compression.""" import logging import threading import time import unittest import grpc from tests.unit import test_common from tests.unit.framework.common import test_constants _BEAT = 0.5 _SOME_TIME = 5 _MORE_TIME = 10 class _MethodHandler(grpc.RpcMethodHandler): request_streaming = True response_streaming = True request_deserializer = None response_serializer = None def stream_stream(self, request_iterator, servicer_context): for request in request_iterator: yield request * 2 _METHOD_HANDLER = _MethodHandler() class _GenericHandler(grpc.GenericRpcHandler): def service(self, handler_call_details): return _METHOD_HANDLER _GENERIC_HANDLER = _GenericHandler() class _Pipe(object): def __init__(self, values): self._condition = threading.Condition() self._values = list(values) self._open = True def __iter__(self): return self def _next(self): with self._condition: while not self._values and self._open: self._condition.wait() if self._values: return self._values.pop(0) else: raise StopIteration() def next(self): return self._next() def __next__(self): return self._next() def add(self, value): with self._condition: self._values.append(value) self._condition.notify() def close(self): with self._condition: self._open = False self._condition.notify() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() class ChannelCloseTest(unittest.TestCase): def setUp(self): self._server = test_common.test_server( max_workers=test_constants.THREAD_CONCURRENCY) self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) self._port = self._server.add_insecure_port('[::]:0') self._server.start() def tearDown(self): self._server.stop(None) def test_close_immediately_after_call_invocation(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) multi_callable = channel.stream_stream('Meffod') request_iterator = _Pipe(()) response_iterator = multi_callable(request_iterator) channel.close() request_iterator.close() self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_close_while_call_active(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) multi_callable = channel.stream_stream('Meffod') request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) channel.close() request_iterator.close() self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_context_manager_close_while_call_active(self): with grpc.insecure_channel('localhost:{}'.format( self._port)) as channel: # pylint: disable=bad-continuation multi_callable = channel.stream_stream('Meffod') request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) request_iterator.close() self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_context_manager_close_while_many_calls_active(self): with grpc.insecure_channel('localhost:{}'.format( self._port)) as channel: # pylint: disable=bad-continuation multi_callable = channel.stream_stream('Meffod') request_iterators = tuple( _Pipe((b'abc',)) for _ in range(test_constants.THREAD_CONCURRENCY)) response_iterators = [] for request_iterator in request_iterators: response_iterator = multi_callable(request_iterator) next(response_iterator) response_iterators.append(response_iterator) for request_iterator in request_iterators: request_iterator.close() for response_iterator in response_iterators: self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_many_concurrent_closes(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) multi_callable = channel.stream_stream('Meffod') request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) start = time.time() end = start + _MORE_TIME def sleep_some_time_then_close(): time.sleep(_SOME_TIME) channel.close() for _ in range(test_constants.THREAD_CONCURRENCY): close_thread = threading.Thread(target=sleep_some_time_then_close) close_thread.start() while True: request_iterator.add(b'def') time.sleep(_BEAT) if end < time.time(): break request_iterator.close() self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) if __name__ == '__main__': logging.basicConfig() unittest.main(verbosity=2)