# Copyright 2017 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test of gRPC Python interceptors.""" import collections import itertools import threading import unittest import logging from concurrent import futures import grpc from grpc.framework.foundation import logging_pool from tests.unit import test_common from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_control _SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' _STREAM_UNARY = '/test/StreamUnary' _STREAM_STREAM = '/test/StreamStream' class _Callback(object): def __init__(self): self._condition = threading.Condition() self._value = None self._called = False def __call__(self, value): with self._condition: self._value = value self._called = True self._condition.notify_all() def value(self): with self._condition: while not self._called: self._condition.wait() return self._value class _Handler(object): def __init__(self, control): self._control = control def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: servicer_context.set_trailing_metadata((( 'testkey', 'testvalue', ),)) return request def handle_unary_stream(self, request, servicer_context): for _ in range(test_constants.STREAM_LENGTH): self._control.control() yield request self._control.control() if servicer_context is not None: servicer_context.set_trailing_metadata((( 'testkey', 'testvalue', ),)) def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: servicer_context.invocation_metadata() self._control.control() response_elements = [] for request in request_iterator: self._control.control() response_elements.append(request) self._control.control() if servicer_context is not None: servicer_context.set_trailing_metadata((( 'testkey', 'testvalue', ),)) return b''.join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: servicer_context.set_trailing_metadata((( 'testkey', 'testvalue', ),)) for request in request_iterator: self._control.control() yield request self._control.control() class _MethodHandler(grpc.RpcMethodHandler): def __init__(self, request_streaming, response_streaming, request_deserializer, response_serializer, unary_unary, unary_stream, stream_unary, stream_stream): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = request_deserializer self.response_serializer = response_serializer self.unary_unary = unary_unary self.unary_stream = unary_stream self.stream_unary = stream_unary self.stream_stream = stream_stream class _GenericHandler(grpc.GenericRpcHandler): def __init__(self, handler): self._handler = handler def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return _MethodHandler(False, False, None, None, self._handler.handle_unary_unary, None, None, None) elif handler_call_details.method == _UNARY_STREAM: return _MethodHandler(False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, self._handler.handle_unary_stream, None, None) elif handler_call_details.method == _STREAM_UNARY: return _MethodHandler(True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None, self._handler.handle_stream_unary, None) elif handler_call_details.method == _STREAM_STREAM: return _MethodHandler(True, True, None, None, None, None, None, self._handler.handle_stream_stream) else: return None def _unary_unary_multi_callable(channel): return channel.unary_unary(_UNARY_UNARY) def _unary_stream_multi_callable(channel): return channel.unary_stream( _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE) def _stream_unary_multi_callable(channel): return channel.stream_unary( _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE) def _stream_stream_multi_callable(channel): return channel.stream_stream(_STREAM_STREAM) class _ClientCallDetails( collections.namedtuple( '_ClientCallDetails', ('method', 'timeout', 'metadata', 'credentials')), grpc.ClientCallDetails): pass class _GenericClientInterceptor( grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): def __init__(self, interceptor_function): self._fn = interceptor_function def intercept_unary_unary(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( client_call_details, iter((request,)), False, False) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response def intercept_unary_stream(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( client_call_details, iter((request,)), False, True) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it def intercept_stream_unary(self, continuation, client_call_details, request_iterator): new_details, new_request_iterator, postprocess = self._fn( client_call_details, request_iterator, True, False) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response def intercept_stream_stream(self, continuation, client_call_details, request_iterator): new_details, new_request_iterator, postprocess = self._fn( client_call_details, request_iterator, True, True) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it class _LoggingInterceptor( grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): def __init__(self, tag, record): self.tag = tag self.record = record def intercept_service(self, continuation, handler_call_details): self.record.append(self.tag + ':intercept_service') return continuation(handler_call_details) def intercept_unary_unary(self, continuation, client_call_details, request): self.record.append(self.tag + ':intercept_unary_unary') return continuation(client_call_details, request) def intercept_unary_stream(self, continuation, client_call_details, request): self.record.append(self.tag + ':intercept_unary_stream') return continuation(client_call_details, request) def intercept_stream_unary(self, continuation, client_call_details, request_iterator): self.record.append(self.tag + ':intercept_stream_unary') return continuation(client_call_details, request_iterator) def intercept_stream_stream(self, continuation, client_call_details, request_iterator): self.record.append(self.tag + ':intercept_stream_stream') return continuation(client_call_details, request_iterator) class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): def intercept_unary_unary(self, ignored_continuation, ignored_client_call_details, ignored_request): raise test_control.Defect() def _wrap_request_iterator_stream_interceptor(wrapper): def intercept_call(client_call_details, request_iterator, request_streaming, ignored_response_streaming): if request_streaming: return client_call_details, wrapper(request_iterator), None else: return client_call_details, request_iterator, None return _GenericClientInterceptor(intercept_call) def _append_request_header_interceptor(header, value): def intercept_call(client_call_details, request_iterator, ignored_request_streaming, ignored_response_streaming): metadata = [] if client_call_details.metadata: metadata = list(client_call_details.metadata) metadata.append(( header, value, )) client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, metadata, client_call_details.credentials) return client_call_details, request_iterator, None return _GenericClientInterceptor(intercept_call) class _GenericServerInterceptor(grpc.ServerInterceptor): def __init__(self, fn): self._fn = fn def intercept_service(self, continuation, handler_call_details): return self._fn(continuation, handler_call_details) def _filter_server_interceptor(condition, interceptor): def intercept_service(continuation, handler_call_details): if condition(handler_call_details): return interceptor.intercept_service(continuation, handler_call_details) return continuation(handler_call_details) return _GenericServerInterceptor(intercept_service) class InterceptorTest(unittest.TestCase): def setUp(self): self._control = test_control.PauseFailControl() self._handler = _Handler(self._control) self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) self._record = [] conditional_interceptor = _filter_server_interceptor( lambda x: ('secret', '42') in x.invocation_metadata, _LoggingInterceptor('s3', self._record)) self._server = grpc.server( self._server_pool, options=(('grpc.so_reuseport', 0),), interceptors=( _LoggingInterceptor('s1', self._record), conditional_interceptor, _LoggingInterceptor('s2', self._record), )) port = self._server.add_insecure_port('[::]:0') self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() self._channel = grpc.insecure_channel('localhost:%d' % port) def tearDown(self): self._server.stop(None) self._server_pool.shutdown(wait=True) self._channel.close() def testTripleRequestMessagesClientInterceptor(self): def triple(request_iterator): while True: try: item = next(request_iterator) yield item yield item yield item except StopIteration: break interceptor = _wrap_request_iterator_stream_interceptor(triple) channel = grpc.intercept_channel(self._channel, interceptor) requests = tuple( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) multi_callable = _stream_stream_multi_callable(channel) response_iterator = multi_callable( iter(requests), metadata=( ('test', 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) responses = tuple(response_iterator) self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) multi_callable = _stream_stream_multi_callable(self._channel) response_iterator = multi_callable( iter(requests), metadata=( ('test', 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) responses = tuple(response_iterator) self.assertEqual(len(responses), test_constants.STREAM_LENGTH) def testDefectiveClientInterceptor(self): interceptor = _DefectiveClientInterceptor() defective_channel = grpc.intercept_channel(self._channel, interceptor) request = b'\x07\x08' multi_callable = _unary_unary_multi_callable(defective_channel) call_future = multi_callable.future( request, metadata=(('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) self.assertIsNotNone(call_future.exception()) self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) def testInterceptedHeaderManipulationWithServerSideVerification(self): request = b'\x07\x08' channel = grpc.intercept_channel(self._channel, _append_request_header_interceptor( 'secret', '42')) channel = grpc.intercept_channel(channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) self._record[:] = [] multi_callable = _unary_unary_multi_callable(channel) multi_callable.with_call( request, metadata=( ('test', 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) self.assertSequenceEqual(self._record, [ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', 's1:intercept_service', 's3:intercept_service', 's2:intercept_service' ]) def testInterceptedUnaryRequestBlockingUnaryResponse(self): request = b'\x07\x08' self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _unary_unary_multi_callable(channel) multi_callable( request, metadata=(('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) self.assertSequenceEqual(self._record, [ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): request = b'\x07\x08' channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) self._record[:] = [] multi_callable = _unary_unary_multi_callable(channel) multi_callable.with_call( request, metadata=( ('test', 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) self.assertSequenceEqual(self._record, [ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedUnaryRequestFutureUnaryResponse(self): request = b'\x07\x08' self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _unary_unary_multi_callable(channel) response_future = multi_callable.future( request, metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) response_future.result() self.assertSequenceEqual(self._record, [ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedUnaryRequestStreamResponse(self): request = b'\x37\x58' self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _unary_stream_multi_callable(channel) response_iterator = multi_callable( request, metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) tuple(response_iterator) self.assertSequenceEqual(self._record, [ 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedStreamRequestBlockingUnaryResponse(self): requests = tuple( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _stream_unary_multi_callable(channel) multi_callable( request_iterator, metadata=(('test', 'InterceptedStreamRequestBlockingUnaryResponse'),)) self.assertSequenceEqual(self._record, [ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): requests = tuple( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _stream_unary_multi_callable(channel) multi_callable.with_call( request_iterator, metadata=( ('test', 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) self.assertSequenceEqual(self._record, [ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedStreamRequestFutureUnaryResponse(self): requests = tuple( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _stream_unary_multi_callable(channel) response_future = multi_callable.future( request_iterator, metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) response_future.result() self.assertSequenceEqual(self._record, [ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', 's1:intercept_service', 's2:intercept_service' ]) def testInterceptedStreamRequestStreamResponse(self): requests = tuple( b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel(self._channel, _LoggingInterceptor( 'c1', self._record), _LoggingInterceptor( 'c2', self._record)) multi_callable = _stream_stream_multi_callable(channel) response_iterator = multi_callable( request_iterator, metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) tuple(response_iterator) self.assertSequenceEqual(self._record, [ 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', 's1:intercept_service', 's2:intercept_service' ]) if __name__ == '__main__': logging.basicConfig() unittest.main(verbosity=2)