diff options
25 files changed, 1167 insertions, 74 deletions
diff --git a/.pylintrc-tests b/.pylintrc-tests index ebe9d507cd..e68755c674 100644 --- a/.pylintrc-tests +++ b/.pylintrc-tests @@ -20,6 +20,8 @@ notes=FIXME,XXX [MESSAGES CONTROL] +extension-pkg-whitelist=grpc._cython.cygrpc + disable= # These suppressions are specific to tests: # diff --git a/src/core/lib/gprpp/fork.cc b/src/core/lib/gprpp/fork.cc index f6d9a87d2c..0288c39680 100644 --- a/src/core/lib/gprpp/fork.cc +++ b/src/core/lib/gprpp/fork.cc @@ -157,11 +157,11 @@ class ThreadState { } // namespace void Fork::GlobalInit() { - if (!overrideEnabled_) { + if (!override_enabled_) { #ifdef GRPC_ENABLE_FORK_SUPPORT - supportEnabled_ = true; + support_enabled_ = true; #else - supportEnabled_ = false; + support_enabled_ = false; #endif bool env_var_set = false; char* env = gpr_getenv("GRPC_ENABLE_FORK_SUPPORT"); @@ -172,7 +172,7 @@ void Fork::GlobalInit() { "False", "FALSE", "0"}; for (size_t i = 0; i < GPR_ARRAY_SIZE(truthy); i++) { if (0 == strcmp(env, truthy[i])) { - supportEnabled_ = true; + support_enabled_ = true; env_var_set = true; break; } @@ -180,7 +180,7 @@ void Fork::GlobalInit() { if (!env_var_set) { for (size_t i = 0; i < GPR_ARRAY_SIZE(falsey); i++) { if (0 == strcmp(env, falsey[i])) { - supportEnabled_ = false; + support_enabled_ = false; env_var_set = true; break; } @@ -189,72 +189,79 @@ void Fork::GlobalInit() { gpr_free(env); } } - if (supportEnabled_) { - execCtxState_ = grpc_core::New<internal::ExecCtxState>(); - threadState_ = grpc_core::New<internal::ThreadState>(); + if (support_enabled_) { + exec_ctx_state_ = grpc_core::New<internal::ExecCtxState>(); + thread_state_ = grpc_core::New<internal::ThreadState>(); } } void Fork::GlobalShutdown() { - if (supportEnabled_) { - grpc_core::Delete(execCtxState_); - grpc_core::Delete(threadState_); + if (support_enabled_) { + grpc_core::Delete(exec_ctx_state_); + grpc_core::Delete(thread_state_); } } -bool Fork::Enabled() { return supportEnabled_; } +bool Fork::Enabled() { return support_enabled_; } // Testing Only void Fork::Enable(bool enable) { - overrideEnabled_ = true; - supportEnabled_ = enable; + override_enabled_ = true; + support_enabled_ = enable; } void Fork::IncExecCtxCount() { - if (supportEnabled_) { - execCtxState_->IncExecCtxCount(); + if (support_enabled_) { + exec_ctx_state_->IncExecCtxCount(); } } void Fork::DecExecCtxCount() { - if (supportEnabled_) { - execCtxState_->DecExecCtxCount(); + if (support_enabled_) { + exec_ctx_state_->DecExecCtxCount(); } } +void Fork::SetResetChildPollingEngineFunc(Fork::child_postfork_func func) { + reset_child_polling_engine_ = func; +} +Fork::child_postfork_func Fork::GetResetChildPollingEngineFunc() { + return reset_child_polling_engine_; +} + bool Fork::BlockExecCtx() { - if (supportEnabled_) { - return execCtxState_->BlockExecCtx(); + if (support_enabled_) { + return exec_ctx_state_->BlockExecCtx(); } return false; } void Fork::AllowExecCtx() { - if (supportEnabled_) { - execCtxState_->AllowExecCtx(); + if (support_enabled_) { + exec_ctx_state_->AllowExecCtx(); } } void Fork::IncThreadCount() { - if (supportEnabled_) { - threadState_->IncThreadCount(); + if (support_enabled_) { + thread_state_->IncThreadCount(); } } void Fork::DecThreadCount() { - if (supportEnabled_) { - threadState_->DecThreadCount(); + if (support_enabled_) { + thread_state_->DecThreadCount(); } } void Fork::AwaitThreads() { - if (supportEnabled_) { - threadState_->AwaitThreads(); + if (support_enabled_) { + thread_state_->AwaitThreads(); } } -internal::ExecCtxState* Fork::execCtxState_ = nullptr; -internal::ThreadState* Fork::threadState_ = nullptr; -bool Fork::supportEnabled_ = false; -bool Fork::overrideEnabled_ = false; - +internal::ExecCtxState* Fork::exec_ctx_state_ = nullptr; +internal::ThreadState* Fork::thread_state_ = nullptr; +bool Fork::support_enabled_ = false; +bool Fork::override_enabled_ = false; +Fork::child_postfork_func Fork::reset_child_polling_engine_ = nullptr; } // namespace grpc_core diff --git a/src/core/lib/gprpp/fork.h b/src/core/lib/gprpp/fork.h index 123e22c4c6..5a7404f0d9 100644 --- a/src/core/lib/gprpp/fork.h +++ b/src/core/lib/gprpp/fork.h @@ -33,6 +33,8 @@ class ThreadState; class Fork { public: + typedef void (*child_postfork_func)(void); + static void GlobalInit(); static void GlobalShutdown(); @@ -46,6 +48,12 @@ class Fork { // Decrement the count of active ExecCtxs static void DecExecCtxCount(); + // Provide a function that will be invoked in the child's postfork handler to + // reset the polling engine's internal state. + static void SetResetChildPollingEngineFunc( + child_postfork_func reset_child_polling_engine); + static child_postfork_func GetResetChildPollingEngineFunc(); + // Check if there is a single active ExecCtx // (the one used to invoke this function). If there are more, // return false. Otherwise, return true and block creation of @@ -68,10 +76,11 @@ class Fork { static void Enable(bool enable); private: - static internal::ExecCtxState* execCtxState_; - static internal::ThreadState* threadState_; - static bool supportEnabled_; - static bool overrideEnabled_; + static internal::ExecCtxState* exec_ctx_state_; + static internal::ThreadState* thread_state_; + static bool support_enabled_; + static bool override_enabled_; + static child_postfork_func reset_child_polling_engine_; }; } // namespace grpc_core diff --git a/src/core/lib/iomgr/ev_epoll1_linux.cc b/src/core/lib/iomgr/ev_epoll1_linux.cc index 66e0f1fd6d..aa5016bd8f 100644 --- a/src/core/lib/iomgr/ev_epoll1_linux.cc +++ b/src/core/lib/iomgr/ev_epoll1_linux.cc @@ -131,6 +131,13 @@ static void epoll_set_shutdown() { * Fd Declarations */ +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +struct grpc_fork_fd_list { + grpc_fd* fd; + grpc_fd* next; + grpc_fd* prev; +}; + struct grpc_fd { int fd; @@ -141,6 +148,9 @@ struct grpc_fd { struct grpc_fd* freelist_next; grpc_iomgr_object iomgr_object; + + /* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ + grpc_fork_fd_list* fork_fd_list; }; static void fd_global_init(void); @@ -256,6 +266,10 @@ static bool append_error(grpc_error** composite, grpc_error* error, static grpc_fd* fd_freelist = nullptr; static gpr_mu fd_freelist_mu; +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +static grpc_fd* fork_fd_list_head = nullptr; +static gpr_mu fork_fd_list_mu; + static void fd_global_init(void) { gpr_mu_init(&fd_freelist_mu); } static void fd_global_shutdown(void) { @@ -269,6 +283,38 @@ static void fd_global_shutdown(void) { gpr_mu_destroy(&fd_freelist_mu); } +static void fork_fd_list_add_grpc_fd(grpc_fd* fd) { + if (grpc_core::Fork::Enabled()) { + gpr_mu_lock(&fork_fd_list_mu); + fd->fork_fd_list = + static_cast<grpc_fork_fd_list*>(gpr_malloc(sizeof(grpc_fork_fd_list))); + fd->fork_fd_list->next = fork_fd_list_head; + fd->fork_fd_list->prev = nullptr; + if (fork_fd_list_head != nullptr) { + fork_fd_list_head->fork_fd_list->prev = fd; + } + fork_fd_list_head = fd; + gpr_mu_unlock(&fork_fd_list_mu); + } +} + +static void fork_fd_list_remove_grpc_fd(grpc_fd* fd) { + if (grpc_core::Fork::Enabled()) { + gpr_mu_lock(&fork_fd_list_mu); + if (fork_fd_list_head == fd) { + fork_fd_list_head = fd->fork_fd_list->next; + } + if (fd->fork_fd_list->prev != nullptr) { + fd->fork_fd_list->prev->fork_fd_list->next = fd->fork_fd_list->next; + } + if (fd->fork_fd_list->next != nullptr) { + fd->fork_fd_list->next->fork_fd_list->prev = fd->fork_fd_list->prev; + } + gpr_free(fd->fork_fd_list); + gpr_mu_unlock(&fork_fd_list_mu); + } +} + static grpc_fd* fd_create(int fd, const char* name, bool track_err) { grpc_fd* new_fd = nullptr; @@ -295,6 +341,7 @@ static grpc_fd* fd_create(int fd, const char* name, bool track_err) { char* fd_name; gpr_asprintf(&fd_name, "%s fd=%d", name, fd); grpc_iomgr_register_object(&new_fd->iomgr_object, fd_name); + fork_fd_list_add_grpc_fd(new_fd); #ifndef NDEBUG if (grpc_trace_fd_refcount.enabled()) { gpr_log(GPR_DEBUG, "FD %d %p create %s", fd, new_fd, fd_name); @@ -361,6 +408,7 @@ static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd, GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_REF(error)); grpc_iomgr_unregister_object(&fd->iomgr_object); + fork_fd_list_remove_grpc_fd(fd); fd->read_closure->DestroyEvent(); fd->write_closure->DestroyEvent(); fd->error_closure->DestroyEvent(); @@ -1190,6 +1238,10 @@ static void shutdown_engine(void) { fd_global_shutdown(); pollset_global_shutdown(); epoll_set_shutdown(); + if (grpc_core::Fork::Enabled()) { + gpr_mu_destroy(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc(nullptr); + } } static const grpc_event_engine_vtable vtable = { @@ -1227,6 +1279,21 @@ static const grpc_event_engine_vtable vtable = { shutdown_engine, }; +/* Called by the child process's post-fork handler to close open fds, including + * the global epoll fd. This allows gRPC to shutdown in the child process + * without interfering with connections or RPCs ongoing in the parent. */ +static void reset_event_manager_on_fork() { + gpr_mu_lock(&fork_fd_list_mu); + while (fork_fd_list_head != nullptr) { + close(fork_fd_list_head->fd); + fork_fd_list_head->fd = -1; + fork_fd_list_head = fork_fd_list_head->fork_fd_list->next; + } + gpr_mu_unlock(&fork_fd_list_mu); + shutdown_engine(); + grpc_init_epoll1_linux(true); +} + /* It is possible that GLIBC has epoll but the underlying kernel doesn't. * Create epoll_fd (epoll_set_init() takes care of that) to make sure epoll * support is available */ @@ -1248,6 +1315,11 @@ const grpc_event_engine_vtable* grpc_init_epoll1_linux(bool explicit_request) { return nullptr; } + if (grpc_core::Fork::Enabled()) { + gpr_mu_init(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc( + reset_event_manager_on_fork); + } return &vtable; } diff --git a/src/core/lib/iomgr/fork_posix.cc b/src/core/lib/iomgr/fork_posix.cc index b37384b8db..a5b61fb4ce 100644 --- a/src/core/lib/iomgr/fork_posix.cc +++ b/src/core/lib/iomgr/fork_posix.cc @@ -84,6 +84,11 @@ void grpc_postfork_child() { if (!skipped_handler) { grpc_core::Fork::AllowExecCtx(); grpc_core::ExecCtx exec_ctx; + grpc_core::Fork::child_postfork_func reset_polling_engine = + grpc_core::Fork::GetResetChildPollingEngineFunc(); + if (reset_polling_engine != nullptr) { + reset_polling_engine(); + } grpc_timer_manager_set_threading(true); grpc_executor_set_threading(true); } diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index e9246991df..6876601785 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -111,6 +111,10 @@ class _RPCState(object): # prior to termination of the RPC. self.cancelled = False self.callbacks = [] + self.fork_epoch = cygrpc.get_fork_epoch() + + def reset_postfork_child(self): + self.condition = threading.Condition() def _abort(state, code, details): @@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer): done = not state.due for callback in callbacks: callback() - return done + return done and state.fork_epoch >= cygrpc.get_fork_epoch() return handle_event def _consume_request_iterator(request_iterator, state, call, request_serializer, event_handler): + if cygrpc.is_fork_support_enabled(): + condition_wait_timeout = 1.0 + else: + condition_wait_timeout = None def consume_request_iterator(): # pylint: disable=too-many-branches while True: + return_from_user_request_generator_invoked = False try: + # The thread may die in user-code. Do not block fork for this. + cygrpc.enter_user_request_generator() request = next(request_iterator) except StopIteration: break except Exception: # pylint: disable=broad-except + cygrpc.return_from_user_request_generator() + return_from_user_request_generator_invoked = True code = grpc.StatusCode.UNKNOWN details = 'Exception iterating requests!' _LOGGER.exception(details) @@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, details) _abort(state, code, details) return + finally: + if not return_from_user_request_generator_invoked: + cygrpc.return_from_user_request_generator() serialized_request = _common.serialize(request, request_serializer) with state.condition: if state.code is None and not state.cancelled: @@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, else: return while True: - state.condition.wait() + state.condition.wait(condition_wait_timeout) + cygrpc.block_if_fork_in_progress(state) if state.code is None: if cygrpc.OperationType.send_message not in state.due: break @@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, if operating: state.due.add(cygrpc.OperationType.send_close_from_client) - consumption_thread = threading.Thread(target=consume_request_iterator) - consumption_thread.daemon = True + consumption_thread = cygrpc.ForkManagedThread( + target=consume_request_iterator) + consumption_thread.setDaemon(True) consumption_thread.start() @@ -671,13 +689,20 @@ class _ChannelCallState(object): self.lock = threading.Lock() self.channel = channel self.managed_calls = 0 + self.threading = False + + def reset_postfork_child(self): + self.managed_calls = 0 def _run_channel_spin_thread(state): def channel_spin(): while True: + cygrpc.block_if_fork_in_progress(state) event = state.channel.next_call_event() + if event.completion_type == cygrpc.CompletionType.queue_timeout: + continue call_completed = event.tag(event) if call_completed: with state.lock: @@ -685,8 +710,8 @@ def _run_channel_spin_thread(state): if state.managed_calls == 0: return - channel_spin_thread = threading.Thread(target=channel_spin) - channel_spin_thread.daemon = True + channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin) + channel_spin_thread.setDaemon(True) channel_spin_thread.start() @@ -742,6 +767,13 @@ class _ChannelConnectivityState(object): self.callbacks_and_connectivities = [] self.delivering = False + def reset_postfork_child(self): + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + def _deliveries(state): callbacks_needing_update = [] @@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): callbacks = initial_callbacks while True: for callback in callbacks: + cygrpc.block_if_fork_in_progress(state) callable_util.call_logging_exceptions( callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE, connectivity) @@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): def _spawn_delivery(state, callbacks): - delivering_thread = threading.Thread( + delivering_thread = cygrpc.ForkManagedThread( target=_deliver, args=( state, state.connectivity, @@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect): while True: event = channel.watch_connectivity_state(connectivity, time.time() + 0.2) + cygrpc.block_if_fork_in_progress(state) with state.lock: if not state.callbacks_and_connectivities and not state.try_to_connect: state.polling = False @@ -826,10 +860,10 @@ def _moot(state): def _subscribe(state, callback, try_to_connect): with state.lock: if not state.callbacks_and_connectivities and not state.polling: - polling_thread = threading.Thread( + polling_thread = cygrpc.ForkManagedThread( target=_poll_connectivity, args=(state, state.channel, bool(try_to_connect))) - polling_thread.daemon = True + polling_thread.setDaemon(True) polling_thread.start() state.polling = True state.callbacks_and_connectivities.append([callback, None]) @@ -876,6 +910,7 @@ class Channel(grpc.Channel): _common.encode(target), _options(options), credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) + cygrpc.fork_register_channel(self) def subscribe(self, callback, try_to_connect=None): _subscribe(self._connectivity_state, callback, try_to_connect) @@ -919,6 +954,11 @@ class Channel(grpc.Channel): self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') _moot(self._connectivity_state) + def _close_on_fork(self): + self._channel.close_on_fork(cygrpc.StatusCode.cancelled, + 'Channel closed due to fork') + _moot(self._connectivity_state) + def __enter__(self): return self @@ -939,4 +979,5 @@ class Channel(grpc.Channel): # for as long as they are in use and to close them after using them, # then deletion of this grpc._channel.Channel instance can be made to # effect closure of the underlying cygrpc.Channel instance. + cygrpc.fork_unregister_channel(self) _moot(self._connectivity_state) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi index a0de862d94..24e85b08e7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi @@ -19,7 +19,7 @@ cdef class Call: def __cinit__(self): # Create an *empty* call - grpc_init() + fork_handlers_and_grpc_init() self.c_call = NULL self.references = [] diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index f067d76fab..ced32abba1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -40,6 +40,7 @@ cdef class _ChannelState: # field and just use the NULLness of c_channel as an indication that the # channel is closed. cdef object open + cdef object closed_reason # A dict from _BatchOperationTag to _CallState cdef dict integrated_call_states diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index aa187e88a6..a81ff4d823 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -15,6 +15,7 @@ cimport cpython import threading +import time _INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( 'Internal gRPC call error %d. ' + @@ -83,6 +84,7 @@ cdef class _ChannelState: self.integrated_call_states = {} self.segregated_call_states = set() self.connectivity_due = set() + self.closed_reason = None cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): @@ -142,10 +144,10 @@ cdef _cancel( _check_and_raise_call_error_no_metadata(c_call_error) -cdef BatchOperationEvent _next_call_event( +cdef _next_call_event( _ChannelState channel_state, grpc_completion_queue *c_completion_queue, - on_success): - tag, event = _latent_event(c_completion_queue, None) + on_success, deadline): + tag, event = _latent_event(c_completion_queue, deadline) with channel_state.condition: on_success(tag) channel_state.condition.notify_all() @@ -229,8 +231,7 @@ cdef void _call( call_state.due.update(started_tags) on_success(started_tags) else: - raise ValueError('Cannot invoke RPC on closed channel!') - + raise ValueError('Cannot invoke RPC: %s' % channel_state.closed_reason) cdef void _process_integrated_call_tag( _ChannelState state, _BatchOperationTag tag) except *: cdef _CallState call_state = state.integrated_call_states.pop(tag) @@ -302,7 +303,7 @@ cdef class SegregatedCall: _process_segregated_call_tag( self._channel_state, self._call_state, self._c_completion_queue, tag) return _next_call_event( - self._channel_state, self._c_completion_queue, on_success) + self._channel_state, self._c_completion_queue, on_success, None) cdef SegregatedCall _segregated_call( @@ -346,7 +347,7 @@ cdef object _watch_connectivity_state( state.c_connectivity_completion_queue, <cpython.PyObject *>tag) state.connectivity_due.add(tag) else: - raise ValueError('Cannot invoke RPC on closed channel!') + raise ValueError('Cannot invoke RPC: %s' % state.closed_reason) completed_tag, event = _latent_event( state.c_connectivity_completion_queue, None) with state.condition: @@ -355,12 +356,15 @@ cdef object _watch_connectivity_state( return event -cdef _close(_ChannelState state, grpc_status_code code, object details): +cdef _close(Channel channel, grpc_status_code code, object details, + drain_calls): + cdef _ChannelState state = channel._state cdef _CallState call_state encoded_details = _encode(details) with state.condition: if state.open: state.open = False + state.closed_reason = details for call_state in set(state.integrated_call_states.values()): grpc_call_cancel_with_status( call_state.c_call, code, encoded_details, NULL) @@ -370,12 +374,19 @@ cdef _close(_ChannelState state, grpc_status_code code, object details): # TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity # watching. - while state.integrated_call_states: - state.condition.wait() - while state.segregated_call_states: - state.condition.wait() - while state.connectivity_due: - state.condition.wait() + if drain_calls: + while not _calls_drained(state): + event = channel.next_call_event() + if event.completion_type == CompletionType.queue_timeout: + continue + event.tag(event) + else: + while state.integrated_call_states: + state.condition.wait() + while state.segregated_call_states: + state.condition.wait() + while state.connectivity_due: + state.condition.wait() _destroy_c_completion_queue(state.c_call_completion_queue) _destroy_c_completion_queue(state.c_connectivity_completion_queue) @@ -390,13 +401,17 @@ cdef _close(_ChannelState state, grpc_status_code code, object details): state.condition.wait() +cdef _calls_drained(_ChannelState state): + return not (state.integrated_call_states or state.segregated_call_states or + state.connectivity_due) + cdef class Channel: def __cinit__( self, bytes target, object arguments, ChannelCredentials channel_credentials): arguments = () if arguments is None else tuple(arguments) - grpc_init() + fork_handlers_and_grpc_init() self._state = _ChannelState() self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer @@ -435,9 +450,14 @@ cdef class Channel: def next_call_event(self): def on_success(tag): - _process_integrated_call_tag(self._state, tag) - return _next_call_event( - self._state, self._state.c_call_completion_queue, on_success) + if tag is not None: + _process_integrated_call_tag(self._state, tag) + if is_fork_support_enabled(): + queue_deadline = time.time() + 1.0 + else: + queue_deadline = None + return _next_call_event(self._state, self._state.c_call_completion_queue, + on_success, queue_deadline) def segregated_call( self, int flags, method, host, object deadline, object metadata, @@ -452,11 +472,14 @@ cdef class Channel: return grpc_channel_check_connectivity_state( self._state.c_channel, try_to_connect) else: - raise ValueError('Cannot invoke RPC on closed channel!') + raise ValueError('Cannot invoke RPC: %s' % self._state.closed_reason) def watch_connectivity_state( self, grpc_connectivity_state last_observed_state, object deadline): return _watch_connectivity_state(self._state, last_observed_state, deadline) def close(self, code, details): - _close(self._state, code, details) + _close(self, code, details, False) + + def close_on_fork(self, code, details): + _close(self, code, details, True) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index a2d765546a..141116df5d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -71,7 +71,7 @@ cdef class CompletionQueue: def __cinit__(self, shutdown_cq=False): cdef grpc_completion_queue_attributes c_attrs - grpc_init() + fork_handlers_and_grpc_init() if shutdown_cq: c_attrs.version = 1 c_attrs.cq_completion_type = GRPC_CQ_NEXT diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 0a25218e19..e3c1c8215c 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -21,7 +21,7 @@ from libc.stdint cimport uintptr_t def _spawn_callback_in_thread(cb_func, args): - threading.Thread(target=cb_func, args=args).start() + ForkManagedThread(target=cb_func, args=args).start() async_callback_func = _spawn_callback_in_thread @@ -114,7 +114,7 @@ cdef class ChannelCredentials: cdef class SSLSessionCacheLRU: def __cinit__(self, capacity): - grpc_init() + fork_handlers_and_grpc_init() self._cache = grpc_ssl_session_cache_create_lru(capacity) def __int__(self): @@ -172,7 +172,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials): cdef class ServerCertificateConfig: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() self.c_cert_config = NULL self.c_pem_root_certs = NULL self.c_ssl_pem_key_cert_pairs = NULL @@ -187,7 +187,7 @@ cdef class ServerCertificateConfig: cdef class ServerCredentials: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() self.c_credentials = NULL self.references = [] self.initial_cert_config = None diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi new file mode 100644 index 0000000000..a925bdd2e6 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi @@ -0,0 +1,29 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef extern from "pthread.h" nogil: + int pthread_atfork( + void (*prepare)() nogil, + void (*parent)() nogil, + void (*child)() nogil) + + +cdef void __prefork() nogil + + +cdef void __postfork_parent() nogil + + +cdef void __postfork_child() nogil
\ No newline at end of file diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi new file mode 100644 index 0000000000..1176258da8 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi @@ -0,0 +1,203 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import threading + +_LOGGER = logging.getLogger(__name__) + +_AWAIT_THREADS_TIMEOUT_SECONDS = 5 + +_TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1'] + +# This flag enables experimental support within gRPC Python for applications +# that will fork() without exec(). When enabled, gRPC Python will attempt to +# pause all of its internally created threads before the fork syscall proceeds. +# +# For this to be successful, the application must not have multiple threads of +# its own calling into gRPC when fork is invoked. Any callbacks from gRPC +# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs) +# must not block and should execute quickly. +# +# This flag is not supported on Windows. +_GRPC_ENABLE_FORK_SUPPORT = ( + os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0') + .lower() in _TRUE_VALUES) + +_GRPC_POLL_STRATEGY = os.environ.get('GRPC_POLL_STRATEGY') + +cdef void __prefork() nogil: + with gil: + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = True + if not _fork_state.active_thread_count.await_zero_threads( + _AWAIT_THREADS_TIMEOUT_SECONDS): + _LOGGER.error( + 'Failed to shutdown gRPC Python threads prior to fork. ' + 'Behavior after fork will be undefined.') + + +cdef void __postfork_parent() nogil: + with gil: + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = False + _fork_state.fork_in_progress_condition.notify_all() + + +cdef void __postfork_child() nogil: + with gil: + # Thread could be holding the fork_in_progress_condition inside of + # block_if_fork_in_progress() when fork occurs. Reset the lock here. + _fork_state.fork_in_progress_condition = threading.Condition() + # A thread in return_from_user_request_generator() may hold this lock + # when fork occurs. + _fork_state.active_thread_count = _ActiveThreadCount() + for state_to_reset in _fork_state.postfork_states_to_reset: + state_to_reset.reset_postfork_child() + _fork_state.fork_epoch += 1 + for channel in _fork_state.channels: + channel._close_on_fork() + # TODO(ericgribkoff) Check and abort if core is not shutdown + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = False + + +def fork_handlers_and_grpc_init(): + grpc_init() + if _GRPC_ENABLE_FORK_SUPPORT: + # TODO(ericgribkoff) epoll1 is default for grpcio distribution. Decide whether to expose + # grpc_get_poll_strategy_name() from ev_posix.cc to get actual polling choice. + if _GRPC_POLL_STRATEGY is not None and _GRPC_POLL_STRATEGY != "epoll1": + _LOGGER.error( + 'gRPC Python fork support is only compatible with the epoll1 ' + 'polling engine') + return + with _fork_state.fork_handler_registered_lock: + if not _fork_state.fork_handler_registered: + pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child) + _fork_state.fork_handler_registered = True + + +class ForkManagedThread(object): + def __init__(self, target, args=()): + if _GRPC_ENABLE_FORK_SUPPORT: + def managed_target(*args): + try: + target(*args) + finally: + _fork_state.active_thread_count.decrement() + self._thread = threading.Thread(target=managed_target, args=args) + else: + self._thread = threading.Thread(target=target, args=args) + + def setDaemon(self, daemonic): + self._thread.daemon = daemonic + + def start(self): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.increment() + self._thread.start() + + def join(self): + self._thread.join() + + +def block_if_fork_in_progress(postfork_state_to_reset=None): + if _GRPC_ENABLE_FORK_SUPPORT: + with _fork_state.fork_in_progress_condition: + if not _fork_state.fork_in_progress: + return + if postfork_state_to_reset is not None: + _fork_state.postfork_states_to_reset.append(postfork_state_to_reset) + _fork_state.active_thread_count.decrement() + _fork_state.fork_in_progress_condition.wait() + _fork_state.active_thread_count.increment() + + +def enter_user_request_generator(): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.decrement() + + +def return_from_user_request_generator(): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.increment() + block_if_fork_in_progress() + + +def get_fork_epoch(): + return _fork_state.fork_epoch + + +def is_fork_support_enabled(): + return _GRPC_ENABLE_FORK_SUPPORT + + +def fork_register_channel(channel): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.channels.add(channel) + + +def fork_unregister_channel(channel): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.channels.remove(channel) + + +class _ActiveThreadCount(object): + def __init__(self): + self._num_active_threads = 0 + self._condition = threading.Condition() + + def increment(self): + with self._condition: + self._num_active_threads += 1 + + def decrement(self): + with self._condition: + self._num_active_threads -= 1 + if self._num_active_threads == 0: + self._condition.notify_all() + + def await_zero_threads(self, timeout_secs): + end_time = time.time() + timeout_secs + wait_time = timeout_secs + with self._condition: + while True: + if self._num_active_threads > 0: + self._condition.wait(wait_time) + if self._num_active_threads == 0: + return True + # Thread count may have increased before this re-obtains the + # lock after a notify(). Wait again until timeout_secs has + # elapsed. + wait_time = end_time - time.time() + if wait_time <= 0: + return False + + +class _ForkState(object): + def __init__(self): + self.fork_in_progress_condition = threading.Condition() + self.fork_in_progress = False + self.postfork_states_to_reset = [] + self.fork_handler_registered_lock = threading.Lock() + self.fork_handler_registered = False + self.active_thread_count = _ActiveThreadCount() + self.fork_epoch = 0 + self.channels = set() + + +_fork_state = _ForkState() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi new file mode 100644 index 0000000000..8dc1ef3b1a --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi @@ -0,0 +1,63 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading + +# No-op implementations for Windows. + +def fork_handlers_and_grpc_init(): + grpc_init() + + +class ForkManagedThread(object): + def __init__(self, target, args=()): + self._thread = threading.Thread(target=target, args=args) + + def setDaemon(self, daemonic): + self._thread.daemon = daemonic + + def start(self): + self._thread.start() + + def join(self): + self._thread.join() + + +def block_if_fork_in_progress(postfork_state_to_reset=None): + pass + + +def enter_user_request_generator(): + pass + + +def return_from_user_request_generator(): + pass + + +def get_fork_epoch(): + return 0 + + +def is_fork_support_enabled(): + return False + + +def fork_register_channel(channel): + pass + + +def fork_unregister_channel(channel): + pass diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index 37b98ebbdb..fe98d559f3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -127,7 +127,7 @@ class CompressionLevel: cdef class CallDetails: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() with nogil: grpc_call_details_init(&self.c_details) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index da3dd21244..db59d468dc 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -60,7 +60,7 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp cdef class Server: def __cinit__(self, object arguments): - grpc_init() + fork_handlers_and_grpc_init() self.references = [] self.registered_completion_queues = [] self._vtable.copy = &_copy_pointer diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index 0cc26bc0d0..8258b857bc 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -31,3 +31,6 @@ include "_cygrpc/time.pxd.pxi" include "_cygrpc/_hooks.pxd.pxi" include "_cygrpc/grpc_gevent.pxd.pxi" + +IF UNAME_SYSNAME != "Windows": + include "_cygrpc/fork_posix.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index 3cac406687..026f7ba2e3 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -39,6 +39,11 @@ include "_cygrpc/_hooks.pyx.pxi" include "_cygrpc/grpc_gevent.pyx.pxi" +IF UNAME_SYSNAME == "Windows": + include "_cygrpc/fork_windows.pyx.pxi" +ELSE: + include "_cygrpc/fork_posix.pyx.pxi" + # # initialize gRPC # diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index a23c980017..0dfbf3180b 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -202,3 +202,28 @@ class RunInterop(test.test): from tests.interop import client sys.argv[1:] = self.args.split() client.test_interoperability() + + +class RunFork(test.test): + + description = 'run fork test client' + user_options = [('args=', 'a', 'pass-thru arguments for the client')] + + def initialize_options(self): + self.args = '' + + def finalize_options(self): + # distutils requires this override. + pass + + def run(self): + if self.distribution.install_requires: + self.distribution.fetch_build_eggs( + self.distribution.install_requires) + if self.distribution.tests_require: + self.distribution.fetch_build_eggs(self.distribution.tests_require) + # We import here to ensure that our setuptools parent has had a chance to + # edit the Python system path. + from tests.fork import client + sys.argv[1:] = self.args.split() + client.test_fork() diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index a94c0963ec..61c98fa038 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -52,6 +52,7 @@ COMMAND_CLASS = { 'preprocess': commands.GatherProto, 'build_package_protos': grpc_tools.command.BuildPackageProtos, 'build_py': commands.BuildPy, + 'run_fork': commands.RunFork, 'run_interop': commands.RunInterop, 'test_lite': commands.TestLite, 'test_gevent': commands.TestGevent, diff --git a/src/python/grpcio_tests/tests/fork/__init__.py b/src/python/grpcio_tests/tests/fork/__init__.py new file mode 100644 index 0000000000..9a26bac010 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_tests/tests/fork/client.py b/src/python/grpcio_tests/tests/fork/client.py new file mode 100644 index 0000000000..9a32629ed5 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/client.py @@ -0,0 +1,76 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC interoperability test client.""" + +import argparse +import logging +import sys + +from tests.fork import methods + + +def _args(): + + def parse_bool(value): + if value == 'true': + return True + if value == 'false': + return False + raise argparse.ArgumentTypeError('Only true/false allowed') + + parser = argparse.ArgumentParser() + parser.add_argument( + '--server_host', + default="localhost", + type=str, + help='the host to which to connect') + parser.add_argument( + '--server_port', + type=int, + required=True, + help='the port to which to connect') + parser.add_argument( + '--test_case', + default='large_unary', + type=str, + help='the test case to execute') + parser.add_argument( + '--use_tls', + default=False, + type=parse_bool, + help='require a secure connection') + return parser.parse_args() + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case "%s"!' % test_case_arg) + + +def test_fork(): + logging.basicConfig(level=logging.INFO) + args = _args() + if args.test_case == "all": + for test_case in methods.TestCase: + test_case.run_test(args) + else: + test_case = _test_case_from_arg(args.test_case) + test_case.run_test(args) + + +if __name__ == '__main__': + test_fork() diff --git a/src/python/grpcio_tests/tests/fork/methods.py b/src/python/grpcio_tests/tests/fork/methods.py new file mode 100644 index 0000000000..889ef13cb2 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/methods.py @@ -0,0 +1,445 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementations of fork support test methods.""" + +import enum +import json +import logging +import multiprocessing +import os +import threading +import time + +import grpc + +from six.moves import queue + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc + +_LOGGER = logging.getLogger(__name__) + + +def _channel(args): + target = '{}:{}'.format(args.server_host, args.server_port) + if args.use_tls: + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(target, channel_credentials) + else: + channel = grpc.insecure_channel(target) + return channel + + +def _validate_payload_type_and_length(response, expected_type, expected_length): + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +def _async_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response_future = stub.UnaryCall.future(request) + response = response_future.result() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +def _blocking_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response = stub.UnaryCall(request) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +class _Pipe(object): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._open = True + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class _ChildProcess(object): + + def __init__(self, task, args=None): + if args is None: + args = () + self._exceptions = multiprocessing.Queue() + + def record_exceptions(): + try: + task(*args) + except Exception as e: # pylint: disable=broad-except + self._exceptions.put(e) + + self._process = multiprocessing.Process(target=record_exceptions) + + def start(self): + self._process.start() + + def finish(self): + self._process.join() + if self._process.exitcode != 0: + raise ValueError('Child process failed with exitcode %d' % + self._process.exitcode) + try: + exception = self._exceptions.get(block=False) + raise ValueError('Child process failed: %s' % exception) + except queue.Empty: + pass + + +def _async_unary_same_channel(channel): + + def child_target(): + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _async_unary_new_channel(channel, args): + + def child_target(): + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _async_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _blocking_unary_same_channel(channel): + + def child_target(): + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + child_process.finish() + + +def _blocking_unary_new_channel(channel, args): + + def child_target(): + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(stub) + child_process.finish() + + +# Verify that the fork channel registry can handle already closed channels +def _close_channel_before_fork(channel, args): + + def child_target(): + new_channel.close() + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + channel.close() + + new_channel = _channel(args) + new_stub = test_pb2_grpc.TestServiceStub(new_channel) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(new_stub) + child_process.finish() + + +def _connectivity_watch(channel, args): + + def child_target(): + + def child_connectivity_callback(state): + child_states.append(state) + + child_states = [] + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + child_channel.subscribe(child_connectivity_callback) + _async_unary(child_stub) + if len(child_states + ) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY: + raise ValueError('Channel did not move to READY') + if len(parent_states) > 1: + raise ValueError('Received connectivity updates on parent callback') + child_channel.unsubscribe(child_connectivity_callback) + child_channel.close() + + def parent_connectivity_callback(state): + parent_states.append(state) + + parent_states = [] + channel.subscribe(parent_connectivity_callback) + stub = test_pb2_grpc.TestServiceStub(channel) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + if len(parent_states + ) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY: + raise ValueError('Channel did not move to READY') + channel.unsubscribe(parent_connectivity_callback) + child_process.finish() + + # Need to unsubscribe or _channel.py in _poll_connectivity triggers a + # "Cannot invoke RPC on closed channel!" error. + # TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling + channel.unsubscribe(parent_connectivity_callback) + + +def _ping_pong_with_child_processes_after_first_response( + channel, args, child_target, run_after_close=True): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + stub = test_pb2_grpc.TestServiceStub(channel) + pipe = _Pipe() + parent_bidi_call = stub.FullDuplexCall(pipe) + child_processes = [] + first_message_received = False + for response_size, payload_size in zip(request_response_sizes, + request_payload_sizes): + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=( + messages_pb2.ResponseParameters(size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + pipe.add(request) + if first_message_received: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + response = next(parent_bidi_call) + first_message_received = True + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + response_size) + pipe.close() + if run_after_close: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + for child_process in child_processes: + child_process.finish() + + +def _in_progress_bidi_continue_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + inherited_code = parent_bidi_call.code() + inherited_details = parent_bidi_call.details() + if inherited_code != grpc.StatusCode.CANCELLED: + raise ValueError( + 'Expected inherited code CANCELLED, got %s' % inherited_code) + if inherited_details != 'Channel closed due to fork': + raise ValueError( + 'Expected inherited details Channel closed due to fork, got %s' + % inherited_details) + + # Don't run child_target after closing the parent call, as the call may have + # received a status from the server before fork occurs. + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target, run_after_close=False) + + +def _in_progress_bidi_same_channel_async_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_same_channel_blocking_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_new_channel_async_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + channel = _channel(args) + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +def _in_progress_bidi_new_channel_blocking_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + channel = _channel(args) + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +@enum.unique +class TestCase(enum.Enum): + + CONNECTIVITY_WATCH = 'connectivity_watch' + CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork' + ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel' + ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel' + BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel' + BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel' + IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call' + + def run_test(self, args): + _LOGGER.info("Running %s", self) + channel = _channel(args) + if self is TestCase.ASYNC_UNARY_SAME_CHANNEL: + _async_unary_same_channel(channel) + elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL: + _async_unary_new_channel(channel, args) + elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL: + _blocking_unary_same_channel(channel) + elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL: + _blocking_unary_new_channel(channel, args) + elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK: + _close_channel_before_fork(channel, args) + elif self is TestCase.CONNECTIVITY_WATCH: + _connectivity_watch(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL: + _in_progress_bidi_continue_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL: + _in_progress_bidi_same_channel_async_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_same_channel_blocking_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL: + _in_progress_bidi_new_channel_async_call(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_new_channel_blocking_call(channel, args) + else: + raise NotImplementedError( + 'Test case "%s" not implemented!' % self.name) + channel.close() diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index ebc41c63f0..76d5d22d57 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -32,6 +32,8 @@ "unit._credentials_test.CredentialsTest", "unit._cython._cancel_many_calls_test.CancelManyCallsTest", "unit._cython._channel_test.ChannelTest", + "unit._cython._fork_test.ForkPosixTester", + "unit._cython._fork_test.ForkWindowsTester", "unit._cython._no_messages_server_completion_queue_per_call_test.Test", "unit._cython._no_messages_single_server_completion_queue_test.Test", "unit._cython._read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py new file mode 100644 index 0000000000..aeb02458a7 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +import unittest + +from grpc._cython import cygrpc + + +def _get_number_active_threads(): + return cygrpc._fork_state.active_thread_count._num_active_threads + + +@unittest.skipIf(os.name == 'nt', 'Posix-specific tests') +class ForkPosixTester(unittest.TestCase): + + def setUp(self): + cygrpc._GRPC_ENABLE_FORK_SUPPORT = True + + def testForkManagedThread(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + def testForkManagedThreadThrowsException(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + raise Exception("expected exception") + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + +@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') +class ForkWindowsTester(unittest.TestCase): + + def testForkManagedThreadIsNoOp(self): + + def cb(): + pass + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + + +if __name__ == '__main__': + unittest.main(verbosity=2) |