diff options
-rw-r--r-- | src/python/grpcio/grpc/_server.py | 6 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py | 200 |
2 files changed, 119 insertions, 87 deletions
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 9402941bab..56122fee11 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -277,6 +277,12 @@ class _Context(grpc.ServicerContext): self._state.trailing_metadata = trailing_metadata def abort(self, code, details): + # treat OK like other invalid arguments: fail the RPC + if code == grpc.StatusCode.OK: + logging.error( + 'abort() called with StatusCode.OK; returning UNKNOWN') + code = grpc.StatusCode.UNKNOWN + details = '' with self._state.condition: self._state.code = code self._state.details = _common.encode(details) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index bb6ac70497..ca10bd4dab 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -50,6 +50,12 @@ _SERVER_TRAILING_METADATA = (('server-trailing-md-key', _NON_OK_CODE = grpc.StatusCode.NOT_FOUND _DETAILS = 'Test details!' +# calling abort should always fail an RPC, even for "invalid" codes +_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK) +_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN, + grpc.StatusCode.UNKNOWN) +_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '') + class _Servicer(object): @@ -302,99 +308,119 @@ class MetadataCodeDetailsTest(unittest.TestCase): self.assertEqual(_DETAILS, response_iterator_call.details()) def testAbortedUnaryUnary(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - with self.assertRaises(grpc.RpcError) as exception_context: - self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) - self.assertIs(_NON_OK_CODE, exception_context.exception.code()) - self.assertEqual(_DETAILS, exception_context.exception.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) def testAbortedUnaryStream(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - response_iterator_call = self._unary_stream( - _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = response_iterator_call.initial_metadata() - with self.assertRaises(grpc.RpcError): - self.assertEqual(len(list(response_iterator_call)), 0) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, response_iterator_call.code()) - self.assertEqual(_DETAILS, response_iterator_call.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) def testAbortedStreamUnary(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call( - iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) - self.assertIs(_NON_OK_CODE, exception_context.exception.code()) - self.assertEqual(_DETAILS, exception_context.exception.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) def testAbortedStreamStream(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - response_iterator_call = self._stream_stream( - iter([object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) - received_initial_metadata = response_iterator_call.initial_metadata() - with self.assertRaises(grpc.RpcError): - self.assertEqual(len(list(response_iterator_call)), 0) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, response_iterator_call.code()) - self.assertEqual(_DETAILS, response_iterator_call.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) |