aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio/grpc/_channel.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio/grpc/_channel.py')
-rw-r--r--src/python/grpcio/grpc/_channel.py59
1 files changed, 50 insertions, 9 deletions
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index e9246991df..6876601785 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -111,6 +111,10 @@ class _RPCState(object):
# prior to termination of the RPC.
self.cancelled = False
self.callbacks = []
+ self.fork_epoch = cygrpc.get_fork_epoch()
+
+ def reset_postfork_child(self):
+ self.condition = threading.Condition()
def _abort(state, code, details):
@@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer):
done = not state.due
for callback in callbacks:
callback()
- return done
+ return done and state.fork_epoch >= cygrpc.get_fork_epoch()
return handle_event
def _consume_request_iterator(request_iterator, state, call, request_serializer,
event_handler):
+ if cygrpc.is_fork_support_enabled():
+ condition_wait_timeout = 1.0
+ else:
+ condition_wait_timeout = None
def consume_request_iterator(): # pylint: disable=too-many-branches
while True:
+ return_from_user_request_generator_invoked = False
try:
+ # The thread may die in user-code. Do not block fork for this.
+ cygrpc.enter_user_request_generator()
request = next(request_iterator)
except StopIteration:
break
except Exception: # pylint: disable=broad-except
+ cygrpc.return_from_user_request_generator()
+ return_from_user_request_generator_invoked = True
code = grpc.StatusCode.UNKNOWN
details = 'Exception iterating requests!'
_LOGGER.exception(details)
@@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
details)
_abort(state, code, details)
return
+ finally:
+ if not return_from_user_request_generator_invoked:
+ cygrpc.return_from_user_request_generator()
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
if state.code is None and not state.cancelled:
@@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
else:
return
while True:
- state.condition.wait()
+ state.condition.wait(condition_wait_timeout)
+ cygrpc.block_if_fork_in_progress(state)
if state.code is None:
if cygrpc.OperationType.send_message not in state.due:
break
@@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
if operating:
state.due.add(cygrpc.OperationType.send_close_from_client)
- consumption_thread = threading.Thread(target=consume_request_iterator)
- consumption_thread.daemon = True
+ consumption_thread = cygrpc.ForkManagedThread(
+ target=consume_request_iterator)
+ consumption_thread.setDaemon(True)
consumption_thread.start()
@@ -671,13 +689,20 @@ class _ChannelCallState(object):
self.lock = threading.Lock()
self.channel = channel
self.managed_calls = 0
+ self.threading = False
+
+ def reset_postfork_child(self):
+ self.managed_calls = 0
def _run_channel_spin_thread(state):
def channel_spin():
while True:
+ cygrpc.block_if_fork_in_progress(state)
event = state.channel.next_call_event()
+ if event.completion_type == cygrpc.CompletionType.queue_timeout:
+ continue
call_completed = event.tag(event)
if call_completed:
with state.lock:
@@ -685,8 +710,8 @@ def _run_channel_spin_thread(state):
if state.managed_calls == 0:
return
- channel_spin_thread = threading.Thread(target=channel_spin)
- channel_spin_thread.daemon = True
+ channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin)
+ channel_spin_thread.setDaemon(True)
channel_spin_thread.start()
@@ -742,6 +767,13 @@ class _ChannelConnectivityState(object):
self.callbacks_and_connectivities = []
self.delivering = False
+ def reset_postfork_child(self):
+ self.polling = False
+ self.connectivity = None
+ self.try_to_connect = False
+ self.callbacks_and_connectivities = []
+ self.delivering = False
+
def _deliveries(state):
callbacks_needing_update = []
@@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
callbacks = initial_callbacks
while True:
for callback in callbacks:
+ cygrpc.block_if_fork_in_progress(state)
callable_util.call_logging_exceptions(
callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
connectivity)
@@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
def _spawn_delivery(state, callbacks):
- delivering_thread = threading.Thread(
+ delivering_thread = cygrpc.ForkManagedThread(
target=_deliver, args=(
state,
state.connectivity,
@@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
while True:
event = channel.watch_connectivity_state(connectivity,
time.time() + 0.2)
+ cygrpc.block_if_fork_in_progress(state)
with state.lock:
if not state.callbacks_and_connectivities and not state.try_to_connect:
state.polling = False
@@ -826,10 +860,10 @@ def _moot(state):
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
- polling_thread = threading.Thread(
+ polling_thread = cygrpc.ForkManagedThread(
target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
- polling_thread.daemon = True
+ polling_thread.setDaemon(True)
polling_thread.start()
state.polling = True
state.callbacks_and_connectivities.append([callback, None])
@@ -876,6 +910,7 @@ class Channel(grpc.Channel):
_common.encode(target), _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
+ cygrpc.fork_register_channel(self)
def subscribe(self, callback, try_to_connect=None):
_subscribe(self._connectivity_state, callback, try_to_connect)
@@ -919,6 +954,11 @@ class Channel(grpc.Channel):
self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!')
_moot(self._connectivity_state)
+ def _close_on_fork(self):
+ self._channel.close_on_fork(cygrpc.StatusCode.cancelled,
+ 'Channel closed due to fork')
+ _moot(self._connectivity_state)
+
def __enter__(self):
return self
@@ -939,4 +979,5 @@ class Channel(grpc.Channel):
# for as long as they are in use and to close them after using them,
# then deletion of this grpc._channel.Channel instance can be made to
# effect closure of the underlying cygrpc.Channel instance.
+ cygrpc.fork_unregister_channel(self)
_moot(self._connectivity_state)