diff options
Diffstat (limited to 'src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py')
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py | 237 |
1 files changed, 173 insertions, 64 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 6faab94be6..cb59cd3769 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -56,6 +56,7 @@ class _Servicer(object): def __init__(self): self._lock = threading.Lock() + self._abort_call = False self._code = None self._details = None self._exception = False @@ -67,10 +68,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -81,10 +85,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 2): yield _SERIALIZED_RESPONSE if self._exception: @@ -95,14 +102,16 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -113,19 +122,25 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 3): yield object() if self._exception: raise test_control.Defect() + def set_abort_call(self): + with self._lock: + self._abort_call = True + def set_code(self, code): with self._lock: self._code = code @@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulUnaryStream(self): self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testSuccessfulStreamUnary(self): self._servicer.set_details(_DETAILS) @@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulStreamStream(self): self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError) as exception_context: - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeExceptionStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeReturnNoneUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) |