diff options
Diffstat (limited to 'src/python')
13 files changed, 911 insertions, 528 deletions
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 2eff08aa57..6604f8f35c 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -79,27 +79,6 @@ def _wait_once_until(condition, until): condition.wait(timeout=remaining) -_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( - 'Internal gRPC call error %d. ' + - 'Please report to https://github.com/grpc/grpc/issues') - - -def _check_call_error(call_error, metadata): - if call_error == cygrpc.CallError.invalid_metadata: - raise ValueError('metadata was invalid: %s' % metadata) - elif call_error != cygrpc.CallError.ok: - raise ValueError(_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) - - -def _call_error_set_RPCstate(state, call_error, metadata): - if call_error == cygrpc.CallError.invalid_metadata: - _abort(state, grpc.StatusCode.INTERNAL, - 'metadata was invalid: %s' % metadata) - else: - _abort(state, grpc.StatusCode.INTERNAL, - _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) - - class _RPCState(object): def __init__(self, due, initial_metadata, trailing_metadata, code, details): @@ -163,7 +142,7 @@ def _handle_event(event, state, response_deserializer): return callbacks -def _event_handler(state, call, response_deserializer): +def _event_handler(state, response_deserializer): def handle_event(event): with state.condition: @@ -172,40 +151,47 @@ def _event_handler(state, call, response_deserializer): done = not state.due for callback in callbacks: callback() - return call if done else None + return done return handle_event -def _consume_request_iterator(request_iterator, state, call, - request_serializer): - event_handler = _event_handler(state, call, None) +def _consume_request_iterator(request_iterator, state, call, request_serializer, + event_handler): - def consume_request_iterator(): + def consume_request_iterator(): # pylint: disable=too-many-branches while True: try: request = next(request_iterator) except StopIteration: break except Exception: # pylint: disable=broad-except - logging.exception("Exception iterating requests!") - call.cancel() - _abort(state, grpc.StatusCode.UNKNOWN, - "Exception iterating requests!") + code = grpc.StatusCode.UNKNOWN + details = 'Exception iterating requests!' + logging.exception(details) + call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) return serialized_request = _common.serialize(request, request_serializer) with state.condition: if state.code is None and not state.cancelled: if serialized_request is None: - call.cancel() + code = grpc.StatusCode.INTERNAL # pylint: disable=redefined-variable-type details = 'Exception serializing request!' - _abort(state, grpc.StatusCode.INTERNAL, details) + call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) return else: operations = (cygrpc.SendMessageOperation( serialized_request, _EMPTY_FLAGS),) - call.start_client_batch(operations, event_handler) - state.due.add(cygrpc.OperationType.send_message) + operating = call.operate(operations, event_handler) + if operating: + state.due.add(cygrpc.OperationType.send_message) + else: + return while True: state.condition.wait() if state.code is None: @@ -219,15 +205,19 @@ def _consume_request_iterator(request_iterator, state, call, if state.code is None: operations = ( cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),) - call.start_client_batch(operations, event_handler) - state.due.add(cygrpc.OperationType.send_close_from_client) + operating = call.operate(operations, event_handler) + if operating: + state.due.add(cygrpc.OperationType.send_close_from_client) def stop_consumption_thread(timeout): # pylint: disable=unused-argument with state.condition: if state.code is None: - call.cancel() + code = grpc.StatusCode.CANCELLED + details = 'Consumption thread cleaned up!' + call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) state.cancelled = True - _abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!') + _abort(state, code, details) state.condition.notify_all() consumption_thread = _common.CleanupThread( @@ -247,9 +237,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def cancel(self): with self._state.condition: if self._state.code is None: - self._call.cancel() + code = grpc.StatusCode.CANCELLED + details = 'Locally cancelled by application!' + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details) self._state.cancelled = True - _abort(self._state, grpc.StatusCode.CANCELLED, 'Cancelled!') + _abort(self._state, code, details) self._state.condition.notify_all() return False @@ -318,12 +311,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def _next(self): with self._state.condition: if self._state.code is None: - event_handler = _event_handler(self._state, self._call, + event_handler = _event_handler(self._state, self._response_deserializer) - self._call.start_client_batch( + operating = self._call.operate( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), event_handler) - self._state.due.add(cygrpc.OperationType.receive_message) + if operating: + self._state.due.add(cygrpc.OperationType.receive_message) elif self._state.code is grpc.StatusCode.OK: raise StopIteration() else: @@ -408,9 +402,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def __del__(self): with self._state.condition: if self._state.code is None: - self._call.cancel() - self._state.cancelled = True self._state.code = grpc.StatusCode.CANCELLED + self._state.details = 'Cancelled upon garbage collection!' + self._state.cancelled = True + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code], + self._state.details) self._state.condition.notify_all() @@ -437,6 +434,24 @@ def _end_unary_response_blocking(state, call, with_call, deadline): raise _Rendezvous(state, None, None, deadline) +def _stream_unary_invocation_operationses(metadata): + return ( + ( + cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + + +def _stream_unary_invocation_operationses_and_tags(metadata): + return tuple(( + operations, + None, + ) for operations in _stream_unary_invocation_operationses(metadata)) + + class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def __init__(self, channel, managed_call, method, request_serializer, @@ -448,8 +463,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._response_deserializer = response_deserializer def _prepare(self, request, timeout, metadata): - deadline, serialized_request, rendezvous = (_start_unary_request( - request, timeout, self._request_serializer)) + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) if serialized_request is None: return None, None, None, rendezvous else: @@ -467,48 +482,38 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def _blocking(self, request, timeout, metadata, credentials): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata) - if rendezvous: + if state is None: raise rendezvous else: - completion_queue = cygrpc.CompletionQueue() - call = self._channel.create_call(None, 0, completion_queue, - self._method, None, deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - call_error = call.start_client_batch(operations, None) - _check_call_error(call_error, metadata) - _handle_event(completion_queue.poll(), state, - self._response_deserializer) - return state, call, deadline + call = self._channel.segregated_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, (( + operations, + None, + ),)) + event = call.next_event() + _handle_event(event, state, self._response_deserializer) + return state, call, def __call__(self, request, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request, timeout, metadata, - credentials) - return _end_unary_response_blocking(state, call, False, deadline) + state, call, = self._blocking(request, timeout, metadata, credentials) + return _end_unary_response_blocking(state, call, False, None) def with_call(self, request, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request, timeout, metadata, - credentials) - return _end_unary_response_blocking(state, call, True, deadline) + state, call, = self._blocking(request, timeout, metadata, credentials) + return _end_unary_response_blocking(state, call, True, None) def future(self, request, timeout=None, metadata=None, credentials=None): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata) - if rendezvous: - return rendezvous + if state is None: + raise rendezvous else: - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, - self._response_deserializer) - with state.condition: - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + (operations,), event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -524,34 +529,27 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._response_deserializer = response_deserializer def __call__(self, request, timeout=None, metadata=None, credentials=None): - deadline, serialized_request, rendezvous = (_start_unary_request( - request, timeout, self._request_serializer)) + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) if serialized_request is None: raise rendezvous else: state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, - self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( + operationses = ( + ( cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + operationses, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -569,49 +567,38 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): def _blocking(self, request_iterator, timeout, metadata, credentials): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) - completion_queue = cygrpc.CompletionQueue() - call = self._channel.create_call(None, 0, completion_queue, - self._method, None, deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None) - operations = ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, None) - _check_call_error(call_error, metadata) - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + call = self._channel.segregated_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + _stream_unary_invocation_operationses_and_tags(metadata)) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, None) while True: - event = completion_queue.poll() + event = call.next_event() with state.condition: _handle_event(event, state, self._response_deserializer) state.condition.notify_all() if not state.due: break - return state, call, deadline + return state, call, def __call__(self, request_iterator, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request_iterator, timeout, - metadata, credentials) - return _end_unary_response_blocking(state, call, False, deadline) + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials) + return _end_unary_response_blocking(state, call, False, None) def with_call(self, request_iterator, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request_iterator, timeout, - metadata, credentials) - return _end_unary_response_blocking(state, call, True, deadline) + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials) + return _end_unary_response_blocking(state, call, True, None) def future(self, request_iterator, @@ -620,27 +607,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): credentials=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + _stream_unary_invocation_operationses(metadata), event_handler) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -661,26 +634,20 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): credentials=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( + operationses = ( + ( cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, operationses, + event_handler) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -689,28 +656,25 @@ class _ChannelCallState(object): def __init__(self, channel): self.lock = threading.Lock() self.channel = channel - self.completion_queue = cygrpc.CompletionQueue() - self.managed_calls = None + self.managed_calls = 0 def _run_channel_spin_thread(state): def channel_spin(): while True: - event = state.completion_queue.poll() - completed_call = event.tag(event) - if completed_call is not None: + event = state.channel.next_call_event() + call_completed = event.tag(event) + if call_completed: with state.lock: - state.managed_calls.remove(completed_call) - if not state.managed_calls: - state.managed_calls = None + state.managed_calls -= 1 + if state.managed_calls == 0: return def stop_channel_spin(timeout): # pylint: disable=unused-argument with state.lock: - if state.managed_calls is not None: - for call in state.managed_calls: - call.cancel() + state.channel.close(cygrpc.StatusCode.cancelled, + 'Channel spin thread cleaned up!') channel_spin_thread = _common.CleanupThread( stop_channel_spin, target=channel_spin) @@ -719,37 +683,41 @@ def _run_channel_spin_thread(state): def _channel_managed_call_management(state): - def create(parent, flags, method, host, deadline): - """Creates a managed cygrpc.Call and a function to call to drive it. - - If operations are successfully added to the returned cygrpc.Call, the - returned function must be called. If operations are not successfully added - to the returned cygrpc.Call, the returned function must not be called. - - Args: - parent: A cygrpc.Call to be used as the parent of the created call. - flags: An integer bitfield of call flags. - method: The RPC method. - host: A host string for the created call. - deadline: A float to be the deadline of the created call or None if the - call is to have an infinite deadline. - - Returns: - A cygrpc.Call with which to conduct an RPC and a function to call if - operations are successfully started on the call. - """ - call = state.channel.create_call(parent, flags, state.completion_queue, - method, host, deadline) - - def drive(): - with state.lock: - if state.managed_calls is None: - state.managed_calls = set((call,)) - _run_channel_spin_thread(state) - else: - state.managed_calls.add(call) + # pylint: disable=too-many-arguments + def create(flags, method, host, deadline, metadata, credentials, + operationses, event_handler): + """Creates a cygrpc.IntegratedCall. - return call, drive + Args: + flags: An integer bitfield of call flags. + method: The RPC method. + host: A host string for the created call. + deadline: A float to be the deadline of the created call or None if + the call is to have an infinite deadline. + metadata: The metadata for the call or None. + credentials: A cygrpc.CallCredentials or None. + operationses: An iterable of iterables of cygrpc.Operations to be + started on the call. + event_handler: A behavior to call to handle the events resultant from + the operations on the call. + + Returns: + A cygrpc.IntegratedCall with which to conduct an RPC. + """ + operationses_and_tags = tuple(( + operations, + event_handler, + ) for operations in operationses) + with state.lock: + call = state.channel.integrated_call(flags, method, host, deadline, + metadata, credentials, + operationses_and_tags) + if state.managed_calls == 0: + state.managed_calls = 1 + _run_channel_spin_thread(state) + else: + state.managed_calls += 1 + return call return create @@ -819,12 +787,9 @@ def _poll_connectivity(state, channel, initial_try_to_connect): callback_and_connectivity[1] = state.connectivity if callbacks: _spawn_delivery(state, callbacks) - completion_queue = cygrpc.CompletionQueue() while True: - channel.watch_connectivity_state(connectivity, - time.time() + 0.2, completion_queue, - None) - event = completion_queue.poll() + event = channel.watch_connectivity_state(connectivity, + time.time() + 0.2) with state.lock: if not state.callbacks_and_connectivities and not state.try_to_connect: state.polling = False diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index 1ba76b7f83..eefc685c0b 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -13,9 +13,59 @@ # limitations under the License. +cdef _check_call_error_no_metadata(c_call_error) + + +cdef _check_and_raise_call_error_no_metadata(c_call_error) + + +cdef _check_call_error(c_call_error, metadata) + + +cdef class _CallState: + + cdef grpc_call *c_call + cdef set due + + +cdef class _ChannelState: + + cdef object condition + cdef grpc_channel *c_channel + # A boolean field indicating that the channel is open (if True) or is being + # closed (i.e. a call to close is currently executing) or is closed (if + # False). + # TODO(https://github.com/grpc/grpc/issues/3064): Eliminate "is being closed" + # a state in which condition may be acquired by any thread, eliminate this + # field and just use the NULLness of c_channel as an indication that the + # channel is closed. + cdef object open + + # A dict from _BatchOperationTag to _CallState + cdef dict integrated_call_states + cdef grpc_completion_queue *c_call_completion_queue + + # A set of _CallState + cdef set segregated_call_states + + cdef set connectivity_due + cdef grpc_completion_queue *c_connectivity_completion_queue + + +cdef class IntegratedCall: + + cdef _ChannelState _channel_state + cdef _CallState _call_state + + +cdef class SegregatedCall: + + cdef _ChannelState _channel_state + cdef _CallState _call_state + cdef grpc_completion_queue *_c_completion_queue + + cdef class Channel: cdef grpc_arg_pointer_vtable _vtable - cdef grpc_channel *c_channel - cdef list references - cdef readonly _ArgumentsProcessor _arguments_processor + cdef _ChannelState _state diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index a3966497bc..72e74e84ae 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -14,82 +14,439 @@ cimport cpython +import threading + +_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( + 'Internal gRPC call error %d. ' + + 'Please report to https://github.com/grpc/grpc/issues') + + +cdef str _call_error_metadata(metadata): + return 'metadata was invalid: %s' % metadata + + +cdef str _call_error_no_metadata(c_call_error): + return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error + + +cdef str _call_error(c_call_error, metadata): + if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA: + return _call_error_metadata(metadata) + else: + return _call_error_no_metadata(c_call_error) + + +cdef _check_call_error_no_metadata(c_call_error): + if c_call_error != GRPC_CALL_OK: + return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error + else: + return None + + +cdef _check_and_raise_call_error_no_metadata(c_call_error): + error = _check_call_error_no_metadata(c_call_error) + if error is not None: + raise ValueError(error) + + +cdef _check_call_error(c_call_error, metadata): + if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA: + return _call_error_metadata(metadata) + else: + return _check_call_error_no_metadata(c_call_error) + + +cdef void _raise_call_error_no_metadata(c_call_error) except *: + raise ValueError(_call_error_no_metadata(c_call_error)) + + +cdef void _raise_call_error(c_call_error, metadata) except *: + raise ValueError(_call_error(c_call_error, metadata)) + + +cdef _destroy_c_completion_queue(grpc_completion_queue *c_completion_queue): + grpc_completion_queue_shutdown(c_completion_queue) + grpc_completion_queue_destroy(c_completion_queue) + + +cdef class _CallState: + + def __cinit__(self): + self.due = set() + + +cdef class _ChannelState: + + def __cinit__(self): + self.condition = threading.Condition() + self.open = True + self.integrated_call_states = {} + self.segregated_call_states = set() + self.connectivity_due = set() + + +cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag = _BatchOperationTag(user_tag, operations, None) + tag.prepare() + cpython.Py_INCREF(tag) + with nogil: + c_call_error = grpc_call_start_batch( + c_call, tag.c_ops, tag.c_nops, <cpython.PyObject *>tag, NULL) + return c_call_error, tag + + +cdef object _operate_from_integrated_call( + _ChannelState channel_state, _CallState call_state, object operations, + object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag + with channel_state.condition: + if call_state.due: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + call_state.due.add(tag) + channel_state.integrated_call_states[tag] = call_state + return True + else: + _raise_call_error_no_metadata(c_call_error) + else: + return False + + +cdef object _operate_from_segregated_call( + _ChannelState channel_state, _CallState call_state, object operations, + object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag + with channel_state.condition: + if call_state.due: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + call_state.due.add(tag) + return True + else: + _raise_call_error_no_metadata(c_call_error) + else: + return False + + +cdef _cancel( + _ChannelState channel_state, _CallState call_state, grpc_status_code code, + str details): + cdef grpc_call_error c_call_error + with channel_state.condition: + if call_state.due: + c_call_error = grpc_call_cancel_with_status( + call_state.c_call, code, _encode(details), NULL) + _check_and_raise_call_error_no_metadata(c_call_error) + + +cdef BatchOperationEvent _next_call_event( + _ChannelState channel_state, grpc_completion_queue *c_completion_queue, + on_success): + tag, event = _latent_event(c_completion_queue, None) + with channel_state.condition: + on_success(tag) + channel_state.condition.notify_all() + return event + + +# TODO(https://github.com/grpc/grpc/issues/14569): This could be a lot simpler. +cdef void _call( + _ChannelState channel_state, _CallState call_state, + grpc_completion_queue *c_completion_queue, on_success, int flags, method, + host, object deadline, CallCredentials credentials, + object operationses_and_user_tags, object metadata) except *: + """Invokes an RPC. + + Args: + channel_state: A _ChannelState with its "open" attribute set to True. RPCs + may not be invoked on a closed channel. + call_state: An empty _CallState to be altered (specifically assigned a + c_call and having its due set populated) if the RPC invocation is + successful. + c_completion_queue: A grpc_completion_queue to be used for the call's + operations. + on_success: A behavior to be called if attempting to start operations for + the call succeeds. If called the behavior will be called while holding the + channel_state condition and passed the tags associated with operations + that were successfully started for the call. + flags: Flags to be passed to gRPC Core as part of call creation. + method: The fully-qualified name of the RPC method being invoked. + host: A "host" string to be passed to gRPC Core as part of call creation. + deadline: A float for the deadline of the RPC, or None if the RPC is to have + no deadline. + credentials: A _CallCredentials for the RPC or None. + operationses_and_user_tags: A sequence of length-two sequences the first + element of which is a sequence of Operations and the second element of + which is an object to be used as a tag. A SendInitialMetadataOperation + must be present in the first element of this value. + metadata: The metadata for this call. + """ + cdef grpc_slice method_slice + cdef grpc_slice host_slice + cdef grpc_slice *host_slice_ptr + cdef grpc_call_credentials *c_call_credentials + cdef grpc_call_error c_call_error + cdef tuple error_and_wrapper_tag + cdef _BatchOperationTag wrapper_tag + with channel_state.condition: + if channel_state.open: + method_slice = _slice_from_bytes(method) + if host is None: + host_slice_ptr = NULL + else: + host_slice = _slice_from_bytes(host) + host_slice_ptr = &host_slice + call_state.c_call = grpc_channel_create_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, method_slice, host_slice_ptr, + _timespec_from_time(deadline), NULL) + grpc_slice_unref(method_slice) + if host_slice_ptr: + grpc_slice_unref(host_slice) + if credentials is not None: + c_call_credentials = credentials.c() + c_call_error = grpc_call_set_credentials( + call_state.c_call, c_call_credentials) + grpc_call_credentials_release(c_call_credentials) + if c_call_error != GRPC_CALL_OK: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + _raise_call_error_no_metadata(c_call_error) + started_tags = set() + for operations, user_tag in operationses_and_user_tags: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + started_tags.add(tag) + else: + grpc_call_cancel(call_state.c_call, NULL) + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + _raise_call_error(c_call_error, metadata) + else: + call_state.due.update(started_tags) + on_success(started_tags) + else: + raise ValueError('Cannot invoke RPC on closed channel!') + +cdef void _process_integrated_call_tag( + _ChannelState state, _BatchOperationTag tag) except *: + cdef _CallState call_state = state.integrated_call_states.pop(tag) + call_state.due.remove(tag) + if not call_state.due: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + + +cdef class IntegratedCall: + + def __cinit__(self, _ChannelState channel_state, _CallState call_state): + self._channel_state = channel_state + self._call_state = call_state + + def operate(self, operations, tag): + return _operate_from_integrated_call( + self._channel_state, self._call_state, operations, tag) + + def cancel(self, code, details): + _cancel(self._channel_state, self._call_state, code, details) + + +cdef IntegratedCall _integrated_call( + _ChannelState state, int flags, method, host, object deadline, + object metadata, CallCredentials credentials, operationses_and_user_tags): + call_state = _CallState() + + def on_success(started_tags): + for started_tag in started_tags: + state.integrated_call_states[started_tag] = call_state + + _call( + state, call_state, state.c_call_completion_queue, on_success, flags, + method, host, deadline, credentials, operationses_and_user_tags, metadata) + + return IntegratedCall(state, call_state) + + +cdef object _process_segregated_call_tag( + _ChannelState state, _CallState call_state, + grpc_completion_queue *c_completion_queue, _BatchOperationTag tag): + call_state.due.remove(tag) + if not call_state.due: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + state.segregated_call_states.remove(call_state) + _destroy_c_completion_queue(c_completion_queue) + return True + else: + return False + + +cdef class SegregatedCall: + + def __cinit__(self, _ChannelState channel_state, _CallState call_state): + self._channel_state = channel_state + self._call_state = call_state + + def operate(self, operations, tag): + return _operate_from_segregated_call( + self._channel_state, self._call_state, operations, tag) + + def cancel(self, code, details): + _cancel(self._channel_state, self._call_state, code, details) + + def next_event(self): + def on_success(tag): + _process_segregated_call_tag( + self._channel_state, self._call_state, self._c_completion_queue, tag) + return _next_call_event( + self._channel_state, self._c_completion_queue, on_success) + + +cdef SegregatedCall _segregated_call( + _ChannelState state, int flags, method, host, object deadline, + object metadata, CallCredentials credentials, operationses_and_user_tags): + cdef _CallState call_state = _CallState() + cdef grpc_completion_queue *c_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + cdef SegregatedCall segregated_call + + def on_success(started_tags): + state.segregated_call_states.add(call_state) + + try: + _call( + state, call_state, c_completion_queue, on_success, flags, method, host, + deadline, credentials, operationses_and_user_tags, metadata) + except: + _destroy_c_completion_queue(c_completion_queue) + raise + + segregated_call = SegregatedCall(state, call_state) + segregated_call._c_completion_queue = c_completion_queue + return segregated_call + + +cdef object _watch_connectivity_state( + _ChannelState state, grpc_connectivity_state last_observed_state, + object deadline): + cdef _ConnectivityTag tag = _ConnectivityTag(object()) + with state.condition: + if state.open: + cpython.Py_INCREF(tag) + grpc_channel_watch_connectivity_state( + state.c_channel, last_observed_state, _timespec_from_time(deadline), + state.c_connectivity_completion_queue, <cpython.PyObject *>tag) + state.connectivity_due.add(tag) + else: + raise ValueError('Cannot invoke RPC on closed channel!') + completed_tag, event = _latent_event( + state.c_connectivity_completion_queue, None) + with state.condition: + state.connectivity_due.remove(completed_tag) + state.condition.notify_all() + return event + + +cdef _close(_ChannelState state, grpc_status_code code, object details): + cdef _CallState call_state + encoded_details = _encode(details) + with state.condition: + if state.open: + state.open = False + for call_state in set(state.integrated_call_states.values()): + grpc_call_cancel_with_status( + call_state.c_call, code, encoded_details, NULL) + for call_state in state.segregated_call_states: + grpc_call_cancel_with_status( + call_state.c_call, code, encoded_details, NULL) + # TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity + # watching. + + while state.integrated_call_states: + state.condition.wait() + while state.segregated_call_states: + state.condition.wait() + while state.connectivity_due: + state.condition.wait() + + _destroy_c_completion_queue(state.c_call_completion_queue) + _destroy_c_completion_queue(state.c_connectivity_completion_queue) + grpc_channel_destroy(state.c_channel) + state.c_channel = NULL + grpc_shutdown() + state.condition.notify_all() + else: + # Another call to close already completed in the past or is currently + # being executed in another thread. + while state.c_channel != NULL: + state.condition.wait() + cdef class Channel: - def __cinit__(self, bytes target, object arguments, - ChannelCredentials channel_credentials=None): + def __cinit__( + self, bytes target, object arguments, + ChannelCredentials channel_credentials): grpc_init() + self._state = _ChannelState() self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer self._vtable.cmp = &_compare_pointer cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( arguments) cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) - self.references = [] - c_target = target if channel_credentials is None: - self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, NULL) + self._state.c_channel = grpc_insecure_channel_create( + <char *>target, c_arguments, NULL) else: c_channel_credentials = channel_credentials.c() - self.c_channel = grpc_secure_channel_create( - c_channel_credentials, c_target, c_arguments, NULL) + self._state.c_channel = grpc_secure_channel_create( + c_channel_credentials, <char *>target, c_arguments, NULL) grpc_channel_credentials_release(c_channel_credentials) - arguments_processor.un_c() - self.references.append(target) - self.references.append(arguments) - - def create_call(self, Call parent, int flags, - CompletionQueue queue not None, - method, host, object deadline): - if queue.is_shutting_down: - raise ValueError("queue must not be shutting down or shutdown") - cdef grpc_slice method_slice = _slice_from_bytes(method) - cdef grpc_slice host_slice - cdef grpc_slice *host_slice_ptr = NULL - if host is not None: - host_slice = _slice_from_bytes(host) - host_slice_ptr = &host_slice - cdef Call operation_call = Call() - operation_call.references = [self, queue] - cdef grpc_call *parent_call = NULL - if parent is not None: - parent_call = parent.c_call - operation_call.c_call = grpc_channel_create_call( - self.c_channel, parent_call, flags, - queue.c_completion_queue, method_slice, host_slice_ptr, - _timespec_from_time(deadline), NULL) - grpc_slice_unref(method_slice) - if host_slice_ptr: - grpc_slice_unref(host_slice) - return operation_call + self._state.c_call_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + self._state.c_connectivity_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + + def target(self): + cdef char *c_target + with self._state.condition: + c_target = grpc_channel_get_target(self._state.c_channel) + target = <bytes>c_target + gpr_free(c_target) + return target + + def integrated_call( + self, int flags, method, host, object deadline, object metadata, + CallCredentials credentials, operationses_and_tags): + return _integrated_call( + self._state, flags, method, host, deadline, metadata, credentials, + operationses_and_tags) + + def next_call_event(self): + def on_success(tag): + _process_integrated_call_tag(self._state, tag) + return _next_call_event( + self._state, self._state.c_call_completion_queue, on_success) + + def segregated_call( + self, int flags, method, host, object deadline, object metadata, + CallCredentials credentials, operationses_and_tags): + return _segregated_call( + self._state, flags, method, host, deadline, metadata, credentials, + operationses_and_tags) def check_connectivity_state(self, bint try_to_connect): - cdef grpc_connectivity_state result - with nogil: - result = grpc_channel_check_connectivity_state(self.c_channel, - try_to_connect) - return result + with self._state.condition: + return grpc_channel_check_connectivity_state( + self._state.c_channel, try_to_connect) def watch_connectivity_state( - self, grpc_connectivity_state last_observed_state, - object deadline, CompletionQueue queue not None, tag): - cdef _ConnectivityTag connectivity_tag = _ConnectivityTag(tag) - cpython.Py_INCREF(connectivity_tag) - grpc_channel_watch_connectivity_state( - self.c_channel, last_observed_state, _timespec_from_time(deadline), - queue.c_completion_queue, <cpython.PyObject *>connectivity_tag) + self, grpc_connectivity_state last_observed_state, object deadline): + return _watch_connectivity_state(self._state, last_observed_state, deadline) - def target(self): - cdef char *target = NULL - with nogil: - target = grpc_channel_get_target(self.c_channel) - result = <bytes>target - with nogil: - gpr_free(target) - return result - - def __dealloc__(self): - if self.c_channel != NULL: - grpc_channel_destroy(self.c_channel) - grpc_shutdown() + def close(self, code, details): + _close(self._state, code, details) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi index 5ea0287b81..9f06ce086e 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi @@ -13,10 +13,16 @@ # limitations under the License. +cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline) + + +cdef _interpret_event(grpc_event c_event) + + cdef class CompletionQueue: cdef grpc_completion_queue *c_completion_queue cdef bint is_shutting_down cdef bint is_shutdown - cdef _interpret_event(self, grpc_event event) + cdef _interpret_event(self, grpc_event c_event) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 40496d1124..a2d765546a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -20,6 +20,53 @@ import time cdef int _INTERRUPT_CHECK_PERIOD_MS = 200 +cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline): + cdef gpr_timespec c_increment + cdef gpr_timespec c_timeout + cdef gpr_timespec c_deadline + c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN) + if deadline is None: + c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME) + else: + c_deadline = _timespec_from_time(deadline) + + with nogil: + while True: + c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment) + if gpr_time_cmp(c_timeout, c_deadline) > 0: + c_timeout = c_deadline + c_event = grpc_completion_queue_next(c_completion_queue, c_timeout, NULL) + if (c_event.type != GRPC_QUEUE_TIMEOUT or + gpr_time_cmp(c_timeout, c_deadline) == 0): + break + + # Handle any signals + with gil: + cpython.PyErr_CheckSignals() + return c_event + + +cdef _interpret_event(grpc_event c_event): + cdef _Tag tag + if c_event.type == GRPC_QUEUE_TIMEOUT: + # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) + elif c_event.type == GRPC_QUEUE_SHUTDOWN: + # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + return None, ConnectivityEvent(GRPC_QUEUE_SHUTDOWN, False, None) + else: + tag = <_Tag>c_event.tag + # We receive event tags only after they've been inc-ref'd elsewhere in + # the code. + cpython.Py_DECREF(tag) + return tag, tag.event(c_event) + + +cdef _latent_event(grpc_completion_queue *c_completion_queue, object deadline): + cdef grpc_event c_event = _next(c_completion_queue, deadline) + return _interpret_event(c_event) + + cdef class CompletionQueue: def __cinit__(self, shutdown_cq=False): @@ -36,48 +83,16 @@ cdef class CompletionQueue: self.is_shutting_down = False self.is_shutdown = False - cdef _interpret_event(self, grpc_event event): - cdef _Tag tag = None - if event.type == GRPC_QUEUE_TIMEOUT: - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. - return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) - elif event.type == GRPC_QUEUE_SHUTDOWN: + cdef _interpret_event(self, grpc_event c_event): + unused_tag, event = _interpret_event(c_event) + if event.completion_type == GRPC_QUEUE_SHUTDOWN: self.is_shutdown = True - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. - return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, True, None) - else: - tag = <_Tag>event.tag - # We receive event tags only after they've been inc-ref'd elsewhere in - # the code. - cpython.Py_DECREF(tag) - return tag.event(event) + return event + # We name this 'poll' to avoid problems with CPython's expectations for + # 'special' methods (like next and __next__). def poll(self, deadline=None): - # We name this 'poll' to avoid problems with CPython's expectations for - # 'special' methods (like next and __next__). - cdef gpr_timespec c_increment - cdef gpr_timespec c_timeout - cdef gpr_timespec c_deadline - if deadline is None: - c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME) - else: - c_deadline = _timespec_from_time(deadline) - with nogil: - c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN) - - while True: - c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment) - if gpr_time_cmp(c_timeout, c_deadline) > 0: - c_timeout = c_deadline - event = grpc_completion_queue_next( - self.c_completion_queue, c_timeout, NULL) - if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0: - break; - - # Handle any signals - with gil: - cpython.PyErr_CheckSignals() - return self._interpret_event(event) + return self._interpret_event(_next(self.c_completion_queue, deadline)) def shutdown(self): with nogil: diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 4f8868d346..578a3d79ad 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from grpc.framework.foundation import logging_pool from tests.unit.framework.common import test_constants +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -30,6 +31,8 @@ _RECEIVE_MESSAGE_TAG = 'receive_message' _SERVER_COMPLETE_CALL_TAG = 'server_complete_call' _SUCCESS_CALL_FRACTION = 1.0 / 8.0 +_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) +_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS class _State(object): @@ -150,7 +153,8 @@ class CancelManyCallsTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, + None) state = _State() @@ -165,31 +169,33 @@ class CancelManyCallsTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() with client_condition: client_calls = [] for index in range(test_constants.RPC_CONCURRENCY): - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, - b'/twinkies', None, None) - operations = ( - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(b'\x45\x56', _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch(operations, tag) + client_call = channel.integrated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, + None, (( + ( + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(b'\x45\x56', + _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + tag, + ),)) client_due.add(tag) client_calls.append(client_call) + client_events_future = test_utilities.SimpleFuture( + lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) + with state.condition: while True: if state.parked_handlers < test_constants.THREAD_CONCURRENCY: @@ -201,12 +207,14 @@ class CancelManyCallsTest(unittest.TestCase): state.condition.notify_all() break - client_driver.events( - test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) + client_events_future.result() with client_condition: for client_call in client_calls: - client_call.cancel() + client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') + for _ in range(_UNSUCCESSFUL_CALLS): + channel.next_call_event() + channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') with state.condition: server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py index 7305d0fa3f..d95286071d 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -21,25 +21,20 @@ from grpc._cython import cygrpc from tests.unit.framework.common import test_constants -def _channel_and_completion_queue(): - channel = cygrpc.Channel(b'localhost:54321', ()) - completion_queue = cygrpc.CompletionQueue() - return channel, completion_queue +def _channel(): + return cygrpc.Channel(b'localhost:54321', (), None) -def _connectivity_loop(channel, completion_queue): +def _connectivity_loop(channel): for _ in range(100): connectivity = channel.check_connectivity_state(True) - channel.watch_connectivity_state(connectivity, - time.time() + 0.2, completion_queue, - None) - completion_queue.poll() + channel.watch_connectivity_state(connectivity, time.time() + 0.2) def _create_loop_destroy(): - channel, completion_queue = _channel_and_completion_queue() - _connectivity_loop(channel, completion_queue) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def _in_parallel(behavior, arguments): @@ -55,12 +50,9 @@ def _in_parallel(behavior, arguments): class ChannelTest(unittest.TestCase): def test_single_channel_lonely_connectivity(self): - channel, completion_queue = _channel_and_completion_queue() - _in_parallel(_connectivity_loop, ( - channel, - completion_queue, - )) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def test_multiple_channels_lonely_connectivity(self): _in_parallel(_create_loop_destroy, ()) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index 7fd3d19b4e..d8210f36f8 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -100,7 +100,8 @@ class RpcTest(object): self.server.register_completion_queue(self.server_completion_queue) port = self.server.add_http2_port(b'[::]:0') self.server.start() - self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), []) + self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], + None) self._server_shutdown_tag = 'server_shutdown_tag' self.server_condition = threading.Condition() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 7caa98f72d..8a721788f4 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -41,31 +42,27 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - client_complete_rpc_start_batch_result = client_call.start_client_batch( + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [( [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.client_driver.add_due({ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) + )]) + client_call.operate([ + cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA, + _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), + ], client_complete_rpc_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -96,20 +93,23 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = server_call_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + if client_events[0].tag is client_receive_initial_metadata_tag: + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] + else: + client_complete_rpc_event = client_events[0] + client_receive_initial_metadata_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 8582a39c01..47f39ebce2 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -36,28 +37,31 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_complete_rpc_start_batch_result = client_call.start_client_batch( - [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.client_driver.add_due({ - client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) - + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation( + _common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation( + _common.EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + ]) + client_call.operate([ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -87,20 +91,19 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = self.server_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index bc63b54879..8a903bfaf9 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -17,6 +17,7 @@ import threading import unittest from grpc._cython import cygrpc +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -118,7 +119,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set()) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(), + None) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, @@ -127,10 +129,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() server_call_condition = threading.Condition() server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' @@ -154,25 +152,28 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_completion_queue, server_rpc_tag) - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, b'/twinkies', - None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_due.add(client_receive_initial_metadata_tag) - client_complete_rpc_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_complete_rpc_tag)) - client_due.add(client_complete_rpc_tag) + client_call = channel.segregated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, ( + ( + [ + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + ), + ( + [ + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + )) + client_receive_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_rpc_event = server_driver.first_event() @@ -208,19 +209,20 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_complete_rpc_tag) server_call_driver.events() - with client_condition: - client_receive_first_message_tag = 'client_receive_first_message_tag' - client_receive_first_message_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], client_receive_first_message_tag)) - client_due.add(client_receive_first_message_tag) - client_receive_first_message_event = client_driver.event_with_tag( - client_receive_first_message_tag) + client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result( + ) + + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_call.operate([ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], client_receive_first_message_tag) + client_receive_first_message_event = client_call.next_event() - client_call_cancel_result = client_call.cancel() - client_driver.events() + client_call_cancel_result = client_call.cancel( + cygrpc.StatusCode.cancelled, 'Cancelled during test!') + client_complete_rpc_event = client_call.next_event() + channel.close(cygrpc.StatusCode.unknown, 'Channel closed!') server.shutdown(server_completion_queue, server_shutdown_tag) server.cancel_all_calls() server_driver.events() @@ -228,11 +230,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): self.assertEqual(cygrpc.CallError.ok, request_call_result) self.assertEqual(cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) self.assertIs(server_rpc_tag, server_rpc_event.tag) self.assertEqual(cygrpc.CompletionType.operation_complete, server_rpc_event.completion_type) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index 23f5ef605d..724a690746 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -51,8 +51,8 @@ class TypeSmokeTest(unittest.TestCase): del server def testChannelUpDown(self): - channel = cygrpc.Channel(b'[::]:0', None) - del channel + channel = cygrpc.Channel(b'[::]:0', None, None) + channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') def test_metadata_plugin_call_credentials_up_down(self): cygrpc.MetadataPluginCallCredentials(_metadata_plugin, @@ -121,7 +121,7 @@ class ServerClientMixin(object): client_credentials) else: self.client_channel = cygrpc.Channel('localhost:{}'.format( - self.port).encode(), set()) + self.port).encode(), set(), None) if host_override: self.host_argument = None # default host self.expected_host = host_override @@ -131,17 +131,20 @@ class ServerClientMixin(object): self.expected_host = self.host_argument def tearDownMixin(self): + self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') + del self.client_channel del self.server del self.client_completion_queue del self.server_completion_queue - def _perform_operations(self, operations, call, queue, deadline, - description): - """Perform the list of operations with given call, queue, and deadline. + def _perform_queue_operations(self, operations, call, queue, deadline, + description): + """Perform the operations with given call, queue, and deadline. - Invocation errors are reported with as an exception with `description` in - the message. Performs the operations asynchronously, returning a future. - """ + Invocation errors are reported with as an exception with `description` + in the message. Performs the operations asynchronously, returning a + future. + """ def performer(): tag = object() @@ -185,9 +188,6 @@ class ServerClientMixin(object): self.assertEqual(cygrpc.CallError.ok, request_call_result) client_call_tag = object() - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) client_initial_metadata = ( ( CLIENT_METADATA_ASCII_KEY, @@ -198,18 +198,24 @@ class ServerClientMixin(object): CLIENT_METADATA_BIN_VALUE, ), ) - client_start_batch_result = client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(client_initial_metadata, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_call_tag) - self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) - client_event_future = test_utilities.CompletionQueuePollFuture( - self.client_completion_queue, DEADLINE) + client_call = self.client_channel.integrated_call( + 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, + None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + client_initial_metadata, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_call_tag, + ), + ]) + client_event_future = test_utilities.SimpleFuture( + self.client_channel.next_call_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) self.assertEqual(cygrpc.CompletionType.operation_complete, @@ -304,66 +310,76 @@ class ServerClientMixin(object): del client_call del server_call - def test6522(self): + def test_6522(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 METHOD = b'twinkies' empty_metadata = () + # Prologue server_request_tag = object() self.server.request_call(self.server_completion_queue, self.server_completion_queue, server_request_tag) - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) - - # Prologue - def perform_client_operations(operations, description): - return self._perform_operations(operations, client_call, - self.client_completion_queue, - DEADLINE, description) - - client_event_future = perform_client_operations([ - cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], "Client prologue") + client_call = self.client_channel.segregated_call( + 0, METHOD, self.host_argument, DEADLINE, None, None, ([( + [ + cygrpc.SendInitialMetadataOperation(empty_metadata, + _EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + object(), + ), ( + [ + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + object(), + )])) + + client_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) server_call = request_event.call def perform_server_operations(operations, description): - return self._perform_operations(operations, server_call, - self.server_completion_queue, - DEADLINE, description) + return self._perform_queue_operations(operations, server_call, + self.server_completion_queue, + DEADLINE, description) server_event_future = perform_server_operations([ cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), ], "Server prologue") - client_event_future.result() # force completion + client_initial_metadata_event_future.result() # force completion server_event_future.result() # Messaging for _ in range(10): - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Client message") + client_message_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_event_future = perform_server_operations([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Server receive") - client_event_future.result() # force completion + client_message_event_future.result() # force completion server_event_future.result() # Epilogue - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS) ], "Client epilogue") + # One for ReceiveStatusOnClient, one for SendCloseFromClient. + client_events_future = test_utilities.SimpleFuture( + lambda: { + client_call.next_event(), + client_call.next_event(),}) server_event_future = perform_server_operations([ cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), @@ -371,7 +387,7 @@ class ServerClientMixin(object): empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) ], "Server epilogue") - client_event_future.result() # force completion + client_events_future.result() # force completion server_event_future.result() diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 4edf0fc4ad..f153089a24 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -81,29 +81,16 @@ class InvalidMetadataTest(unittest.TestCase): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._unary_unary.future(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.future(request, metadata=metadata) def testUnaryRequestStreamResponse(self): request = b'\x37\x58' metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._unary_stream(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_stream(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestBlockingUnaryResponse(self): request_iterator = ( @@ -129,32 +116,18 @@ class InvalidMetadataTest(unittest.TestCase): b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._stream_unary.future( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_unary.future(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestStreamResponse(self): request_iterator = ( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._stream_stream( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_stream(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) if __name__ == '__main__': |