From 61b8c8920698b949872480540912ba2bf022c9df Mon Sep 17 00:00:00 2001 From: ncteisen Date: Wed, 7 Dec 2016 15:28:35 -0800 Subject: Add check on return value from start_client_batch --- src/python/grpcio/grpc/_channel.py | 46 +++++- src/python/grpcio_tests/tests/tests.json | 1 + .../tests/unit/_invalid_metadata_test.py | 179 +++++++++++++++++++++ 3 files changed, 220 insertions(+), 6 deletions(-) create mode 100644 src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 3ac735a4ec..1298cfbb6f 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -99,6 +99,22 @@ def _wait_once_until(condition, until): else: condition.wait(timeout=remaining) +_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( + 'Internal gRPC call error %d. ' + + 'Please report to https://github.com/grpc/grpc/issues') + +def _check_call_error(call_error, metadata): + if call_error == cygrpc.CallError.invalid_metadata: + raise ValueError('metadata was invalid: %s' % metadata) + elif call_error != cygrpc.CallError.ok: + raise ValueError(_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) + +def _call_error_set_RPCstate(state, call_error, metadata): + if call_error == cygrpc.CallError.invalid_metadata: + _abort(state, grpc.StatusCode.INTERNAL, 'metadata was invalid: %s' % metadata) + else: + _abort(state, grpc.StatusCode.INTERNAL, + _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) class _RPCState(object): @@ -472,7 +488,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): None, 0, completion_queue, self._method, None, deadline_timespec) if credentials is not None: call.set_credentials(credentials._credentials) - call.start_client_batch(cygrpc.Operations(operations), None) + call_error = call.start_client_batch(cygrpc.Operations(operations), None) + _check_call_error(call_error, metadata) _handle_event(completion_queue.poll(), state, self._response_deserializer) return state, deadline @@ -496,7 +513,11 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): call.set_credentials(credentials._credentials) event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call.start_client_batch(cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(cygrpc.Operations(operations), + event_handler) + if call_error != cygrpc.CallError.ok: + _call_error_set_RPCstate(state, call_error, metadata) + return _Rendezvous(state, None, None, deadline) drive_call() return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -536,7 +557,11 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_client_batch(cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(cygrpc.Operations(operations), + event_handler) + if call_error != cygrpc.CallError.ok: + _call_error_set_RPCstate(state, call_error, metadata) + return _Rendezvous(state, None, None, deadline) drive_call() return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -571,7 +596,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_client_batch(cygrpc.Operations(operations), None) + call_error = call.start_client_batch(cygrpc.Operations(operations), None) + _check_call_error(call_error, metadata) _consume_request_iterator( request_iterator, state, call, self._request_serializer) while True: @@ -615,7 +641,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_client_batch(cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(cygrpc.Operations(operations), + event_handler) + if call_error != cygrpc.CallError.ok: + _call_error_set_RPCstate(state, call_error, metadata) + return _Rendezvous(state, None, None, deadline) drive_call() _consume_request_iterator( request_iterator, state, call, self._request_serializer) @@ -652,7 +682,11 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_client_batch(cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(cygrpc.Operations(operations), + event_handler) + if call_error != cygrpc.CallError.ok: + _call_error_set_RPCstate(state, call_error, metadata) + return _Rendezvous(state, None, None, deadline) drive_call() _consume_request_iterator( request_iterator, state, call, self._request_serializer) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index d0cf2b5779..0eea66da40 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -25,6 +25,7 @@ "_implementations_test.CallCredentialsTest", "_implementations_test.ChannelCredentialsTest", "_insecure_interop_test.InsecureInteropTest", + "_invalid_metadata_test.InvalidMetadataTest" "_logging_pool_test.LoggingPoolTest", "_metadata_code_details_test.MetadataCodeDetailsTest", "_metadata_test.MetadataTest", diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py new file mode 100644 index 0000000000..0411c58900 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -0,0 +1,179 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test of RPCs made against gRPC Python's application-layer API.""" + +import unittest + +import grpc + +from tests.unit.framework.common import test_constants + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class InvalidMetadataTest(unittest.TestCase): + + def setUp(self): + self._channel = grpc.insecure_channel('localhost:8080') + self._unary_unary = _unary_unary_multi_callable(self._channel) + self._unary_stream = _unary_stream_multi_callable(self._channel) + self._stream_unary = _stream_unary_multi_callable(self._channel) + self._stream_stream = _stream_stream_multi_callable(self._channel) + + def testUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_unary(request, metadata=metadata) + self.assertEqual( + expected_error_details, exception_context.exception.message) + + def testUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.with_call(request, metadata=metadata) + self.assertEqual( + expected_error_details, exception_context.exception.message) + + def testUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + response_future = self._unary_unary.future(request, metadata=metadata) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertEqual( + exception_context.exception.details(), expected_error_details) + self.assertEqual( + exception_context.exception.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(response_future.details(), expected_error_details) + self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + + def testUnaryRequestStreamResponse(self): + request = b'\x37\x58' + metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + response_iterator = self._unary_stream(request, metadata=metadata) + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertEqual( + exception_context.exception.details(), expected_error_details) + self.assertEqual( + exception_context.exception.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(response_iterator.details(), expected_error_details) + self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + + def testStreamRequestBlockingUnaryResponse(self): + request_iterator = (b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._stream_unary(request_iterator, metadata=metadata) + self.assertEqual( + expected_error_details, exception_context.exception.message) + + def testStreamRequestBlockingUnaryResponseWithCall(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),) + expected_error_details = "metadata was invalid: %s" % metadata + multi_callable = _stream_unary_multi_callable(self._channel) + with self.assertRaises(ValueError) as exception_context: + multi_callable.with_call(request_iterator, metadata=metadata) + self.assertEqual( + expected_error_details, exception_context.exception.message) + + def testStreamRequestFutureUnaryResponse(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + response_future = self._stream_unary.future( + request_iterator, metadata=metadata) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertEqual( + exception_context.exception.details(), expected_error_details) + self.assertEqual( + exception_context.exception.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(response_future.details(), expected_error_details) + self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + + def testStreamRequestStreamResponse(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + response_iterator = self._stream_stream(request_iterator, metadata=metadata) + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertEqual( + exception_context.exception.details(), expected_error_details) + self.assertEqual( + exception_context.exception.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(response_iterator.details(), expected_error_details) + self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + + +if __name__ == '__main__': + unittest.main(verbosity=2) -- cgit v1.2.3