# Copyright 2016 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Service-side implementation of gRPC Python.""" import collections import enum import logging import threading import time import six import grpc from grpc import _common from grpc import _interceptor from grpc._cython import cygrpc from grpc.framework.foundation import callable_util _LOGGER = logging.getLogger(__name__) _SHUTDOWN_TAG = 'shutdown' _REQUEST_CALL_TAG = 'request_call' _RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' _SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' _RECEIVE_MESSAGE_TOKEN = 'receive_message' _SEND_MESSAGE_TOKEN = 'send_message' _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( 'send_initial_metadata * send_message') _SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( 'send_initial_metadata * send_status_from_server') _OPEN = 'open' _CLOSED = 'closed' _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 def _serialized_request(request_event): return request_event.batch_operations[0].message() def _application_code(code): cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code def _completion_code(state): if state.code is None: return cygrpc.StatusCode.ok else: return _application_code(state.code) def _abortion_code(state, code): if state.code is None: return code else: return _application_code(state.code) def _details(state): return b'' if state.details is None else state.details class _HandlerCallDetails( collections.namedtuple('_HandlerCallDetails', ( 'method', 'invocation_metadata', )), grpc.HandlerCallDetails): pass class _RPCState(object): def __init__(self): self.condition = threading.Condition() self.due = set() self.request = None self.client = _OPEN self.initial_metadata_allowed = True self.disable_next_compression = False self.trailing_metadata = None self.code = None self.details = None self.statused = False self.rpc_errors = [] self.callbacks = [] self.abortion = None def _raise_rpc_error(state): rpc_error = grpc.RpcError() state.rpc_errors.append(rpc_error) raise rpc_error def _possibly_finish_call(state, token): state.due.remove(token) if (state.client is _CANCELLED or state.statused) and not state.due: callbacks = state.callbacks state.callbacks = None return state, callbacks else: return None, () def _send_status_from_server(state, token): def send_status_from_server(unused_send_status_from_server_event): with state.condition: return _possibly_finish_call(state, token) return send_status_from_server def _abort(state, call, code, details): if state.client is not _CANCELLED: effective_code = _abortion_code(state, code) effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = ( cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation( state.trailing_metadata, effective_code, effective_details, _EMPTY_FLAGS), ) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: operations = (cygrpc.SendStatusFromServerOperation( state.trailing_metadata, effective_code, effective_details, _EMPTY_FLAGS),) token = _SEND_STATUS_FROM_SERVER_TOKEN call.start_server_batch(operations, _send_status_from_server(state, token)) state.statused = True state.due.add(token) def _receive_close_on_server(state): def receive_close_on_server(receive_close_on_server_event): with state.condition: if receive_close_on_server_event.batch_operations[0].cancelled(): state.client = _CANCELLED elif state.client is _OPEN: state.client = _CLOSED state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) return receive_close_on_server def _receive_message(state, call, request_deserializer): def receive_message(receive_message_event): serialized_request = _serialized_request(receive_message_event) if serialized_request is None: with state.condition: if state.client is _OPEN: state.client = _CLOSED state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) else: request = _common.deserialize(serialized_request, request_deserializer) with state.condition: if request is None: _abort(state, call, cygrpc.StatusCode.internal, b'Exception deserializing request!') else: state.request = request state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) return receive_message def _send_initial_metadata(state): def send_initial_metadata(unused_send_initial_metadata_event): with state.condition: return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) return send_initial_metadata def _send_message(state, token): def send_message(unused_send_message_event): with state.condition: state.condition.notify_all() return _possibly_finish_call(state, token) return send_message class _Context(grpc.ServicerContext): def __init__(self, rpc_event, state, request_deserializer): self._rpc_event = rpc_event self._state = state self._request_deserializer = request_deserializer def is_active(self): with self._state.condition: return self._state.client is not _CANCELLED and not self._state.statused def time_remaining(self): return max(self._rpc_event.call_details.deadline - time.time(), 0) def cancel(self): self._rpc_event.call.cancel() def add_callback(self, callback): with self._state.condition: if self._state.callbacks is None: return False else: self._state.callbacks.append(callback) return True def disable_next_message_compression(self): with self._state.condition: self._state.disable_next_compression = True def invocation_metadata(self): return self._rpc_event.invocation_metadata def peer(self): return _common.decode(self._rpc_event.call.peer()) def peer_identities(self): return cygrpc.peer_identities(self._rpc_event.call) def peer_identity_key(self): id_key = cygrpc.peer_identity_key(self._rpc_event.call) return id_key if id_key is None else _common.decode(id_key) def auth_context(self): return { _common.decode(key): value for key, value in six.iteritems( cygrpc.auth_context(self._rpc_event.call)) } def send_initial_metadata(self, initial_metadata): with self._state.condition: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) else: if self._state.initial_metadata_allowed: operation = cygrpc.SendInitialMetadataOperation( initial_metadata, _EMPTY_FLAGS) self._rpc_event.call.start_server_batch( (operation,), _send_initial_metadata(self._state)) self._state.initial_metadata_allowed = False self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) else: raise ValueError('Initial metadata no longer allowed!') def set_trailing_metadata(self, trailing_metadata): with self._state.condition: self._state.trailing_metadata = trailing_metadata def abort(self, code, details): # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: _LOGGER.error( 'abort() called with StatusCode.OK; returning UNKNOWN') code = grpc.StatusCode.UNKNOWN details = '' with self._state.condition: self._state.code = code self._state.details = _common.encode(details) self._state.abortion = Exception() raise self._state.abortion def abort_with_status(self, status): self._state.trailing_metadata = status.trailing_metadata self.abort(status.code, status.details) def set_code(self, code): with self._state.condition: self._state.code = code def set_details(self, details): with self._state.condition: self._state.details = _common.encode(details) class _RequestIterator(object): def __init__(self, state, call, request_deserializer): self._state = state self._call = call self._request_deserializer = request_deserializer def _raise_or_start_receive_message(self): if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif self._state.client is _CLOSED or self._state.statused: raise StopIteration() else: self._call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), _receive_message(self._state, self._call, self._request_deserializer)) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) def _look_for_request(self): if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif (self._state.request is None and _RECEIVE_MESSAGE_TOKEN not in self._state.due): raise StopIteration() else: request = self._state.request self._state.request = None return request raise AssertionError() # should never run def _next(self): with self._state.condition: self._raise_or_start_receive_message() while True: self._state.condition.wait() request = self._look_for_request() if request is not None: return request def __iter__(self): return self def __next__(self): return self._next() def next(self): return self._next() def _unary_request(rpc_event, state, request_deserializer): def unary_request(): with state.condition: if state.client is _CANCELLED or state.statused: return None else: rpc_event.call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), _receive_message(state, rpc_event.call, request_deserializer)) state.due.add(_RECEIVE_MESSAGE_TOKEN) while True: state.condition.wait() if state.request is None: if state.client is _CLOSED: details = '"{}" requires exactly one request message.'.format( rpc_event.call_details.method) _abort(state, rpc_event.call, cygrpc.StatusCode.unimplemented, _common.encode(details)) return None elif state.client is _CANCELLED: return None else: request = state.request state.request = None return request return unary_request def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): context = _Context(rpc_event, state, request_deserializer) try: return behavior(argument, context), True except Exception as exception: # pylint: disable=broad-except with state.condition: if exception is state.abortion: _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, b'RPC Aborted') elif exception not in state.rpc_errors: details = 'Exception calling application: {}'.format(exception) _LOGGER.exception(details) _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details)) return None, False def _take_response_from_response_iterator(rpc_event, state, response_iterator): try: return next(response_iterator), True except StopIteration: return None, True except Exception as exception: # pylint: disable=broad-except with state.condition: if exception is state.abortion: _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, b'RPC Aborted') elif exception not in state.rpc_errors: details = 'Exception iterating responses: {}'.format(exception) _LOGGER.exception(details) _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details)) return None, False def _serialize_response(rpc_event, state, response, response_serializer): serialized_response = _common.serialize(response, response_serializer) if serialized_response is None: with state.condition: _abort(state, rpc_event.call, cygrpc.StatusCode.internal, b'Failed to serialize response!') return None else: return serialized_response def _send_response(rpc_event, state, serialized_response): with state.condition: if state.client is _CANCELLED or state.statused: return False else: if state.initial_metadata_allowed: operations = ( cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), cygrpc.SendMessageOperation(serialized_response, _EMPTY_FLAGS), ) state.initial_metadata_allowed = False token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN else: operations = (cygrpc.SendMessageOperation( serialized_response, _EMPTY_FLAGS),) token = _SEND_MESSAGE_TOKEN rpc_event.call.start_server_batch(operations, _send_message(state, token)) state.due.add(token) while True: state.condition.wait() if token not in state.due: return state.client is not _CANCELLED and not state.statused def _status(rpc_event, state, serialized_response): with state.condition: if state.client is not _CANCELLED: code = _completion_code(state) details = _details(state) operations = [ cygrpc.SendStatusFromServerOperation( state.trailing_metadata, code, details, _EMPTY_FLAGS), ] if state.initial_metadata_allowed: operations.append( cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS)) if serialized_response is not None: operations.append( cygrpc.SendMessageOperation(serialized_response, _EMPTY_FLAGS)) rpc_event.call.start_server_batch( operations, _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) state.statused = True state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, request_deserializer, response_serializer): cygrpc.install_census_context_from_call(rpc_event.call) try: argument = argument_thunk() if argument is not None: response, proceed = _call_behavior(rpc_event, state, behavior, argument, request_deserializer) if proceed: serialized_response = _serialize_response( rpc_event, state, response, response_serializer) if serialized_response is not None: _status(rpc_event, state, serialized_response) finally: cygrpc.uninstall_context() def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, request_deserializer, response_serializer): cygrpc.install_census_context_from_call(rpc_event.call) try: argument = argument_thunk() if argument is not None: response_iterator, proceed = _call_behavior( rpc_event, state, behavior, argument, request_deserializer) if proceed: while True: response, proceed = _take_response_from_response_iterator( rpc_event, state, response_iterator) if proceed: if response is None: _status(rpc_event, state, None) break else: serialized_response = _serialize_response( rpc_event, state, response, response_serializer) if serialized_response is not None: proceed = _send_response( rpc_event, state, serialized_response) if not proceed: break else: break else: break finally: cygrpc.uninstall_context() def _handle_unary_unary(rpc_event, state, method_handler, thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) return thread_pool.submit(_unary_response_in_pool, rpc_event, state, method_handler.unary_unary, unary_request, method_handler.request_deserializer, method_handler.response_serializer) def _handle_unary_stream(rpc_event, state, method_handler, thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) return thread_pool.submit(_stream_response_in_pool, rpc_event, state, method_handler.unary_stream, unary_request, method_handler.request_deserializer, method_handler.response_serializer) def _handle_stream_unary(rpc_event, state, method_handler, thread_pool): request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) return thread_pool.submit( _unary_response_in_pool, rpc_event, state, method_handler.stream_unary, lambda: request_iterator, method_handler.request_deserializer, method_handler.response_serializer) def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) return thread_pool.submit( _stream_response_in_pool, rpc_event, state, method_handler.stream_stream, lambda: request_iterator, method_handler.request_deserializer, method_handler.response_serializer) def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): def query_handlers(handler_call_details): for generic_handler in generic_handlers: method_handler = generic_handler.service(handler_call_details) if method_handler is not None: return method_handler return None handler_call_details = _HandlerCallDetails( _common.decode(rpc_event.call_details.method), rpc_event.invocation_metadata) if interceptor_pipeline is not None: return interceptor_pipeline.execute(query_handlers, handler_call_details) else: return query_handlers(handler_call_details) def _reject_rpc(rpc_event, status, details): operations = ( cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation(None, status, details, _EMPTY_FLAGS), ) rpc_state = _RPCState() rpc_event.call.start_server_batch(operations, lambda ignored_event: (rpc_state, (),)) return rpc_state def _handle_with_method_handler(rpc_event, method_handler, thread_pool): state = _RPCState() with state.condition: rpc_event.call.start_server_batch( (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), _receive_close_on_server(state)) state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) if method_handler.request_streaming: if method_handler.response_streaming: return state, _handle_stream_stream(rpc_event, state, method_handler, thread_pool) else: return state, _handle_stream_unary(rpc_event, state, method_handler, thread_pool) else: if method_handler.response_streaming: return state, _handle_unary_stream(rpc_event, state, method_handler, thread_pool) else: return state, _handle_unary_unary(rpc_event, state, method_handler, thread_pool) def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, concurrency_exceeded): if not rpc_event.success: return None, None if rpc_event.call_details.method is not None: try: method_handler = _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline) except Exception as exception: # pylint: disable=broad-except details = 'Exception servicing handler: {}'.format(exception) _LOGGER.exception(details) return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, b'Error in service handler!'), None if method_handler is None: return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, b'Method not found!'), None elif concurrency_exceeded: return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, b'Concurrent RPC limit exceeded!'), None else: return _handle_with_method_handler(rpc_event, method_handler, thread_pool) else: return None, None @enum.unique class _ServerStage(enum.Enum): STOPPED = 'stopped' STARTED = 'started' GRACE = 'grace' class _ServerState(object): # pylint: disable=too-many-arguments def __init__(self, completion_queue, server, generic_handlers, interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): self.lock = threading.RLock() self.completion_queue = completion_queue self.server = server self.generic_handlers = list(generic_handlers) self.interceptor_pipeline = interceptor_pipeline self.thread_pool = thread_pool self.stage = _ServerStage.STOPPED self.shutdown_events = None self.maximum_concurrent_rpcs = maximum_concurrent_rpcs self.active_rpc_count = 0 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. self.rpc_states = set() self.due = set() # A "volatile" flag to interrupt the daemon serving thread self.server_deallocated = False def _add_generic_handlers(state, generic_handlers): with state.lock: state.generic_handlers.extend(generic_handlers) def _add_insecure_port(state, address): with state.lock: return state.server.add_http2_port(address) def _add_secure_port(state, address, server_credentials): with state.lock: return state.server.add_http2_port(address, server_credentials._credentials) def _request_call(state): state.server.request_call(state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG) state.due.add(_REQUEST_CALL_TAG) # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. def _stop_serving(state): if not state.rpc_states and not state.due: state.server.destroy() for shutdown_event in state.shutdown_events: shutdown_event.set() state.stage = _ServerStage.STOPPED return True else: return False def _on_call_completed(state): with state.lock: state.active_rpc_count -= 1 def _process_event_and_continue(state, event): should_continue = True if event.tag is _SHUTDOWN_TAG: with state.lock: state.due.remove(_SHUTDOWN_TAG) if _stop_serving(state): should_continue = False elif event.tag is _REQUEST_CALL_TAG: with state.lock: state.due.remove(_REQUEST_CALL_TAG) concurrency_exceeded = ( state.maximum_concurrent_rpcs is not None and state.active_rpc_count >= state.maximum_concurrent_rpcs) rpc_state, rpc_future = _handle_call( event, state.generic_handlers, state.interceptor_pipeline, state.thread_pool, concurrency_exceeded) if rpc_state is not None: state.rpc_states.add(rpc_state) if rpc_future is not None: state.active_rpc_count += 1 rpc_future.add_done_callback( lambda unused_future: _on_call_completed(state)) if state.stage is _ServerStage.STARTED: _request_call(state) elif _stop_serving(state): should_continue = False else: rpc_state, callbacks = event.tag(event) for callback in callbacks: callable_util.call_logging_exceptions(callback, 'Exception calling callback!') if rpc_state is not None: with state.lock: state.rpc_states.remove(rpc_state) if _stop_serving(state): should_continue = False return should_continue def _serve(state): while True: timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S event = state.completion_queue.poll(timeout) if state.server_deallocated: _begin_shutdown_once(state) if event.completion_type != cygrpc.CompletionType.queue_timeout: if not _process_event_and_continue(state, event): return # We want to force the deletion of the previous event # ~before~ we poll again; if the event has a reference # to a shutdown Call object, this can induce spinlock. event = None def _begin_shutdown_once(state): with state.lock: if state.stage is _ServerStage.STARTED: state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) state.stage = _ServerStage.GRACE state.shutdown_events = [] state.due.add(_SHUTDOWN_TAG) def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: shutdown_event = threading.Event() shutdown_event.set() return shutdown_event else: _begin_shutdown_once(state) shutdown_event = threading.Event() state.shutdown_events.append(shutdown_event) if grace is None: state.server.cancel_all_calls() else: def cancel_all_calls_after_grace(): shutdown_event.wait(timeout=grace) with state.lock: state.server.cancel_all_calls() thread = threading.Thread(target=cancel_all_calls_after_grace) thread.start() return shutdown_event shutdown_event.wait() return shutdown_event def _start(state): with state.lock: if state.stage is not _ServerStage.STOPPED: raise ValueError('Cannot start already-started server!') state.server.start() state.stage = _ServerStage.STARTED _request_call(state) thread = threading.Thread(target=_serve, args=(state,)) thread.daemon = True thread.start() def _validate_generic_rpc_handlers(generic_rpc_handlers): for generic_rpc_handler in generic_rpc_handlers: service_attribute = getattr(generic_rpc_handler, 'service', None) if service_attribute is None: raise AttributeError( '"{}" must conform to grpc.GenericRpcHandler type but does ' 'not have "service" method!'.format(generic_rpc_handler)) class _Server(grpc.Server): # pylint: disable=too-many-arguments def __init__(self, thread_pool, generic_handlers, interceptors, options, maximum_concurrent_rpcs): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(options) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, _interceptor.service_pipeline(interceptors), thread_pool, maximum_concurrent_rpcs) def add_generic_rpc_handlers(self, generic_rpc_handlers): _validate_generic_rpc_handlers(generic_rpc_handlers) _add_generic_handlers(self._state, generic_rpc_handlers) def add_insecure_port(self, address): return _add_insecure_port(self._state, _common.encode(address)) def add_secure_port(self, address, server_credentials): return _add_secure_port(self._state, _common.encode(address), server_credentials) def start(self): _start(self._state) def stop(self, grace): return _stop(self._state, grace) def __del__(self): # We can not grab a lock in __del__(), so set a flag to signal the # serving daemon thread (if it exists) to initiate shutdown. self._state.server_deallocated = True def create_server(thread_pool, generic_rpc_handlers, interceptors, options, maximum_concurrent_rpcs): _validate_generic_rpc_handlers(generic_rpc_handlers) return _Server(thread_pool, generic_rpc_handlers, interceptors, options, maximum_concurrent_rpcs)