diff options
author | 2017-12-10 18:27:21 -0800 | |
---|---|---|
committer | 2017-12-10 20:45:18 -0800 | |
commit | 90ab995cb0f3ef29c1f284a4c361a9c2750ef2dd (patch) | |
tree | d317a8a1ed8249c70e08f597af1ee93cb81198ad /src/python/grpcio_tests/tests | |
parent | 9bc44e38296fe5c8001929b61851d0a1ba326a15 (diff) |
Tests for ServicerContext.abort
Diffstat (limited to 'src/python/grpcio_tests/tests')
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py | 166 |
1 files changed, 134 insertions, 32 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 6faab94be6..68df6643e7 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -56,6 +56,7 @@ class _Servicer(object): def __init__(self): self._lock = threading.Lock() + self._abort_call = False self._code = None self._details = None self._exception = False @@ -67,10 +68,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -81,10 +85,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 2): yield _SERIALIZED_RESPONSE if self._exception: @@ -95,14 +102,16 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -113,19 +122,25 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 3): yield object() if self._exception: raise test_control.Defect() + def set_abort_call(self): + with self._lock: + self._abort_call = True + def set_code(self, code): with self._lock: self._code = code @@ -215,8 +230,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( @@ -256,8 +270,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( @@ -271,6 +284,99 @@ class MetadataCodeDetailsTest(unittest.TestCase): self.assertIs(grpc.StatusCode.OK, call.code()) self.assertEqual(_DETAILS, call.details()) + def testAbortedUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, + call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testAbortedStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, + call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) @@ -300,8 +406,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( @@ -347,8 +452,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() with self.assertRaises(grpc.RpcError) as exception_context: - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( @@ -394,8 +498,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( @@ -443,8 +546,7 @@ class MetadataCodeDetailsTest(unittest.TestCase): metadata=_CLIENT_METADATA) received_initial_metadata = call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(call) self.assertTrue( test_common.metadata_transmitted( |