diff options
author | Mehrdad Afshari <mmx@google.com> | 2017-12-11 07:14:21 -0800 |
---|---|---|
committer | Mehrdad Afshari <mmx@google.com> | 2017-12-12 12:01:07 -0800 |
commit | fdfaf1b12eb718268bd8186a00429019eb230d12 (patch) | |
tree | 1dcf6e6ea4d2858cdebafdf693bfda54d1152f40 /src | |
parent | 108500f194693f0bf1cf4eebcd7ef292fa611287 (diff) |
Add tests for gRPC Python interceptor machinery
Diffstat (limited to 'src')
-rw-r--r-- | src/python/grpcio_tests/tests/tests.json | 1 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_interceptor_test.py | 571 |
2 files changed, 572 insertions, 0 deletions
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 34cbade92c..3bf5308749 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -39,6 +39,7 @@ "unit._cython.cygrpc_test.TypeSmokeTest", "unit._empty_message_test.EmptyMessageTest", "unit._exit_test.ExitTest", + "unit._interceptor_test.InterceptorTest", "unit._invalid_metadata_test.InvalidMetadataTest", "unit._invocation_defects_test.InvocationDefectsTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest", diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py new file mode 100644 index 0000000000..cf875ed7da --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -0,0 +1,571 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python interceptors.""" + +import collections +import itertools +import threading +import unittest +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', + 'credentials')), grpc.ClientCallDetails): + pass + + +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + + def __init__(self, interceptor_function): + self._fn = interceptor_function + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, + request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +class _LoggingInterceptor( + grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor): + + def __init__(self, tag, record): + self.tag = tag + self.record = record + + def intercept_service(self, continuation, handler_call_details): + self.record.append(self.tag + ':intercept_service') + return continuation(handler_call_details) + + def intercept_unary_unary(self, continuation, client_call_details, request): + self.record.append(self.tag + ':intercept_unary_unary') + return continuation(client_call_details, request) + + def intercept_unary_stream(self, continuation, client_call_details, + request): + self.record.append(self.tag + ':intercept_unary_stream') + return continuation(client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_unary') + return continuation(client_call_details, request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_stream') + return continuation(client_call_details, request_iterator) + + +class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): + + def intercept_unary_unary(self, ignored_continuation, + ignored_client_call_details, ignored_request): + raise test_control.Defect() + + +def _wrap_request_iterator_stream_interceptor(wrapper): + + def intercept_call(client_call_details, request_iterator, request_streaming, + ignored_response_streaming): + if request_streaming: + return client_call_details, wrapper(request_iterator), None + else: + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +def _append_request_header_interceptor(header, value): + + def intercept_call(client_call_details, request_iterator, + ignored_request_streaming, ignored_response_streaming): + metadata = [] + if client_call_details.metadata: + metadata = list(client_call_details.metadata) + metadata.append((header, value,)) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +class _GenericServerInterceptor(grpc.ServerInterceptor): + + def __init__(self, fn): + self._fn = fn + + def intercept_service(self, continuation, handler_call_details): + return self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition, interceptor): + + def intercept_service(continuation, handler_call_details): + if condition(handler_call_details): + return interceptor.intercept_service(continuation, + handler_call_details) + return continuation(handler_call_details) + + return _GenericServerInterceptor(intercept_service) + + +class InterceptorTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingInterceptor('s3', self._record)) + + self._server = grpc.server( + self._server_pool, + interceptors=(_LoggingInterceptor('s1', self._record), + conditional_interceptor, + _LoggingInterceptor('s2', self._record),)) + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._server_pool.shutdown(wait=True) + + def testTripleRequestMessagesClientInterceptor(self): + + def triple(request_iterator): + while True: + try: + item = next(request_iterator) + yield item + yield item + yield item + except StopIteration: + break + + interceptor = _wrap_request_iterator_stream_interceptor(triple) + channel = grpc.intercept_channel(self._channel, interceptor) + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), test_constants.STREAM_LENGTH) + + def testDefectiveClientInterceptor(self): + interceptor = _DefectiveClientInterceptor() + defective_channel = grpc.intercept_channel(self._channel, interceptor) + + request = b'\x07\x08' + + multi_callable = _unary_unary_multi_callable(defective_channel) + call_future = multi_callable.future( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(call_future.exception()) + self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) + + def testInterceptedHeaderManipulationWithServerSideVerification(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, _append_request_header_interceptor('secret', '42')) + channel = grpc.intercept_channel( + channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's3:intercept_service', + 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + response_future = multi_callable.future( + request, + metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestStreamResponse(self): + request = b'\x37\x58' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable( + request_iterator, + metadata=( + ('test', 'InterceptedStreamRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable.with_call( + request_iterator, + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestStreamResponse(self): + requests = tuple(b'\x77\x58' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + +if __name__ == '__main__': + unittest.main(verbosity=2) |