diff options
Diffstat (limited to 'src/python/grpcio_tests')
11 files changed, 811 insertions, 183 deletions
diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py index 99ca3fd82d..b1b4d7e0c2 100644 --- a/src/python/grpcio_tests/grpc_version.py +++ b/src/python/grpcio_tests/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! -VERSION='1.9.0.dev0' +VERSION = '1.9.0.dev0' diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 34cbade92c..3bf5308749 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -39,6 +39,7 @@ "unit._cython.cygrpc_test.TypeSmokeTest", "unit._empty_message_test.EmptyMessageTest", "unit._exit_test.ExitTest", + "unit._interceptor_test.InterceptorTest", "unit._invalid_metadata_test.InvalidMetadataTest", "unit._invocation_defects_test.InvocationDefectsTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest", diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index b14e8d5c75..d6f4447532 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -33,18 +33,21 @@ class AllTest(unittest.TestCase): 'AuthMetadataPlugin', 'ServerCertificateConfiguration', 'ServerCredentials', 'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', 'Channel', 'ServicerContext', + 'StreamStreamMultiCallable', 'UnaryUnaryClientInterceptor', + 'UnaryStreamClientInterceptor', 'StreamUnaryClientInterceptor', + 'StreamStreamClientInterceptor', 'Channel', 'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', - 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', - 'stream_unary_rpc_method_handler', + 'ServiceRpcHandler', 'Server', 'ServerInterceptor', + 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', + 'stream_unary_rpc_method_handler', 'ClientCallDetails', 'stream_stream_rpc_method_handler', 'method_handlers_generic_handler', 'ssl_channel_credentials', 'metadata_call_credentials', 'access_token_call_credentials', 'composite_call_credentials', 'composite_channel_credentials', 'ssl_server_credentials', 'ssl_server_certificate_configuration', 'dynamic_ssl_server_credentials', 'channel_ready_future', - 'insecure_channel', 'secure_channel', 'server',) + 'insecure_channel', 'secure_channel', 'intercept_channel', + 'server',) six.assertCountEqual(self, expected_grpc_code_elements, _from_grpc_import_star.GRPC_ELEMENTS) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 5b97b7b542..a8a7175cc7 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants _INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) _EMPTY_FLAGS = 0 -_EMPTY_METADATA = cygrpc.Metadata(()) +_EMPTY_METADATA = () _SERVER_SHUTDOWN_TAG = 'server_shutdown' _REQUEST_CALL_TAG = 'request_call' @@ -65,12 +65,10 @@ class _Handler(object): with self._lock: self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),), _RECEIVE_CLOSE_ON_SERVER_TAG) self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_message(_EMPTY_FLAGS),), _RECEIVE_MESSAGE_TAG) first_event = self._completion_queue.poll() if _is_cancellation_event(first_event): @@ -84,8 +82,8 @@ class _Handler(object): cygrpc.operation_send_status_from_server( _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', _EMPTY_FLAGS),) - self._call.start_server_batch( - cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) + self._call.start_server_batch(operations, + _SERVER_COMPLETE_CALL_TAG) self._completion_queue.poll() self._completion_queue.poll() @@ -179,8 +177,7 @@ class CancelManyCallsTest(unittest.TestCase): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch( - cygrpc.Operations(operations), tag) + client_call.start_client_batch(operations, tag) client_due.add(tag) client_calls.append(client_call) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index ac66d1db3d..96f0f1589b 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -23,17 +23,14 @@ RPC_COUNT = 4000 INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) EMPTY_FLAGS = 0 -INVOCATION_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'client-md-key', b'client-md-key'), - cygrpc.Metadatum(b'client-md-key-bin', b'\x00\x01' * 3000),)) +INVOCATION_METADATA = (('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01' * 3000),) -INITIAL_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'server-initial-md-key', b'server-initial-md-value'), - cygrpc.Metadatum(b'server-initial-md-key-bin', b'\x00\x02' * 3000),)) +INITIAL_METADATA = (('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02' * 3000),) -TRAILING_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'server-trailing-md-key', b'server-trailing-md-value'), - cygrpc.Metadatum(b'server-trailing-md-key-bin', b'\x00\x03' * 3000),)) +TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03' * 3000),) class QueueDriver(object): diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 14cc66675c..d08003af44 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -48,20 +48,19 @@ class Test(_common.RpcTest, unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with self.client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata( - _common.EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata( + _common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_complete_rpc_start_batch_result = client_call.start_client_batch( - cygrpc.Operations([ + [ cygrpc.operation_send_initial_metadata( _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), cygrpc.operation_send_close_from_client( _common.EMPTY_FLAGS), cygrpc.operation_receive_status_on_client( _common.EMPTY_FLAGS), - ]), client_complete_rpc_tag) + ], client_complete_rpc_tag) self.client_driver.add_due({ client_receive_initial_metadata_tag, client_complete_rpc_tag, diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 1e44bcc4dc..d0166a2b29 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -43,20 +43,19 @@ class Test(_common.RpcTest, unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with self.client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata( - _common.EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata( + _common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_complete_rpc_start_batch_result = client_call.start_client_batch( - cygrpc.Operations([ + [ cygrpc.operation_send_initial_metadata( _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), cygrpc.operation_send_close_from_client( _common.EMPTY_FLAGS), cygrpc.operation_receive_status_on_client( _common.EMPTY_FLAGS), - ]), client_complete_rpc_tag) + ], client_complete_rpc_tag) self.client_driver.add_due({ client_receive_initial_metadata_tag, client_complete_rpc_tag, diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 0105612b47..1deb15ba03 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -20,7 +20,7 @@ from grpc._cython import cygrpc _INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) _EMPTY_FLAGS = 0 -_EMPTY_METADATA = cygrpc.Metadata(()) +_EMPTY_METADATA = () class _ServerDriver(object): @@ -157,19 +157,17 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_due.add(client_receive_initial_metadata_tag) client_complete_rpc_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), - cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), - ]), client_complete_rpc_tag)) + client_call.start_client_batch([ + cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, + _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ], client_complete_rpc_tag)) client_due.add(client_complete_rpc_tag) server_rpc_event = server_driver.first_event() @@ -197,8 +195,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - cygrpc.Metadata(()), cygrpc.StatusCode.ok, - b'test details', _EMPTY_FLAGS), + (), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), ], server_complete_rpc_tag)) server_send_second_message_event = server_call_driver.event_with_tag( server_send_second_message_tag) @@ -209,10 +207,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): with client_condition: client_receive_first_message_tag = 'client_receive_first_message_tag' client_receive_first_message_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_message(_EMPTY_FLAGS), - ]), client_receive_first_message_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ], client_receive_first_message_tag)) client_due.add(client_receive_first_message_tag) client_receive_first_message_event = client_driver.event_with_tag( client_receive_first_message_tag) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index da94cf8028..4eda685486 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -29,50 +29,12 @@ _EMPTY_FLAGS = 0 def _metadata_plugin(context, callback): - callback( - cygrpc.Metadata([ - cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY, - _CALL_CREDENTIALS_METADATA_VALUE) - ]), cygrpc.StatusCode.ok, b'') + callback(((_CALL_CREDENTIALS_METADATA_KEY, + _CALL_CREDENTIALS_METADATA_VALUE,),), cygrpc.StatusCode.ok, b'') class TypeSmokeTest(unittest.TestCase): - def testStringsInUtilitiesUpDown(self): - self.assertEqual(0, cygrpc.StatusCode.ok) - metadatum = cygrpc.Metadatum(b'a', b'b') - self.assertEqual(b'a', metadatum.key) - self.assertEqual(b'b', metadatum.value) - metadata = cygrpc.Metadata([metadatum]) - self.assertEqual(1, len(metadata)) - self.assertEqual(metadatum.key, metadata[0].key) - - def testMetadataIteration(self): - metadata = cygrpc.Metadata( - [cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')]) - iterator = iter(metadata) - metadatum = next(iterator) - self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, b'a') - self.assertEqual(metadatum.value, b'b') - metadatum = next(iterator) - self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, b'c') - self.assertEqual(metadatum.value, b'd') - with self.assertRaises(StopIteration): - next(iterator) - - def testOperationsIteration(self): - operations = cygrpc.Operations( - [cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)]) - iterator = iter(operations) - operation = next(iterator) - self.assertIsInstance(operation, cygrpc.Operation) - # `Operation`s are write-only structures; can't directly debug anything out - # of them. Just check that we stop iterating. - with self.assertRaises(StopIteration): - next(iterator) - def testOperationFlags(self): operation = cygrpc.operation_send_message(b'asdf', cygrpc.WriteFlag.no_compress) @@ -182,8 +144,7 @@ class ServerClientMixin(object): def performer(): tag = object() try: - call_result = call.start_client_batch( - cygrpc.Operations(operations), tag) + call_result = call.start_client_batch(operations, tag) self.assertEqual(cygrpc.CallError.ok, call_result) event = queue.poll(deadline) self.assertEqual(cygrpc.CompletionType.operation_complete, @@ -200,14 +161,14 @@ class ServerClientMixin(object): def test_echo(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 - CLIENT_METADATA_ASCII_KEY = b'key' - CLIENT_METADATA_ASCII_VALUE = b'val' - CLIENT_METADATA_BIN_KEY = b'key-bin' + CLIENT_METADATA_ASCII_KEY = 'key' + CLIENT_METADATA_ASCII_VALUE = 'val' + CLIENT_METADATA_BIN_KEY = 'key-bin' CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 - SERVER_INITIAL_METADATA_KEY = b'init_me_me_me' - SERVER_INITIAL_METADATA_VALUE = b'whodawha?' - SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought' - SERVER_TRAILING_METADATA_VALUE = b'zomg it is' + SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' + SERVER_INITIAL_METADATA_VALUE = 'whodawha?' + SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' + SERVER_TRAILING_METADATA_VALUE = 'zomg it is' SERVER_STATUS_CODE = cygrpc.StatusCode.ok SERVER_STATUS_DETAILS = b'our work is never over' REQUEST = b'in death a member of project mayhem has a name' @@ -227,11 +188,9 @@ class ServerClientMixin(object): client_call = self.client_channel.create_call( None, 0, self.client_completion_queue, METHOD, self.host_argument, cygrpc_deadline) - client_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, - CLIENT_METADATA_ASCII_VALUE), - cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE) - ]) + client_initial_metadata = ( + (CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE,), + (CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE,),) client_start_batch_result = client_call.start_client_batch([ cygrpc.operation_send_initial_metadata(client_initial_metadata, _EMPTY_FLAGS), @@ -263,14 +222,10 @@ class ServerClientMixin(object): server_call_tag = object() server_call = request_event.operation_call - server_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY, - SERVER_INITIAL_METADATA_VALUE) - ]) - server_trailing_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY, - SERVER_TRAILING_METADATA_VALUE) - ]) + server_initial_metadata = ( + (SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE,),) + server_trailing_metadata = ( + (SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE,),) server_start_batch_result = server_call.start_server_batch([ cygrpc.operation_send_initial_metadata( server_initial_metadata, @@ -347,7 +302,7 @@ class ServerClientMixin(object): METHOD = b'twinkies' cygrpc_deadline = cygrpc.Timespec(DEADLINE) - empty_metadata = cygrpc.Metadata([]) + empty_metadata = () server_request_tag = object() self.server.request_call(self.server_completion_queue, diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py new file mode 100644 index 0000000000..cf875ed7da --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -0,0 +1,571 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python interceptors.""" + +import collections +import itertools +import threading +import unittest +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', + 'credentials')), grpc.ClientCallDetails): + pass + + +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + + def __init__(self, interceptor_function): + self._fn = interceptor_function + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, + request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +class _LoggingInterceptor( + grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor): + + def __init__(self, tag, record): + self.tag = tag + self.record = record + + def intercept_service(self, continuation, handler_call_details): + self.record.append(self.tag + ':intercept_service') + return continuation(handler_call_details) + + def intercept_unary_unary(self, continuation, client_call_details, request): + self.record.append(self.tag + ':intercept_unary_unary') + return continuation(client_call_details, request) + + def intercept_unary_stream(self, continuation, client_call_details, + request): + self.record.append(self.tag + ':intercept_unary_stream') + return continuation(client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_unary') + return continuation(client_call_details, request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_stream') + return continuation(client_call_details, request_iterator) + + +class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): + + def intercept_unary_unary(self, ignored_continuation, + ignored_client_call_details, ignored_request): + raise test_control.Defect() + + +def _wrap_request_iterator_stream_interceptor(wrapper): + + def intercept_call(client_call_details, request_iterator, request_streaming, + ignored_response_streaming): + if request_streaming: + return client_call_details, wrapper(request_iterator), None + else: + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +def _append_request_header_interceptor(header, value): + + def intercept_call(client_call_details, request_iterator, + ignored_request_streaming, ignored_response_streaming): + metadata = [] + if client_call_details.metadata: + metadata = list(client_call_details.metadata) + metadata.append((header, value,)) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +class _GenericServerInterceptor(grpc.ServerInterceptor): + + def __init__(self, fn): + self._fn = fn + + def intercept_service(self, continuation, handler_call_details): + return self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition, interceptor): + + def intercept_service(continuation, handler_call_details): + if condition(handler_call_details): + return interceptor.intercept_service(continuation, + handler_call_details) + return continuation(handler_call_details) + + return _GenericServerInterceptor(intercept_service) + + +class InterceptorTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingInterceptor('s3', self._record)) + + self._server = grpc.server( + self._server_pool, + interceptors=(_LoggingInterceptor('s1', self._record), + conditional_interceptor, + _LoggingInterceptor('s2', self._record),)) + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._server_pool.shutdown(wait=True) + + def testTripleRequestMessagesClientInterceptor(self): + + def triple(request_iterator): + while True: + try: + item = next(request_iterator) + yield item + yield item + yield item + except StopIteration: + break + + interceptor = _wrap_request_iterator_stream_interceptor(triple) + channel = grpc.intercept_channel(self._channel, interceptor) + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), test_constants.STREAM_LENGTH) + + def testDefectiveClientInterceptor(self): + interceptor = _DefectiveClientInterceptor() + defective_channel = grpc.intercept_channel(self._channel, interceptor) + + request = b'\x07\x08' + + multi_callable = _unary_unary_multi_callable(defective_channel) + call_future = multi_callable.future( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(call_future.exception()) + self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) + + def testInterceptedHeaderManipulationWithServerSideVerification(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, _append_request_header_interceptor('secret', '42')) + channel = grpc.intercept_channel( + channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's3:intercept_service', + 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + response_future = multi_callable.future( + request, + metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestStreamResponse(self): + request = b'\x37\x58' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable( + request_iterator, + metadata=( + ('test', 'InterceptedStreamRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable.with_call( + request_iterator, + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestStreamResponse(self): + requests = tuple(b'\x77\x58' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 6faab94be6..cb59cd3769 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -56,6 +56,7 @@ class _Servicer(object): def __init__(self): self._lock = threading.Lock() + self._abort_call = False self._code = None self._details = None self._exception = False @@ -67,10 +68,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -81,10 +85,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 2): yield _SERIALIZED_RESPONSE if self._exception: @@ -95,14 +102,16 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -113,19 +122,25 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 3): yield object() if self._exception: raise test_control.Defect() + def set_abort_call(self): + with self._lock: + self._abort_call = True + def set_code(self, code): with self._lock: self._code = code @@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulUnaryStream(self): self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testSuccessfulStreamUnary(self): self._servicer.set_details(_DETAILS) @@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulStreamStream(self): self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError) as exception_context: - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeExceptionStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeReturnNoneUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) |