diff options
Diffstat (limited to 'src/python/grpcio/grpc/_channel.py')
-rw-r--r-- | src/python/grpcio/grpc/_channel.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index e9246991df..6876601785 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -111,6 +111,10 @@ class _RPCState(object): # prior to termination of the RPC. self.cancelled = False self.callbacks = [] + self.fork_epoch = cygrpc.get_fork_epoch() + + def reset_postfork_child(self): + self.condition = threading.Condition() def _abort(state, code, details): @@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer): done = not state.due for callback in callbacks: callback() - return done + return done and state.fork_epoch >= cygrpc.get_fork_epoch() return handle_event def _consume_request_iterator(request_iterator, state, call, request_serializer, event_handler): + if cygrpc.is_fork_support_enabled(): + condition_wait_timeout = 1.0 + else: + condition_wait_timeout = None def consume_request_iterator(): # pylint: disable=too-many-branches while True: + return_from_user_request_generator_invoked = False try: + # The thread may die in user-code. Do not block fork for this. + cygrpc.enter_user_request_generator() request = next(request_iterator) except StopIteration: break except Exception: # pylint: disable=broad-except + cygrpc.return_from_user_request_generator() + return_from_user_request_generator_invoked = True code = grpc.StatusCode.UNKNOWN details = 'Exception iterating requests!' _LOGGER.exception(details) @@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, details) _abort(state, code, details) return + finally: + if not return_from_user_request_generator_invoked: + cygrpc.return_from_user_request_generator() serialized_request = _common.serialize(request, request_serializer) with state.condition: if state.code is None and not state.cancelled: @@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, else: return while True: - state.condition.wait() + state.condition.wait(condition_wait_timeout) + cygrpc.block_if_fork_in_progress(state) if state.code is None: if cygrpc.OperationType.send_message not in state.due: break @@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, if operating: state.due.add(cygrpc.OperationType.send_close_from_client) - consumption_thread = threading.Thread(target=consume_request_iterator) - consumption_thread.daemon = True + consumption_thread = cygrpc.ForkManagedThread( + target=consume_request_iterator) + consumption_thread.setDaemon(True) consumption_thread.start() @@ -671,13 +689,20 @@ class _ChannelCallState(object): self.lock = threading.Lock() self.channel = channel self.managed_calls = 0 + self.threading = False + + def reset_postfork_child(self): + self.managed_calls = 0 def _run_channel_spin_thread(state): def channel_spin(): while True: + cygrpc.block_if_fork_in_progress(state) event = state.channel.next_call_event() + if event.completion_type == cygrpc.CompletionType.queue_timeout: + continue call_completed = event.tag(event) if call_completed: with state.lock: @@ -685,8 +710,8 @@ def _run_channel_spin_thread(state): if state.managed_calls == 0: return - channel_spin_thread = threading.Thread(target=channel_spin) - channel_spin_thread.daemon = True + channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin) + channel_spin_thread.setDaemon(True) channel_spin_thread.start() @@ -742,6 +767,13 @@ class _ChannelConnectivityState(object): self.callbacks_and_connectivities = [] self.delivering = False + def reset_postfork_child(self): + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + def _deliveries(state): callbacks_needing_update = [] @@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): callbacks = initial_callbacks while True: for callback in callbacks: + cygrpc.block_if_fork_in_progress(state) callable_util.call_logging_exceptions( callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE, connectivity) @@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): def _spawn_delivery(state, callbacks): - delivering_thread = threading.Thread( + delivering_thread = cygrpc.ForkManagedThread( target=_deliver, args=( state, state.connectivity, @@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect): while True: event = channel.watch_connectivity_state(connectivity, time.time() + 0.2) + cygrpc.block_if_fork_in_progress(state) with state.lock: if not state.callbacks_and_connectivities and not state.try_to_connect: state.polling = False @@ -826,10 +860,10 @@ def _moot(state): def _subscribe(state, callback, try_to_connect): with state.lock: if not state.callbacks_and_connectivities and not state.polling: - polling_thread = threading.Thread( + polling_thread = cygrpc.ForkManagedThread( target=_poll_connectivity, args=(state, state.channel, bool(try_to_connect))) - polling_thread.daemon = True + polling_thread.setDaemon(True) polling_thread.start() state.polling = True state.callbacks_and_connectivities.append([callback, None]) @@ -876,6 +910,7 @@ class Channel(grpc.Channel): _common.encode(target), _options(options), credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) + cygrpc.fork_register_channel(self) def subscribe(self, callback, try_to_connect=None): _subscribe(self._connectivity_state, callback, try_to_connect) @@ -919,6 +954,11 @@ class Channel(grpc.Channel): self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') _moot(self._connectivity_state) + def _close_on_fork(self): + self._channel.close_on_fork(cygrpc.StatusCode.cancelled, + 'Channel closed due to fork') + _moot(self._connectivity_state) + def __enter__(self): return self @@ -939,4 +979,5 @@ class Channel(grpc.Channel): # for as long as they are in use and to close them after using them, # then deletion of this grpc._channel.Channel instance can be made to # effect closure of the underlying cygrpc.Channel instance. + cygrpc.fork_unregister_channel(self) _moot(self._connectivity_state) |