diff options
6 files changed, 77 insertions, 51 deletions
diff --git a/src/python/grpcio/grpc/framework/core/_ingestion.py b/src/python/grpcio/grpc/framework/core/_ingestion.py index 7b8127f3fc..766d57f931 100644 --- a/src/python/grpcio/grpc/framework/core/_ingestion.py +++ b/src/python/grpcio/grpc/framework/core/_ingestion.py @@ -114,7 +114,7 @@ class _ServiceSubscriptionCreator(_SubscriptionCreator): group, method, self._operation_context, self._output_operator) except base.NoSuchMethodError as e: return _SubscriptionCreation( - _SubscriptionCreation.Kind.REMOTE_ERROR, None, e.code, e.message) + _SubscriptionCreation.Kind.REMOTE_ERROR, None, e.code, e.details) except abandonment.Abandoned: return _SubscriptionCreation( _SubscriptionCreation.Kind.ABANDONED, None, None, None) diff --git a/src/python/grpcio/grpc/framework/core/_transmission.py b/src/python/grpcio/grpc/framework/core/_transmission.py index efef87dd4c..202a71dd71 100644 --- a/src/python/grpcio/grpc/framework/core/_transmission.py +++ b/src/python/grpcio/grpc/framework/core/_transmission.py @@ -29,6 +29,9 @@ """State and behavior for ticket transmission during an operation.""" +import collections +import enum + from grpc.framework.core import _constants from grpc.framework.core import _interfaces from grpc.framework.foundation import callable_util @@ -47,6 +50,31 @@ def _explode_completion(completion): links.Ticket.Termination.COMPLETION) +class _Abort( + collections.namedtuple( + '_Abort', ('kind', 'termination', 'code', 'details',))): + """Tracks whether the operation aborted and what is to be done about it. + + Attributes: + kind: A Kind value describing the overall kind of the _Abort. + termination: A links.Ticket.Termination value to be sent to the other side + of the operation. Only valid if kind is Kind.ABORTED_NOTIFY_NEEDED. + code: A code value to be sent to the other side of the operation. Only + valid if kind is Kind.ABORTED_NOTIFY_NEEDED. + details: A details value to be sent to the other side of the operation. + Only valid if kind is Kind.ABORTED_NOTIFY_NEEDED. + """ + + @enum.unique + class Kind(enum.Enum): + NOT_ABORTED = 'not aborted' + ABORTED_NOTIFY_NEEDED = 'aborted notify needed' + ABORTED_NO_NOTIFY = 'aborted no notify' + +_NOT_ABORTED = _Abort(_Abort.Kind.NOT_ABORTED, None, None, None) +_ABORTED_NO_NOTIFY = _Abort(_Abort.Kind.ABORTED_NO_NOTIFY, None, None, None) + + class TransmissionManager(_interfaces.TransmissionManager): """An _interfaces.TransmissionManager that sends links.Tickets.""" @@ -79,8 +107,7 @@ class TransmissionManager(_interfaces.TransmissionManager): self._initial_metadata = None self._payloads = [] self._completion = None - self._aborted = False - self._abortion_outcome = None + self._abort = _NOT_ABORTED self._transmitting = False def set_expiration_manager(self, expiration_manager): @@ -94,24 +121,15 @@ class TransmissionManager(_interfaces.TransmissionManager): A links.Ticket to be sent to the other side of the operation or None if there is nothing to be sent at this time. """ - if self._aborted: - if self._abortion_outcome is None: - return None - else: - termination = _constants.ABORTION_OUTCOME_TO_TICKET_TERMINATION[ - self._abortion_outcome] - if termination is None: - return None - else: - self._abortion_outcome = None - if self._completion is None: - code, message = None, None - else: - code, message = self._completion.code, self._completion.message - return links.Ticket( - self._operation_id, self._lowest_unused_sequence_number, None, - None, None, None, None, None, None, None, code, message, - termination, None) + if self._abort.kind is _Abort.Kind.ABORTED_NO_NOTIFY: + return None + elif self._abort.kind is _Abort.Kind.ABORTED_NOTIFY_NEEDED: + termination = self._abort.termination + code, details = self._abort.code, self._abort.details + self._abort = _ABORTED_NO_NOTIFY + return links.Ticket( + self._operation_id, self._lowest_unused_sequence_number, None, None, + None, None, None, None, None, None, code, details, termination, None) action = False # TODO(nathaniel): Support other subscriptions. @@ -174,6 +192,7 @@ class TransmissionManager(_interfaces.TransmissionManager): return else: with self._lock: + self._abort = _ABORTED_NO_NOTIFY if self._termination_manager.outcome is None: self._termination_manager.abort(base.Outcome.TRANSMISSION_FAILURE) self._expiration_manager.terminate() @@ -201,6 +220,9 @@ class TransmissionManager(_interfaces.TransmissionManager): def advance(self, initial_metadata, payload, completion, allowance): """See _interfaces.TransmissionManager.advance for specification.""" + if self._abort.kind is not _Abort.Kind.NOT_ABORTED: + return + effective_initial_metadata = initial_metadata effective_payload = payload effective_completion = completion @@ -246,7 +268,9 @@ class TransmissionManager(_interfaces.TransmissionManager): def timeout(self, timeout): """See _interfaces.TransmissionManager.timeout for specification.""" - if self._transmitting: + if self._abort.kind is not _Abort.Kind.NOT_ABORTED: + return + elif self._transmitting: self._timeout = timeout else: ticket = links.Ticket( @@ -257,7 +281,9 @@ class TransmissionManager(_interfaces.TransmissionManager): def allowance(self, allowance): """See _interfaces.TransmissionManager.allowance for specification.""" - if self._transmitting or not self._payloads: + if self._abort.kind is not _Abort.Kind.NOT_ABORTED: + return + elif self._transmitting or not self._payloads: self._remote_allowance += allowance else: self._remote_allowance += allowance - 1 @@ -283,20 +309,17 @@ class TransmissionManager(_interfaces.TransmissionManager): def abort(self, outcome, code, message): """See _interfaces.TransmissionManager.abort for specification.""" - if self._transmitting: - self._aborted, self._abortion_outcome = True, outcome - else: - self._aborted = True - if outcome is not None: - termination = _constants.ABORTION_OUTCOME_TO_TICKET_TERMINATION[ - outcome] - if termination is not None: - if self._completion is None: - code, message = None, None - else: - code, message = self._completion.code, self._completion.message - ticket = links.Ticket( - self._operation_id, self._lowest_unused_sequence_number, None, - None, None, None, None, None, None, None, code, message, - termination, None) - self._transmit(ticket) + if self._abort.kind is _Abort.Kind.NOT_ABORTED: + termination = _constants.ABORTION_OUTCOME_TO_TICKET_TERMINATION.get( + outcome) + if termination is None: + self._abort = _ABORTED_NO_NOTIFY + elif self._transmitting: + self._abort = _Abort( + _Abort.Kind.ABORTED_NOTIFY_NEEDED, termination, code, message) + else: + ticket = links.Ticket( + self._operation_id, self._lowest_unused_sequence_number, None, + None, None, None, None, None, None, None, code, message, + termination, None) + self._transmit(ticket) diff --git a/src/python/grpcio/grpc/framework/crust/_calls.py b/src/python/grpcio/grpc/framework/crust/_calls.py index f9077bedfe..4c6bf16f43 100644 --- a/src/python/grpcio/grpc/framework/crust/_calls.py +++ b/src/python/grpcio/grpc/framework/crust/_calls.py @@ -98,7 +98,7 @@ def blocking_unary_unary( rendezvous, unused_operation_context, unused_outcome = _invoke( end, group, method, timeout, initial_metadata, payload, True) if with_call: - return next(rendezvous, rendezvous) + return next(rendezvous), rendezvous else: return next(rendezvous) diff --git a/src/python/grpcio/grpc/framework/crust/_service.py b/src/python/grpcio/grpc/framework/crust/_service.py index 2455a58f59..6ff7249e75 100644 --- a/src/python/grpcio/grpc/framework/crust/_service.py +++ b/src/python/grpcio/grpc/framework/crust/_service.py @@ -154,7 +154,7 @@ def adapt_multi_method(multi_method, pool): outcome = operation_context.add_termination_callback(rendezvous.set_outcome) if outcome is None: def in_pool(): - request_consumer = multi_method( + request_consumer = multi_method.service( group, method, rendezvous, _ServicerContext(rendezvous)) for request in rendezvous: request_consumer.consume(request) diff --git a/src/python/grpcio/grpc/framework/crust/implementations.py b/src/python/grpcio/grpc/framework/crust/implementations.py index 12f7e79641..d38fab8ba0 100644 --- a/src/python/grpcio/grpc/framework/crust/implementations.py +++ b/src/python/grpcio/grpc/framework/crust/implementations.py @@ -49,12 +49,12 @@ class _BaseServicer(base.Servicer): return adapted_method(output_operator, context) elif self._adapted_multi_method is not None: try: - return self._adapted_multi_method.service( + return self._adapted_multi_method( group, method, output_operator, context) except face.NoSuchMethodError: - raise base.NoSuchMethodError() + raise base.NoSuchMethodError(None, None) else: - raise base.NoSuchMethodError() + raise base.NoSuchMethodError(None, None) class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): @@ -315,8 +315,11 @@ def servicer(method_implementations, multi_method_implementation, pool): """ adapted_implementations = _adapt_method_implementations( method_implementations, pool) - adapted_multi_method_implementation = _service.adapt_multi_method( - multi_method_implementation, pool) + if multi_method_implementation is None: + adapted_multi_method_implementation = None + else: + adapted_multi_method_implementation = _service.adapt_multi_method( + multi_method_implementation, pool) return _BaseServicer( adapted_implementations, adapted_multi_method_implementation) diff --git a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py index b7dd5d4d17..2d2a081955 100644 --- a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py +++ b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py @@ -82,8 +82,8 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): for test_messages in test_messages_sequence: request = test_messages.request() - response = self._invoker.blocking(group, method)( - request, test_constants.LONG_TIMEOUT) + response, call = self._invoker.blocking(group, method)( + request, test_constants.LONG_TIMEOUT, with_call=True) test_messages.verify(request, response, self) @@ -105,8 +105,8 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): for test_messages in test_messages_sequence: requests = test_messages.requests() - response = self._invoker.blocking(group, method)( - iter(requests), test_constants.LONG_TIMEOUT) + response, call = self._invoker.blocking(group, method)( + iter(requests), test_constants.LONG_TIMEOUT, with_call=True) test_messages.verify(requests, response, self) |