diff options
author | Nathaniel Manista <nathaniel@google.com> | 2018-04-20 02:49:34 +0000 |
---|---|---|
committer | Nathaniel Manista <nathaniel@google.com> | 2018-05-02 18:24:47 +0000 |
commit | ca7ba4d0ac3ab452c5db60befc8be37fd6e2339b (patch) | |
tree | 813d64334b810569555737be39f3d81bc9f4dc7c | |
parent | d3eaf64687416c19177ba6a078f9f11599c063b9 (diff) |
Keep Core memory inside cygrpc.Channel objects
This removes invocation-side completion queues from the _cygrpc API.
Invocation-side calls are changed to no longer share the same lifetime
as Core calls.
Illegal metadata is now detected on invocation rather than at the start
of a batch (so passing illegal metadata to a response-streaming method
will now raise an exception immediately rather than later on when
attempting to read the first response message).
It is no longer possible to create a call without immediately starting
at least one batch of operations on it. Only tests are affected by this
change; there are no real use cases in which one wants to start a call
but wait a little while before learning that the server has rejected
it.
It is now required that code above cygrpc.Channel spend threads on
next_event whenever events are pending. A cygrpc.Channel.close method
is introduced, but it merely blocks until the cygrpc.Channel's
completion queues are drained; it does not itself drain them.
Noteworthy here is that we drop the cygrpc.Channel.__dealloc__ method.
It is not the same as __del__ (which is not something that can be added
to cygrpc.Channel) and there is no guarantee that __dealloc__ will be
called at all or that it will be called while the cygrpc.Channel
instance's Python attributes are intact (in testing, I saw both in
different environments). This commit does not knowingly break any
garbage-collection-based memory management working (or "happening to
appear to work in some circumstances"), though if it does, the proper
remedy is to call grpc.Channel.close... which is the objective towards
which this commit builds.
13 files changed, 911 insertions, 528 deletions
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 2eff08aa57..6604f8f35c 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -79,27 +79,6 @@ def _wait_once_until(condition, until): condition.wait(timeout=remaining) -_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( - 'Internal gRPC call error %d. ' + - 'Please report to https://github.com/grpc/grpc/issues') - - -def _check_call_error(call_error, metadata): - if call_error == cygrpc.CallError.invalid_metadata: - raise ValueError('metadata was invalid: %s' % metadata) - elif call_error != cygrpc.CallError.ok: - raise ValueError(_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) - - -def _call_error_set_RPCstate(state, call_error, metadata): - if call_error == cygrpc.CallError.invalid_metadata: - _abort(state, grpc.StatusCode.INTERNAL, - 'metadata was invalid: %s' % metadata) - else: - _abort(state, grpc.StatusCode.INTERNAL, - _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error) - - class _RPCState(object): def __init__(self, due, initial_metadata, trailing_metadata, code, details): @@ -163,7 +142,7 @@ def _handle_event(event, state, response_deserializer): return callbacks -def _event_handler(state, call, response_deserializer): +def _event_handler(state, response_deserializer): def handle_event(event): with state.condition: @@ -172,40 +151,47 @@ def _event_handler(state, call, response_deserializer): done = not state.due for callback in callbacks: callback() - return call if done else None + return done return handle_event -def _consume_request_iterator(request_iterator, state, call, - request_serializer): - event_handler = _event_handler(state, call, None) +def _consume_request_iterator(request_iterator, state, call, request_serializer, + event_handler): - def consume_request_iterator(): + def consume_request_iterator(): # pylint: disable=too-many-branches while True: try: request = next(request_iterator) except StopIteration: break except Exception: # pylint: disable=broad-except - logging.exception("Exception iterating requests!") - call.cancel() - _abort(state, grpc.StatusCode.UNKNOWN, - "Exception iterating requests!") + code = grpc.StatusCode.UNKNOWN + details = 'Exception iterating requests!' + logging.exception(details) + call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) return serialized_request = _common.serialize(request, request_serializer) with state.condition: if state.code is None and not state.cancelled: if serialized_request is None: - call.cancel() + code = grpc.StatusCode.INTERNAL # pylint: disable=redefined-variable-type details = 'Exception serializing request!' - _abort(state, grpc.StatusCode.INTERNAL, details) + call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) return else: operations = (cygrpc.SendMessageOperation( serialized_request, _EMPTY_FLAGS),) - call.start_client_batch(operations, event_handler) - state.due.add(cygrpc.OperationType.send_message) + operating = call.operate(operations, event_handler) + if operating: + state.due.add(cygrpc.OperationType.send_message) + else: + return while True: state.condition.wait() if state.code is None: @@ -219,15 +205,19 @@ def _consume_request_iterator(request_iterator, state, call, if state.code is None: operations = ( cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),) - call.start_client_batch(operations, event_handler) - state.due.add(cygrpc.OperationType.send_close_from_client) + operating = call.operate(operations, event_handler) + if operating: + state.due.add(cygrpc.OperationType.send_close_from_client) def stop_consumption_thread(timeout): # pylint: disable=unused-argument with state.condition: if state.code is None: - call.cancel() + code = grpc.StatusCode.CANCELLED + details = 'Consumption thread cleaned up!' + call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) state.cancelled = True - _abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!') + _abort(state, code, details) state.condition.notify_all() consumption_thread = _common.CleanupThread( @@ -247,9 +237,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def cancel(self): with self._state.condition: if self._state.code is None: - self._call.cancel() + code = grpc.StatusCode.CANCELLED + details = 'Locally cancelled by application!' + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details) self._state.cancelled = True - _abort(self._state, grpc.StatusCode.CANCELLED, 'Cancelled!') + _abort(self._state, code, details) self._state.condition.notify_all() return False @@ -318,12 +311,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def _next(self): with self._state.condition: if self._state.code is None: - event_handler = _event_handler(self._state, self._call, + event_handler = _event_handler(self._state, self._response_deserializer) - self._call.start_client_batch( + operating = self._call.operate( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), event_handler) - self._state.due.add(cygrpc.OperationType.receive_message) + if operating: + self._state.due.add(cygrpc.OperationType.receive_message) elif self._state.code is grpc.StatusCode.OK: raise StopIteration() else: @@ -408,9 +402,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): def __del__(self): with self._state.condition: if self._state.code is None: - self._call.cancel() - self._state.cancelled = True self._state.code = grpc.StatusCode.CANCELLED + self._state.details = 'Cancelled upon garbage collection!' + self._state.cancelled = True + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code], + self._state.details) self._state.condition.notify_all() @@ -437,6 +434,24 @@ def _end_unary_response_blocking(state, call, with_call, deadline): raise _Rendezvous(state, None, None, deadline) +def _stream_unary_invocation_operationses(metadata): + return ( + ( + cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + + +def _stream_unary_invocation_operationses_and_tags(metadata): + return tuple(( + operations, + None, + ) for operations in _stream_unary_invocation_operationses(metadata)) + + class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def __init__(self, channel, managed_call, method, request_serializer, @@ -448,8 +463,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._response_deserializer = response_deserializer def _prepare(self, request, timeout, metadata): - deadline, serialized_request, rendezvous = (_start_unary_request( - request, timeout, self._request_serializer)) + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) if serialized_request is None: return None, None, None, rendezvous else: @@ -467,48 +482,38 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def _blocking(self, request, timeout, metadata, credentials): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata) - if rendezvous: + if state is None: raise rendezvous else: - completion_queue = cygrpc.CompletionQueue() - call = self._channel.create_call(None, 0, completion_queue, - self._method, None, deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - call_error = call.start_client_batch(operations, None) - _check_call_error(call_error, metadata) - _handle_event(completion_queue.poll(), state, - self._response_deserializer) - return state, call, deadline + call = self._channel.segregated_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, (( + operations, + None, + ),)) + event = call.next_event() + _handle_event(event, state, self._response_deserializer) + return state, call, def __call__(self, request, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request, timeout, metadata, - credentials) - return _end_unary_response_blocking(state, call, False, deadline) + state, call, = self._blocking(request, timeout, metadata, credentials) + return _end_unary_response_blocking(state, call, False, None) def with_call(self, request, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request, timeout, metadata, - credentials) - return _end_unary_response_blocking(state, call, True, deadline) + state, call, = self._blocking(request, timeout, metadata, credentials) + return _end_unary_response_blocking(state, call, True, None) def future(self, request, timeout=None, metadata=None, credentials=None): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata) - if rendezvous: - return rendezvous + if state is None: + raise rendezvous else: - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, - self._response_deserializer) - with state.condition: - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + (operations,), event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -524,34 +529,27 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._response_deserializer = response_deserializer def __call__(self, request, timeout=None, metadata=None, credentials=None): - deadline, serialized_request, rendezvous = (_start_unary_request( - request, timeout, self._request_serializer)) + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) if serialized_request is None: raise rendezvous else: state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, - self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( + operationses = ( + ( cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + operationses, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -569,49 +567,38 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): def _blocking(self, request_iterator, timeout, metadata, credentials): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) - completion_queue = cygrpc.CompletionQueue() - call = self._channel.create_call(None, 0, completion_queue, - self._method, None, deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None) - operations = ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, None) - _check_call_error(call_error, metadata) - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + call = self._channel.segregated_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + _stream_unary_invocation_operationses_and_tags(metadata)) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, None) while True: - event = completion_queue.poll() + event = call.next_event() with state.condition: _handle_event(event, state, self._response_deserializer) state.condition.notify_all() if not state.due: break - return state, call, deadline + return state, call, def __call__(self, request_iterator, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request_iterator, timeout, - metadata, credentials) - return _end_unary_response_blocking(state, call, False, deadline) + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials) + return _end_unary_response_blocking(state, call, False, None) def with_call(self, request_iterator, timeout=None, metadata=None, credentials=None): - state, call, deadline = self._blocking(request_iterator, timeout, - metadata, credentials) - return _end_unary_response_blocking(state, call, True, deadline) + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials) + return _end_unary_response_blocking(state, call, True, None) def future(self, request_iterator, @@ -620,27 +607,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): credentials=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, + _stream_unary_invocation_operationses(metadata), event_handler) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -661,26 +634,20 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): credentials=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) - call, drive_call = self._managed_call(None, 0, self._method, None, - deadline) - if credentials is not None: - call.set_credentials(credentials._credentials) - event_handler = _event_handler(state, call, self._response_deserializer) - with state.condition: - call.start_client_batch( - (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), - event_handler) - operations = ( + operationses = ( + ( cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) - call_error = call.start_client_batch(operations, event_handler) - if call_error != cygrpc.CallError.ok: - _call_error_set_RPCstate(state, call_error, metadata) - return _Rendezvous(state, None, None, deadline) - drive_call() - _consume_request_iterator(request_iterator, state, call, - self._request_serializer) + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + 0, self._method, None, deadline, metadata, None + if credentials is None else credentials._credentials, operationses, + event_handler) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -689,28 +656,25 @@ class _ChannelCallState(object): def __init__(self, channel): self.lock = threading.Lock() self.channel = channel - self.completion_queue = cygrpc.CompletionQueue() - self.managed_calls = None + self.managed_calls = 0 def _run_channel_spin_thread(state): def channel_spin(): while True: - event = state.completion_queue.poll() - completed_call = event.tag(event) - if completed_call is not None: + event = state.channel.next_call_event() + call_completed = event.tag(event) + if call_completed: with state.lock: - state.managed_calls.remove(completed_call) - if not state.managed_calls: - state.managed_calls = None + state.managed_calls -= 1 + if state.managed_calls == 0: return def stop_channel_spin(timeout): # pylint: disable=unused-argument with state.lock: - if state.managed_calls is not None: - for call in state.managed_calls: - call.cancel() + state.channel.close(cygrpc.StatusCode.cancelled, + 'Channel spin thread cleaned up!') channel_spin_thread = _common.CleanupThread( stop_channel_spin, target=channel_spin) @@ -719,37 +683,41 @@ def _run_channel_spin_thread(state): def _channel_managed_call_management(state): - def create(parent, flags, method, host, deadline): - """Creates a managed cygrpc.Call and a function to call to drive it. - - If operations are successfully added to the returned cygrpc.Call, the - returned function must be called. If operations are not successfully added - to the returned cygrpc.Call, the returned function must not be called. - - Args: - parent: A cygrpc.Call to be used as the parent of the created call. - flags: An integer bitfield of call flags. - method: The RPC method. - host: A host string for the created call. - deadline: A float to be the deadline of the created call or None if the - call is to have an infinite deadline. - - Returns: - A cygrpc.Call with which to conduct an RPC and a function to call if - operations are successfully started on the call. - """ - call = state.channel.create_call(parent, flags, state.completion_queue, - method, host, deadline) - - def drive(): - with state.lock: - if state.managed_calls is None: - state.managed_calls = set((call,)) - _run_channel_spin_thread(state) - else: - state.managed_calls.add(call) + # pylint: disable=too-many-arguments + def create(flags, method, host, deadline, metadata, credentials, + operationses, event_handler): + """Creates a cygrpc.IntegratedCall. - return call, drive + Args: + flags: An integer bitfield of call flags. + method: The RPC method. + host: A host string for the created call. + deadline: A float to be the deadline of the created call or None if + the call is to have an infinite deadline. + metadata: The metadata for the call or None. + credentials: A cygrpc.CallCredentials or None. + operationses: An iterable of iterables of cygrpc.Operations to be + started on the call. + event_handler: A behavior to call to handle the events resultant from + the operations on the call. + + Returns: + A cygrpc.IntegratedCall with which to conduct an RPC. + """ + operationses_and_tags = tuple(( + operations, + event_handler, + ) for operations in operationses) + with state.lock: + call = state.channel.integrated_call(flags, method, host, deadline, + metadata, credentials, + operationses_and_tags) + if state.managed_calls == 0: + state.managed_calls = 1 + _run_channel_spin_thread(state) + else: + state.managed_calls += 1 + return call return create @@ -819,12 +787,9 @@ def _poll_connectivity(state, channel, initial_try_to_connect): callback_and_connectivity[1] = state.connectivity if callbacks: _spawn_delivery(state, callbacks) - completion_queue = cygrpc.CompletionQueue() while True: - channel.watch_connectivity_state(connectivity, - time.time() + 0.2, completion_queue, - None) - event = completion_queue.poll() + event = channel.watch_connectivity_state(connectivity, + time.time() + 0.2) with state.lock: if not state.callbacks_and_connectivities and not state.try_to_connect: state.polling = False diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index 1ba76b7f83..eefc685c0b 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -13,9 +13,59 @@ # limitations under the License. +cdef _check_call_error_no_metadata(c_call_error) + + +cdef _check_and_raise_call_error_no_metadata(c_call_error) + + +cdef _check_call_error(c_call_error, metadata) + + +cdef class _CallState: + + cdef grpc_call *c_call + cdef set due + + +cdef class _ChannelState: + + cdef object condition + cdef grpc_channel *c_channel + # A boolean field indicating that the channel is open (if True) or is being + # closed (i.e. a call to close is currently executing) or is closed (if + # False). + # TODO(https://github.com/grpc/grpc/issues/3064): Eliminate "is being closed" + # a state in which condition may be acquired by any thread, eliminate this + # field and just use the NULLness of c_channel as an indication that the + # channel is closed. + cdef object open + + # A dict from _BatchOperationTag to _CallState + cdef dict integrated_call_states + cdef grpc_completion_queue *c_call_completion_queue + + # A set of _CallState + cdef set segregated_call_states + + cdef set connectivity_due + cdef grpc_completion_queue *c_connectivity_completion_queue + + +cdef class IntegratedCall: + + cdef _ChannelState _channel_state + cdef _CallState _call_state + + +cdef class SegregatedCall: + + cdef _ChannelState _channel_state + cdef _CallState _call_state + cdef grpc_completion_queue *_c_completion_queue + + cdef class Channel: cdef grpc_arg_pointer_vtable _vtable - cdef grpc_channel *c_channel - cdef list references - cdef readonly _ArgumentsProcessor _arguments_processor + cdef _ChannelState _state diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index a3966497bc..72e74e84ae 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -14,82 +14,439 @@ cimport cpython +import threading + +_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( + 'Internal gRPC call error %d. ' + + 'Please report to https://github.com/grpc/grpc/issues') + + +cdef str _call_error_metadata(metadata): + return 'metadata was invalid: %s' % metadata + + +cdef str _call_error_no_metadata(c_call_error): + return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error + + +cdef str _call_error(c_call_error, metadata): + if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA: + return _call_error_metadata(metadata) + else: + return _call_error_no_metadata(c_call_error) + + +cdef _check_call_error_no_metadata(c_call_error): + if c_call_error != GRPC_CALL_OK: + return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error + else: + return None + + +cdef _check_and_raise_call_error_no_metadata(c_call_error): + error = _check_call_error_no_metadata(c_call_error) + if error is not None: + raise ValueError(error) + + +cdef _check_call_error(c_call_error, metadata): + if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA: + return _call_error_metadata(metadata) + else: + return _check_call_error_no_metadata(c_call_error) + + +cdef void _raise_call_error_no_metadata(c_call_error) except *: + raise ValueError(_call_error_no_metadata(c_call_error)) + + +cdef void _raise_call_error(c_call_error, metadata) except *: + raise ValueError(_call_error(c_call_error, metadata)) + + +cdef _destroy_c_completion_queue(grpc_completion_queue *c_completion_queue): + grpc_completion_queue_shutdown(c_completion_queue) + grpc_completion_queue_destroy(c_completion_queue) + + +cdef class _CallState: + + def __cinit__(self): + self.due = set() + + +cdef class _ChannelState: + + def __cinit__(self): + self.condition = threading.Condition() + self.open = True + self.integrated_call_states = {} + self.segregated_call_states = set() + self.connectivity_due = set() + + +cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag = _BatchOperationTag(user_tag, operations, None) + tag.prepare() + cpython.Py_INCREF(tag) + with nogil: + c_call_error = grpc_call_start_batch( + c_call, tag.c_ops, tag.c_nops, <cpython.PyObject *>tag, NULL) + return c_call_error, tag + + +cdef object _operate_from_integrated_call( + _ChannelState channel_state, _CallState call_state, object operations, + object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag + with channel_state.condition: + if call_state.due: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + call_state.due.add(tag) + channel_state.integrated_call_states[tag] = call_state + return True + else: + _raise_call_error_no_metadata(c_call_error) + else: + return False + + +cdef object _operate_from_segregated_call( + _ChannelState channel_state, _CallState call_state, object operations, + object user_tag): + cdef grpc_call_error c_call_error + cdef _BatchOperationTag tag + with channel_state.condition: + if call_state.due: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + call_state.due.add(tag) + return True + else: + _raise_call_error_no_metadata(c_call_error) + else: + return False + + +cdef _cancel( + _ChannelState channel_state, _CallState call_state, grpc_status_code code, + str details): + cdef grpc_call_error c_call_error + with channel_state.condition: + if call_state.due: + c_call_error = grpc_call_cancel_with_status( + call_state.c_call, code, _encode(details), NULL) + _check_and_raise_call_error_no_metadata(c_call_error) + + +cdef BatchOperationEvent _next_call_event( + _ChannelState channel_state, grpc_completion_queue *c_completion_queue, + on_success): + tag, event = _latent_event(c_completion_queue, None) + with channel_state.condition: + on_success(tag) + channel_state.condition.notify_all() + return event + + +# TODO(https://github.com/grpc/grpc/issues/14569): This could be a lot simpler. +cdef void _call( + _ChannelState channel_state, _CallState call_state, + grpc_completion_queue *c_completion_queue, on_success, int flags, method, + host, object deadline, CallCredentials credentials, + object operationses_and_user_tags, object metadata) except *: + """Invokes an RPC. + + Args: + channel_state: A _ChannelState with its "open" attribute set to True. RPCs + may not be invoked on a closed channel. + call_state: An empty _CallState to be altered (specifically assigned a + c_call and having its due set populated) if the RPC invocation is + successful. + c_completion_queue: A grpc_completion_queue to be used for the call's + operations. + on_success: A behavior to be called if attempting to start operations for + the call succeeds. If called the behavior will be called while holding the + channel_state condition and passed the tags associated with operations + that were successfully started for the call. + flags: Flags to be passed to gRPC Core as part of call creation. + method: The fully-qualified name of the RPC method being invoked. + host: A "host" string to be passed to gRPC Core as part of call creation. + deadline: A float for the deadline of the RPC, or None if the RPC is to have + no deadline. + credentials: A _CallCredentials for the RPC or None. + operationses_and_user_tags: A sequence of length-two sequences the first + element of which is a sequence of Operations and the second element of + which is an object to be used as a tag. A SendInitialMetadataOperation + must be present in the first element of this value. + metadata: The metadata for this call. + """ + cdef grpc_slice method_slice + cdef grpc_slice host_slice + cdef grpc_slice *host_slice_ptr + cdef grpc_call_credentials *c_call_credentials + cdef grpc_call_error c_call_error + cdef tuple error_and_wrapper_tag + cdef _BatchOperationTag wrapper_tag + with channel_state.condition: + if channel_state.open: + method_slice = _slice_from_bytes(method) + if host is None: + host_slice_ptr = NULL + else: + host_slice = _slice_from_bytes(host) + host_slice_ptr = &host_slice + call_state.c_call = grpc_channel_create_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, method_slice, host_slice_ptr, + _timespec_from_time(deadline), NULL) + grpc_slice_unref(method_slice) + if host_slice_ptr: + grpc_slice_unref(host_slice) + if credentials is not None: + c_call_credentials = credentials.c() + c_call_error = grpc_call_set_credentials( + call_state.c_call, c_call_credentials) + grpc_call_credentials_release(c_call_credentials) + if c_call_error != GRPC_CALL_OK: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + _raise_call_error_no_metadata(c_call_error) + started_tags = set() + for operations, user_tag in operationses_and_user_tags: + c_call_error, tag = _operate(call_state.c_call, operations, user_tag) + if c_call_error == GRPC_CALL_OK: + started_tags.add(tag) + else: + grpc_call_cancel(call_state.c_call, NULL) + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + _raise_call_error(c_call_error, metadata) + else: + call_state.due.update(started_tags) + on_success(started_tags) + else: + raise ValueError('Cannot invoke RPC on closed channel!') + +cdef void _process_integrated_call_tag( + _ChannelState state, _BatchOperationTag tag) except *: + cdef _CallState call_state = state.integrated_call_states.pop(tag) + call_state.due.remove(tag) + if not call_state.due: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + + +cdef class IntegratedCall: + + def __cinit__(self, _ChannelState channel_state, _CallState call_state): + self._channel_state = channel_state + self._call_state = call_state + + def operate(self, operations, tag): + return _operate_from_integrated_call( + self._channel_state, self._call_state, operations, tag) + + def cancel(self, code, details): + _cancel(self._channel_state, self._call_state, code, details) + + +cdef IntegratedCall _integrated_call( + _ChannelState state, int flags, method, host, object deadline, + object metadata, CallCredentials credentials, operationses_and_user_tags): + call_state = _CallState() + + def on_success(started_tags): + for started_tag in started_tags: + state.integrated_call_states[started_tag] = call_state + + _call( + state, call_state, state.c_call_completion_queue, on_success, flags, + method, host, deadline, credentials, operationses_and_user_tags, metadata) + + return IntegratedCall(state, call_state) + + +cdef object _process_segregated_call_tag( + _ChannelState state, _CallState call_state, + grpc_completion_queue *c_completion_queue, _BatchOperationTag tag): + call_state.due.remove(tag) + if not call_state.due: + grpc_call_unref(call_state.c_call) + call_state.c_call = NULL + state.segregated_call_states.remove(call_state) + _destroy_c_completion_queue(c_completion_queue) + return True + else: + return False + + +cdef class SegregatedCall: + + def __cinit__(self, _ChannelState channel_state, _CallState call_state): + self._channel_state = channel_state + self._call_state = call_state + + def operate(self, operations, tag): + return _operate_from_segregated_call( + self._channel_state, self._call_state, operations, tag) + + def cancel(self, code, details): + _cancel(self._channel_state, self._call_state, code, details) + + def next_event(self): + def on_success(tag): + _process_segregated_call_tag( + self._channel_state, self._call_state, self._c_completion_queue, tag) + return _next_call_event( + self._channel_state, self._c_completion_queue, on_success) + + +cdef SegregatedCall _segregated_call( + _ChannelState state, int flags, method, host, object deadline, + object metadata, CallCredentials credentials, operationses_and_user_tags): + cdef _CallState call_state = _CallState() + cdef grpc_completion_queue *c_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + cdef SegregatedCall segregated_call + + def on_success(started_tags): + state.segregated_call_states.add(call_state) + + try: + _call( + state, call_state, c_completion_queue, on_success, flags, method, host, + deadline, credentials, operationses_and_user_tags, metadata) + except: + _destroy_c_completion_queue(c_completion_queue) + raise + + segregated_call = SegregatedCall(state, call_state) + segregated_call._c_completion_queue = c_completion_queue + return segregated_call + + +cdef object _watch_connectivity_state( + _ChannelState state, grpc_connectivity_state last_observed_state, + object deadline): + cdef _ConnectivityTag tag = _ConnectivityTag(object()) + with state.condition: + if state.open: + cpython.Py_INCREF(tag) + grpc_channel_watch_connectivity_state( + state.c_channel, last_observed_state, _timespec_from_time(deadline), + state.c_connectivity_completion_queue, <cpython.PyObject *>tag) + state.connectivity_due.add(tag) + else: + raise ValueError('Cannot invoke RPC on closed channel!') + completed_tag, event = _latent_event( + state.c_connectivity_completion_queue, None) + with state.condition: + state.connectivity_due.remove(completed_tag) + state.condition.notify_all() + return event + + +cdef _close(_ChannelState state, grpc_status_code code, object details): + cdef _CallState call_state + encoded_details = _encode(details) + with state.condition: + if state.open: + state.open = False + for call_state in set(state.integrated_call_states.values()): + grpc_call_cancel_with_status( + call_state.c_call, code, encoded_details, NULL) + for call_state in state.segregated_call_states: + grpc_call_cancel_with_status( + call_state.c_call, code, encoded_details, NULL) + # TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity + # watching. + + while state.integrated_call_states: + state.condition.wait() + while state.segregated_call_states: + state.condition.wait() + while state.connectivity_due: + state.condition.wait() + + _destroy_c_completion_queue(state.c_call_completion_queue) + _destroy_c_completion_queue(state.c_connectivity_completion_queue) + grpc_channel_destroy(state.c_channel) + state.c_channel = NULL + grpc_shutdown() + state.condition.notify_all() + else: + # Another call to close already completed in the past or is currently + # being executed in another thread. + while state.c_channel != NULL: + state.condition.wait() + cdef class Channel: - def __cinit__(self, bytes target, object arguments, - ChannelCredentials channel_credentials=None): + def __cinit__( + self, bytes target, object arguments, + ChannelCredentials channel_credentials): grpc_init() + self._state = _ChannelState() self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer self._vtable.cmp = &_compare_pointer cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( arguments) cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) - self.references = [] - c_target = target if channel_credentials is None: - self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, NULL) + self._state.c_channel = grpc_insecure_channel_create( + <char *>target, c_arguments, NULL) else: c_channel_credentials = channel_credentials.c() - self.c_channel = grpc_secure_channel_create( - c_channel_credentials, c_target, c_arguments, NULL) + self._state.c_channel = grpc_secure_channel_create( + c_channel_credentials, <char *>target, c_arguments, NULL) grpc_channel_credentials_release(c_channel_credentials) - arguments_processor.un_c() - self.references.append(target) - self.references.append(arguments) - - def create_call(self, Call parent, int flags, - CompletionQueue queue not None, - method, host, object deadline): - if queue.is_shutting_down: - raise ValueError("queue must not be shutting down or shutdown") - cdef grpc_slice method_slice = _slice_from_bytes(method) - cdef grpc_slice host_slice - cdef grpc_slice *host_slice_ptr = NULL - if host is not None: - host_slice = _slice_from_bytes(host) - host_slice_ptr = &host_slice - cdef Call operation_call = Call() - operation_call.references = [self, queue] - cdef grpc_call *parent_call = NULL - if parent is not None: - parent_call = parent.c_call - operation_call.c_call = grpc_channel_create_call( - self.c_channel, parent_call, flags, - queue.c_completion_queue, method_slice, host_slice_ptr, - _timespec_from_time(deadline), NULL) - grpc_slice_unref(method_slice) - if host_slice_ptr: - grpc_slice_unref(host_slice) - return operation_call + self._state.c_call_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + self._state.c_connectivity_completion_queue = ( + grpc_completion_queue_create_for_next(NULL)) + + def target(self): + cdef char *c_target + with self._state.condition: + c_target = grpc_channel_get_target(self._state.c_channel) + target = <bytes>c_target + gpr_free(c_target) + return target + + def integrated_call( + self, int flags, method, host, object deadline, object metadata, + CallCredentials credentials, operationses_and_tags): + return _integrated_call( + self._state, flags, method, host, deadline, metadata, credentials, + operationses_and_tags) + + def next_call_event(self): + def on_success(tag): + _process_integrated_call_tag(self._state, tag) + return _next_call_event( + self._state, self._state.c_call_completion_queue, on_success) + + def segregated_call( + self, int flags, method, host, object deadline, object metadata, + CallCredentials credentials, operationses_and_tags): + return _segregated_call( + self._state, flags, method, host, deadline, metadata, credentials, + operationses_and_tags) def check_connectivity_state(self, bint try_to_connect): - cdef grpc_connectivity_state result - with nogil: - result = grpc_channel_check_connectivity_state(self.c_channel, - try_to_connect) - return result + with self._state.condition: + return grpc_channel_check_connectivity_state( + self._state.c_channel, try_to_connect) def watch_connectivity_state( - self, grpc_connectivity_state last_observed_state, - object deadline, CompletionQueue queue not None, tag): - cdef _ConnectivityTag connectivity_tag = _ConnectivityTag(tag) - cpython.Py_INCREF(connectivity_tag) - grpc_channel_watch_connectivity_state( - self.c_channel, last_observed_state, _timespec_from_time(deadline), - queue.c_completion_queue, <cpython.PyObject *>connectivity_tag) + self, grpc_connectivity_state last_observed_state, object deadline): + return _watch_connectivity_state(self._state, last_observed_state, deadline) - def target(self): - cdef char *target = NULL - with nogil: - target = grpc_channel_get_target(self.c_channel) - result = <bytes>target - with nogil: - gpr_free(target) - return result - - def __dealloc__(self): - if self.c_channel != NULL: - grpc_channel_destroy(self.c_channel) - grpc_shutdown() + def close(self, code, details): + _close(self._state, code, details) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi index 5ea0287b81..9f06ce086e 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi @@ -13,10 +13,16 @@ # limitations under the License. +cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline) + + +cdef _interpret_event(grpc_event c_event) + + cdef class CompletionQueue: cdef grpc_completion_queue *c_completion_queue cdef bint is_shutting_down cdef bint is_shutdown - cdef _interpret_event(self, grpc_event event) + cdef _interpret_event(self, grpc_event c_event) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 40496d1124..a2d765546a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -20,6 +20,53 @@ import time cdef int _INTERRUPT_CHECK_PERIOD_MS = 200 +cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline): + cdef gpr_timespec c_increment + cdef gpr_timespec c_timeout + cdef gpr_timespec c_deadline + c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN) + if deadline is None: + c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME) + else: + c_deadline = _timespec_from_time(deadline) + + with nogil: + while True: + c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment) + if gpr_time_cmp(c_timeout, c_deadline) > 0: + c_timeout = c_deadline + c_event = grpc_completion_queue_next(c_completion_queue, c_timeout, NULL) + if (c_event.type != GRPC_QUEUE_TIMEOUT or + gpr_time_cmp(c_timeout, c_deadline) == 0): + break + + # Handle any signals + with gil: + cpython.PyErr_CheckSignals() + return c_event + + +cdef _interpret_event(grpc_event c_event): + cdef _Tag tag + if c_event.type == GRPC_QUEUE_TIMEOUT: + # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) + elif c_event.type == GRPC_QUEUE_SHUTDOWN: + # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + return None, ConnectivityEvent(GRPC_QUEUE_SHUTDOWN, False, None) + else: + tag = <_Tag>c_event.tag + # We receive event tags only after they've been inc-ref'd elsewhere in + # the code. + cpython.Py_DECREF(tag) + return tag, tag.event(c_event) + + +cdef _latent_event(grpc_completion_queue *c_completion_queue, object deadline): + cdef grpc_event c_event = _next(c_completion_queue, deadline) + return _interpret_event(c_event) + + cdef class CompletionQueue: def __cinit__(self, shutdown_cq=False): @@ -36,48 +83,16 @@ cdef class CompletionQueue: self.is_shutting_down = False self.is_shutdown = False - cdef _interpret_event(self, grpc_event event): - cdef _Tag tag = None - if event.type == GRPC_QUEUE_TIMEOUT: - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. - return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) - elif event.type == GRPC_QUEUE_SHUTDOWN: + cdef _interpret_event(self, grpc_event c_event): + unused_tag, event = _interpret_event(c_event) + if event.completion_type == GRPC_QUEUE_SHUTDOWN: self.is_shutdown = True - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. - return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, True, None) - else: - tag = <_Tag>event.tag - # We receive event tags only after they've been inc-ref'd elsewhere in - # the code. - cpython.Py_DECREF(tag) - return tag.event(event) + return event + # We name this 'poll' to avoid problems with CPython's expectations for + # 'special' methods (like next and __next__). def poll(self, deadline=None): - # We name this 'poll' to avoid problems with CPython's expectations for - # 'special' methods (like next and __next__). - cdef gpr_timespec c_increment - cdef gpr_timespec c_timeout - cdef gpr_timespec c_deadline - if deadline is None: - c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME) - else: - c_deadline = _timespec_from_time(deadline) - with nogil: - c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN) - - while True: - c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment) - if gpr_time_cmp(c_timeout, c_deadline) > 0: - c_timeout = c_deadline - event = grpc_completion_queue_next( - self.c_completion_queue, c_timeout, NULL) - if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0: - break; - - # Handle any signals - with gil: - cpython.PyErr_CheckSignals() - return self._interpret_event(event) + return self._interpret_event(_next(self.c_completion_queue, deadline)) def shutdown(self): with nogil: diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 4f8868d346..578a3d79ad 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from grpc.framework.foundation import logging_pool from tests.unit.framework.common import test_constants +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -30,6 +31,8 @@ _RECEIVE_MESSAGE_TAG = 'receive_message' _SERVER_COMPLETE_CALL_TAG = 'server_complete_call' _SUCCESS_CALL_FRACTION = 1.0 / 8.0 +_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) +_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS class _State(object): @@ -150,7 +153,8 @@ class CancelManyCallsTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, + None) state = _State() @@ -165,31 +169,33 @@ class CancelManyCallsTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() with client_condition: client_calls = [] for index in range(test_constants.RPC_CONCURRENCY): - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, - b'/twinkies', None, None) - operations = ( - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(b'\x45\x56', _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch(operations, tag) + client_call = channel.integrated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, + None, (( + ( + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(b'\x45\x56', + _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + tag, + ),)) client_due.add(tag) client_calls.append(client_call) + client_events_future = test_utilities.SimpleFuture( + lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) + with state.condition: while True: if state.parked_handlers < test_constants.THREAD_CONCURRENCY: @@ -201,12 +207,14 @@ class CancelManyCallsTest(unittest.TestCase): state.condition.notify_all() break - client_driver.events( - test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) + client_events_future.result() with client_condition: for client_call in client_calls: - client_call.cancel() + client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') + for _ in range(_UNSUCCESSFUL_CALLS): + channel.next_call_event() + channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') with state.condition: server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py index 7305d0fa3f..d95286071d 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -21,25 +21,20 @@ from grpc._cython import cygrpc from tests.unit.framework.common import test_constants -def _channel_and_completion_queue(): - channel = cygrpc.Channel(b'localhost:54321', ()) - completion_queue = cygrpc.CompletionQueue() - return channel, completion_queue +def _channel(): + return cygrpc.Channel(b'localhost:54321', (), None) -def _connectivity_loop(channel, completion_queue): +def _connectivity_loop(channel): for _ in range(100): connectivity = channel.check_connectivity_state(True) - channel.watch_connectivity_state(connectivity, - time.time() + 0.2, completion_queue, - None) - completion_queue.poll() + channel.watch_connectivity_state(connectivity, time.time() + 0.2) def _create_loop_destroy(): - channel, completion_queue = _channel_and_completion_queue() - _connectivity_loop(channel, completion_queue) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def _in_parallel(behavior, arguments): @@ -55,12 +50,9 @@ def _in_parallel(behavior, arguments): class ChannelTest(unittest.TestCase): def test_single_channel_lonely_connectivity(self): - channel, completion_queue = _channel_and_completion_queue() - _in_parallel(_connectivity_loop, ( - channel, - completion_queue, - )) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def test_multiple_channels_lonely_connectivity(self): _in_parallel(_create_loop_destroy, ()) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index 7fd3d19b4e..d8210f36f8 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -100,7 +100,8 @@ class RpcTest(object): self.server.register_completion_queue(self.server_completion_queue) port = self.server.add_http2_port(b'[::]:0') self.server.start() - self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), []) + self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], + None) self._server_shutdown_tag = 'server_shutdown_tag' self.server_condition = threading.Condition() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 7caa98f72d..8a721788f4 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -41,31 +42,27 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - client_complete_rpc_start_batch_result = client_call.start_client_batch( + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [( [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.client_driver.add_due({ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) + )]) + client_call.operate([ + cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA, + _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), + ], client_complete_rpc_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -96,20 +93,23 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = server_call_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + if client_events[0].tag is client_receive_initial_metadata_tag: + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] + else: + client_complete_rpc_event = client_events[0] + client_receive_initial_metadata_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 8582a39c01..47f39ebce2 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -36,28 +37,31 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_complete_rpc_start_batch_result = client_call.start_client_batch( - [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.client_driver.add_due({ - client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) - + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation( + _common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation( + _common.EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + ]) + client_call.operate([ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -87,20 +91,19 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = self.server_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index bc63b54879..8a903bfaf9 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -17,6 +17,7 @@ import threading import unittest from grpc._cython import cygrpc +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -118,7 +119,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set()) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(), + None) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, @@ -127,10 +129,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() server_call_condition = threading.Condition() server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' @@ -154,25 +152,28 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_completion_queue, server_rpc_tag) - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, b'/twinkies', - None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_due.add(client_receive_initial_metadata_tag) - client_complete_rpc_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_complete_rpc_tag)) - client_due.add(client_complete_rpc_tag) + client_call = channel.segregated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, ( + ( + [ + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + ), + ( + [ + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + )) + client_receive_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_rpc_event = server_driver.first_event() @@ -208,19 +209,20 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_complete_rpc_tag) server_call_driver.events() - with client_condition: - client_receive_first_message_tag = 'client_receive_first_message_tag' - client_receive_first_message_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], client_receive_first_message_tag)) - client_due.add(client_receive_first_message_tag) - client_receive_first_message_event = client_driver.event_with_tag( - client_receive_first_message_tag) + client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result( + ) + + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_call.operate([ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], client_receive_first_message_tag) + client_receive_first_message_event = client_call.next_event() - client_call_cancel_result = client_call.cancel() - client_driver.events() + client_call_cancel_result = client_call.cancel( + cygrpc.StatusCode.cancelled, 'Cancelled during test!') + client_complete_rpc_event = client_call.next_event() + channel.close(cygrpc.StatusCode.unknown, 'Channel closed!') server.shutdown(server_completion_queue, server_shutdown_tag) server.cancel_all_calls() server_driver.events() @@ -228,11 +230,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): self.assertEqual(cygrpc.CallError.ok, request_call_result) self.assertEqual(cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) self.assertIs(server_rpc_tag, server_rpc_event.tag) self.assertEqual(cygrpc.CompletionType.operation_complete, server_rpc_event.completion_type) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index 23f5ef605d..724a690746 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -51,8 +51,8 @@ class TypeSmokeTest(unittest.TestCase): del server def testChannelUpDown(self): - channel = cygrpc.Channel(b'[::]:0', None) - del channel + channel = cygrpc.Channel(b'[::]:0', None, None) + channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') def test_metadata_plugin_call_credentials_up_down(self): cygrpc.MetadataPluginCallCredentials(_metadata_plugin, @@ -121,7 +121,7 @@ class ServerClientMixin(object): client_credentials) else: self.client_channel = cygrpc.Channel('localhost:{}'.format( - self.port).encode(), set()) + self.port).encode(), set(), None) if host_override: self.host_argument = None # default host self.expected_host = host_override @@ -131,17 +131,20 @@ class ServerClientMixin(object): self.expected_host = self.host_argument def tearDownMixin(self): + self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') + del self.client_channel del self.server del self.client_completion_queue del self.server_completion_queue - def _perform_operations(self, operations, call, queue, deadline, - description): - """Perform the list of operations with given call, queue, and deadline. + def _perform_queue_operations(self, operations, call, queue, deadline, + description): + """Perform the operations with given call, queue, and deadline. - Invocation errors are reported with as an exception with `description` in - the message. Performs the operations asynchronously, returning a future. - """ + Invocation errors are reported with as an exception with `description` + in the message. Performs the operations asynchronously, returning a + future. + """ def performer(): tag = object() @@ -185,9 +188,6 @@ class ServerClientMixin(object): self.assertEqual(cygrpc.CallError.ok, request_call_result) client_call_tag = object() - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) client_initial_metadata = ( ( CLIENT_METADATA_ASCII_KEY, @@ -198,18 +198,24 @@ class ServerClientMixin(object): CLIENT_METADATA_BIN_VALUE, ), ) - client_start_batch_result = client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(client_initial_metadata, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_call_tag) - self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) - client_event_future = test_utilities.CompletionQueuePollFuture( - self.client_completion_queue, DEADLINE) + client_call = self.client_channel.integrated_call( + 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, + None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + client_initial_metadata, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_call_tag, + ), + ]) + client_event_future = test_utilities.SimpleFuture( + self.client_channel.next_call_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) self.assertEqual(cygrpc.CompletionType.operation_complete, @@ -304,66 +310,76 @@ class ServerClientMixin(object): del client_call del server_call - def test6522(self): + def test_6522(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 METHOD = b'twinkies' empty_metadata = () + # Prologue server_request_tag = object() self.server.request_call(self.server_completion_queue, self.server_completion_queue, server_request_tag) - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) - - # Prologue - def perform_client_operations(operations, description): - return self._perform_operations(operations, client_call, - self.client_completion_queue, - DEADLINE, description) - - client_event_future = perform_client_operations([ - cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], "Client prologue") + client_call = self.client_channel.segregated_call( + 0, METHOD, self.host_argument, DEADLINE, None, None, ([( + [ + cygrpc.SendInitialMetadataOperation(empty_metadata, + _EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + object(), + ), ( + [ + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + object(), + )])) + + client_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) server_call = request_event.call def perform_server_operations(operations, description): - return self._perform_operations(operations, server_call, - self.server_completion_queue, - DEADLINE, description) + return self._perform_queue_operations(operations, server_call, + self.server_completion_queue, + DEADLINE, description) server_event_future = perform_server_operations([ cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), ], "Server prologue") - client_event_future.result() # force completion + client_initial_metadata_event_future.result() # force completion server_event_future.result() # Messaging for _ in range(10): - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Client message") + client_message_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_event_future = perform_server_operations([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Server receive") - client_event_future.result() # force completion + client_message_event_future.result() # force completion server_event_future.result() # Epilogue - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS) ], "Client epilogue") + # One for ReceiveStatusOnClient, one for SendCloseFromClient. + client_events_future = test_utilities.SimpleFuture( + lambda: { + client_call.next_event(), + client_call.next_event(),}) server_event_future = perform_server_operations([ cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), @@ -371,7 +387,7 @@ class ServerClientMixin(object): empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) ], "Server epilogue") - client_event_future.result() # force completion + client_events_future.result() # force completion server_event_future.result() diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 4edf0fc4ad..f153089a24 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -81,29 +81,16 @@ class InvalidMetadataTest(unittest.TestCase): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._unary_unary.future(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.future(request, metadata=metadata) def testUnaryRequestStreamResponse(self): request = b'\x37\x58' metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._unary_stream(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_stream(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestBlockingUnaryResponse(self): request_iterator = ( @@ -129,32 +116,18 @@ class InvalidMetadataTest(unittest.TestCase): b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._stream_unary.future( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_unary.future(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestStreamResponse(self): request_iterator = ( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._stream_stream( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_stream(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) if __name__ == '__main__': |