diff options
Diffstat (limited to 'src/python')
31 files changed, 1094 insertions, 85 deletions
diff --git a/src/python/grpcio/commands.py b/src/python/grpcio/commands.py index 4c2ebaeaea..0a3097111f 100644 --- a/src/python/grpcio/commands.py +++ b/src/python/grpcio/commands.py @@ -265,8 +265,17 @@ class BuildExt(build_ext.build_ext): os.path.join(target_path, 'libgpr.a'), os.path.join(target_path, 'libgrpc.a') ] + # Running make separately for Mac means we lose all + # Extension.define_macros configured in setup.py. Re-add the macro + # for gRPC Core's fork handlers. + # TODO(ericgribkoff) Decide what to do about the other missing core + # macros, including GRPC_ENABLE_FORK_SUPPORT, which defaults to 1 + # on Linux but remains unset on Mac. + extra_defines = [ + 'EXTRA_DEFINES="GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK=1"' + ] make_process = subprocess.Popen( - ['make'] + targets, + ['make'] + extra_defines + targets, stdout=subprocess.PIPE, stderr=subprocess.PIPE) make_out, make_err = make_process.communicate() diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index e9246991df..6876601785 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -111,6 +111,10 @@ class _RPCState(object): # prior to termination of the RPC. self.cancelled = False self.callbacks = [] + self.fork_epoch = cygrpc.get_fork_epoch() + + def reset_postfork_child(self): + self.condition = threading.Condition() def _abort(state, code, details): @@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer): done = not state.due for callback in callbacks: callback() - return done + return done and state.fork_epoch >= cygrpc.get_fork_epoch() return handle_event def _consume_request_iterator(request_iterator, state, call, request_serializer, event_handler): + if cygrpc.is_fork_support_enabled(): + condition_wait_timeout = 1.0 + else: + condition_wait_timeout = None def consume_request_iterator(): # pylint: disable=too-many-branches while True: + return_from_user_request_generator_invoked = False try: + # The thread may die in user-code. Do not block fork for this. + cygrpc.enter_user_request_generator() request = next(request_iterator) except StopIteration: break except Exception: # pylint: disable=broad-except + cygrpc.return_from_user_request_generator() + return_from_user_request_generator_invoked = True code = grpc.StatusCode.UNKNOWN details = 'Exception iterating requests!' _LOGGER.exception(details) @@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, details) _abort(state, code, details) return + finally: + if not return_from_user_request_generator_invoked: + cygrpc.return_from_user_request_generator() serialized_request = _common.serialize(request, request_serializer) with state.condition: if state.code is None and not state.cancelled: @@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, else: return while True: - state.condition.wait() + state.condition.wait(condition_wait_timeout) + cygrpc.block_if_fork_in_progress(state) if state.code is None: if cygrpc.OperationType.send_message not in state.due: break @@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, if operating: state.due.add(cygrpc.OperationType.send_close_from_client) - consumption_thread = threading.Thread(target=consume_request_iterator) - consumption_thread.daemon = True + consumption_thread = cygrpc.ForkManagedThread( + target=consume_request_iterator) + consumption_thread.setDaemon(True) consumption_thread.start() @@ -671,13 +689,20 @@ class _ChannelCallState(object): self.lock = threading.Lock() self.channel = channel self.managed_calls = 0 + self.threading = False + + def reset_postfork_child(self): + self.managed_calls = 0 def _run_channel_spin_thread(state): def channel_spin(): while True: + cygrpc.block_if_fork_in_progress(state) event = state.channel.next_call_event() + if event.completion_type == cygrpc.CompletionType.queue_timeout: + continue call_completed = event.tag(event) if call_completed: with state.lock: @@ -685,8 +710,8 @@ def _run_channel_spin_thread(state): if state.managed_calls == 0: return - channel_spin_thread = threading.Thread(target=channel_spin) - channel_spin_thread.daemon = True + channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin) + channel_spin_thread.setDaemon(True) channel_spin_thread.start() @@ -742,6 +767,13 @@ class _ChannelConnectivityState(object): self.callbacks_and_connectivities = [] self.delivering = False + def reset_postfork_child(self): + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + def _deliveries(state): callbacks_needing_update = [] @@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): callbacks = initial_callbacks while True: for callback in callbacks: + cygrpc.block_if_fork_in_progress(state) callable_util.call_logging_exceptions( callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE, connectivity) @@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks): def _spawn_delivery(state, callbacks): - delivering_thread = threading.Thread( + delivering_thread = cygrpc.ForkManagedThread( target=_deliver, args=( state, state.connectivity, @@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect): while True: event = channel.watch_connectivity_state(connectivity, time.time() + 0.2) + cygrpc.block_if_fork_in_progress(state) with state.lock: if not state.callbacks_and_connectivities and not state.try_to_connect: state.polling = False @@ -826,10 +860,10 @@ def _moot(state): def _subscribe(state, callback, try_to_connect): with state.lock: if not state.callbacks_and_connectivities and not state.polling: - polling_thread = threading.Thread( + polling_thread = cygrpc.ForkManagedThread( target=_poll_connectivity, args=(state, state.channel, bool(try_to_connect))) - polling_thread.daemon = True + polling_thread.setDaemon(True) polling_thread.start() state.polling = True state.callbacks_and_connectivities.append([callback, None]) @@ -876,6 +910,7 @@ class Channel(grpc.Channel): _common.encode(target), _options(options), credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) + cygrpc.fork_register_channel(self) def subscribe(self, callback, try_to_connect=None): _subscribe(self._connectivity_state, callback, try_to_connect) @@ -919,6 +954,11 @@ class Channel(grpc.Channel): self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') _moot(self._connectivity_state) + def _close_on_fork(self): + self._channel.close_on_fork(cygrpc.StatusCode.cancelled, + 'Channel closed due to fork') + _moot(self._connectivity_state) + def __enter__(self): return self @@ -939,4 +979,5 @@ class Channel(grpc.Channel): # for as long as they are in use and to close them after using them, # then deletion of this grpc._channel.Channel instance can be made to # effect closure of the underlying cygrpc.Channel instance. + cygrpc.fork_unregister_channel(self) _moot(self._connectivity_state) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi index a0de862d94..24e85b08e7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi @@ -19,7 +19,7 @@ cdef class Call: def __cinit__(self): # Create an *empty* call - grpc_init() + fork_handlers_and_grpc_init() self.c_call = NULL self.references = [] diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index f067d76fab..ced32abba1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -40,6 +40,7 @@ cdef class _ChannelState: # field and just use the NULLness of c_channel as an indication that the # channel is closed. cdef object open + cdef object closed_reason # A dict from _BatchOperationTag to _CallState cdef dict integrated_call_states diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index aa187e88a6..a81ff4d823 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -15,6 +15,7 @@ cimport cpython import threading +import time _INTERNAL_CALL_ERROR_MESSAGE_FORMAT = ( 'Internal gRPC call error %d. ' + @@ -83,6 +84,7 @@ cdef class _ChannelState: self.integrated_call_states = {} self.segregated_call_states = set() self.connectivity_due = set() + self.closed_reason = None cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): @@ -142,10 +144,10 @@ cdef _cancel( _check_and_raise_call_error_no_metadata(c_call_error) -cdef BatchOperationEvent _next_call_event( +cdef _next_call_event( _ChannelState channel_state, grpc_completion_queue *c_completion_queue, - on_success): - tag, event = _latent_event(c_completion_queue, None) + on_success, deadline): + tag, event = _latent_event(c_completion_queue, deadline) with channel_state.condition: on_success(tag) channel_state.condition.notify_all() @@ -229,8 +231,7 @@ cdef void _call( call_state.due.update(started_tags) on_success(started_tags) else: - raise ValueError('Cannot invoke RPC on closed channel!') - + raise ValueError('Cannot invoke RPC: %s' % channel_state.closed_reason) cdef void _process_integrated_call_tag( _ChannelState state, _BatchOperationTag tag) except *: cdef _CallState call_state = state.integrated_call_states.pop(tag) @@ -302,7 +303,7 @@ cdef class SegregatedCall: _process_segregated_call_tag( self._channel_state, self._call_state, self._c_completion_queue, tag) return _next_call_event( - self._channel_state, self._c_completion_queue, on_success) + self._channel_state, self._c_completion_queue, on_success, None) cdef SegregatedCall _segregated_call( @@ -346,7 +347,7 @@ cdef object _watch_connectivity_state( state.c_connectivity_completion_queue, <cpython.PyObject *>tag) state.connectivity_due.add(tag) else: - raise ValueError('Cannot invoke RPC on closed channel!') + raise ValueError('Cannot invoke RPC: %s' % state.closed_reason) completed_tag, event = _latent_event( state.c_connectivity_completion_queue, None) with state.condition: @@ -355,12 +356,15 @@ cdef object _watch_connectivity_state( return event -cdef _close(_ChannelState state, grpc_status_code code, object details): +cdef _close(Channel channel, grpc_status_code code, object details, + drain_calls): + cdef _ChannelState state = channel._state cdef _CallState call_state encoded_details = _encode(details) with state.condition: if state.open: state.open = False + state.closed_reason = details for call_state in set(state.integrated_call_states.values()): grpc_call_cancel_with_status( call_state.c_call, code, encoded_details, NULL) @@ -370,12 +374,19 @@ cdef _close(_ChannelState state, grpc_status_code code, object details): # TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity # watching. - while state.integrated_call_states: - state.condition.wait() - while state.segregated_call_states: - state.condition.wait() - while state.connectivity_due: - state.condition.wait() + if drain_calls: + while not _calls_drained(state): + event = channel.next_call_event() + if event.completion_type == CompletionType.queue_timeout: + continue + event.tag(event) + else: + while state.integrated_call_states: + state.condition.wait() + while state.segregated_call_states: + state.condition.wait() + while state.connectivity_due: + state.condition.wait() _destroy_c_completion_queue(state.c_call_completion_queue) _destroy_c_completion_queue(state.c_connectivity_completion_queue) @@ -390,13 +401,17 @@ cdef _close(_ChannelState state, grpc_status_code code, object details): state.condition.wait() +cdef _calls_drained(_ChannelState state): + return not (state.integrated_call_states or state.segregated_call_states or + state.connectivity_due) + cdef class Channel: def __cinit__( self, bytes target, object arguments, ChannelCredentials channel_credentials): arguments = () if arguments is None else tuple(arguments) - grpc_init() + fork_handlers_and_grpc_init() self._state = _ChannelState() self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer @@ -435,9 +450,14 @@ cdef class Channel: def next_call_event(self): def on_success(tag): - _process_integrated_call_tag(self._state, tag) - return _next_call_event( - self._state, self._state.c_call_completion_queue, on_success) + if tag is not None: + _process_integrated_call_tag(self._state, tag) + if is_fork_support_enabled(): + queue_deadline = time.time() + 1.0 + else: + queue_deadline = None + return _next_call_event(self._state, self._state.c_call_completion_queue, + on_success, queue_deadline) def segregated_call( self, int flags, method, host, object deadline, object metadata, @@ -452,11 +472,14 @@ cdef class Channel: return grpc_channel_check_connectivity_state( self._state.c_channel, try_to_connect) else: - raise ValueError('Cannot invoke RPC on closed channel!') + raise ValueError('Cannot invoke RPC: %s' % self._state.closed_reason) def watch_connectivity_state( self, grpc_connectivity_state last_observed_state, object deadline): return _watch_connectivity_state(self._state, last_observed_state, deadline) def close(self, code, details): - _close(self._state, code, details) + _close(self, code, details, False) + + def close_on_fork(self, code, details): + _close(self, code, details, True) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index a2d765546a..141116df5d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -71,7 +71,7 @@ cdef class CompletionQueue: def __cinit__(self, shutdown_cq=False): cdef grpc_completion_queue_attributes c_attrs - grpc_init() + fork_handlers_and_grpc_init() if shutdown_cq: c_attrs.version = 1 c_attrs.cq_completion_type = GRPC_CQ_NEXT diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index d2c0389ca6..38fd9e78b2 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -21,7 +21,7 @@ from libc.stdint cimport uintptr_t def _spawn_callback_in_thread(cb_func, args): - threading.Thread(target=cb_func, args=args).start() + ForkManagedThread(target=cb_func, args=args).start() async_callback_func = _spawn_callback_in_thread @@ -48,7 +48,7 @@ cdef int _get_metadata( cdef size_t metadata_count cdef grpc_metadata *c_metadata def callback(metadata, grpc_status_code status, bytes error_details): - if status is StatusCode.ok: + if status == StatusCode.ok: _store_c_metadata(metadata, &c_metadata, &metadata_count) cb(user_data, c_metadata, metadata_count, status, NULL) _release_c_metadata(c_metadata, metadata_count) @@ -114,7 +114,7 @@ cdef class ChannelCredentials: cdef class SSLSessionCacheLRU: def __cinit__(self, capacity): - grpc_init() + fork_handlers_and_grpc_init() self._cache = grpc_ssl_session_cache_create_lru(capacity) def __int__(self): @@ -172,7 +172,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials): cdef class ServerCertificateConfig: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() self.c_cert_config = NULL self.c_pem_root_certs = NULL self.c_ssl_pem_key_cert_pairs = NULL @@ -187,7 +187,7 @@ cdef class ServerCertificateConfig: cdef class ServerCredentials: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() self.c_credentials = NULL self.references = [] self.initial_cert_config = None @@ -282,3 +282,41 @@ def server_credentials_ssl_dynamic_cert_config(initial_cert_config, # C-core assumes ownership of c_options credentials.c_credentials = grpc_ssl_server_credentials_create_with_options(c_options) return credentials + +cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapper( + void* user_data, grpc_ssl_server_certificate_config **config) with gil: + # This is a credentials.ServerCertificateConfig + cdef ServerCertificateConfig cert_config = None + if not user_data: + raise ValueError('internal error: user_data must be specified') + credentials = <ServerCredentials>user_data + if not credentials.initial_cert_config_fetched: + # C-core is asking for the initial cert config + credentials.initial_cert_config_fetched = True + cert_config = credentials.initial_cert_config._certificate_configuration + else: + user_cb = credentials.cert_config_fetcher + try: + cert_config_wrapper = user_cb() + except Exception: + _LOGGER.exception('Error fetching certificate config') + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL + if cert_config_wrapper is None: + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED + elif not isinstance( + cert_config_wrapper, grpc.ServerCertificateConfiguration): + _LOGGER.error( + 'Error fetching certificate configuration: certificate ' + 'configuration must be of type grpc.ServerCertificateConfiguration, ' + 'not %s' % type(cert_config_wrapper).__name__) + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL + else: + cert_config = cert_config_wrapper._certificate_configuration + config[0] = <grpc_ssl_server_certificate_config*>cert_config.c_cert_config + # our caller will assume ownership of memory, so we have to recreate + # a copy of c_cert_config here + cert_config.c_cert_config = grpc_ssl_server_certificate_config_create( + cert_config.c_pem_root_certs, cert_config.c_ssl_pem_key_cert_pairs, + cert_config.c_ssl_pem_key_cert_pairs_count) + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW + diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi new file mode 100644 index 0000000000..a925bdd2e6 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi @@ -0,0 +1,29 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef extern from "pthread.h" nogil: + int pthread_atfork( + void (*prepare)() nogil, + void (*parent)() nogil, + void (*child)() nogil) + + +cdef void __prefork() nogil + + +cdef void __postfork_parent() nogil + + +cdef void __postfork_child() nogil
\ No newline at end of file diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi new file mode 100644 index 0000000000..433ae1f374 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi @@ -0,0 +1,198 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import threading + +_LOGGER = logging.getLogger(__name__) + +_AWAIT_THREADS_TIMEOUT_SECONDS = 5 + +_TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1'] + +# This flag enables experimental support within gRPC Python for applications +# that will fork() without exec(). When enabled, gRPC Python will attempt to +# pause all of its internally created threads before the fork syscall proceeds. +# +# For this to be successful, the application must not have multiple threads of +# its own calling into gRPC when fork is invoked. Any callbacks from gRPC +# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs) +# must not block and should execute quickly. +# +# This flag is not supported on Windows. +_GRPC_ENABLE_FORK_SUPPORT = ( + os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0') + .lower() in _TRUE_VALUES) + +cdef void __prefork() nogil: + with gil: + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = True + if not _fork_state.active_thread_count.await_zero_threads( + _AWAIT_THREADS_TIMEOUT_SECONDS): + _LOGGER.error( + 'Failed to shutdown gRPC Python threads prior to fork. ' + 'Behavior after fork will be undefined.') + + +cdef void __postfork_parent() nogil: + with gil: + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = False + _fork_state.fork_in_progress_condition.notify_all() + + +cdef void __postfork_child() nogil: + with gil: + # Thread could be holding the fork_in_progress_condition inside of + # block_if_fork_in_progress() when fork occurs. Reset the lock here. + _fork_state.fork_in_progress_condition = threading.Condition() + # A thread in return_from_user_request_generator() may hold this lock + # when fork occurs. + _fork_state.active_thread_count = _ActiveThreadCount() + for state_to_reset in _fork_state.postfork_states_to_reset: + state_to_reset.reset_postfork_child() + _fork_state.fork_epoch += 1 + for channel in _fork_state.channels: + channel._close_on_fork() + # TODO(ericgribkoff) Check and abort if core is not shutdown + with _fork_state.fork_in_progress_condition: + _fork_state.fork_in_progress = False + if grpc_is_initialized() > 0: + with gil: + _LOGGER.error('Failed to shutdown gRPC Core after fork()') + os._exit(os.EX_USAGE) + + +def fork_handlers_and_grpc_init(): + grpc_init() + if _GRPC_ENABLE_FORK_SUPPORT: + with _fork_state.fork_handler_registered_lock: + if not _fork_state.fork_handler_registered: + pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child) + _fork_state.fork_handler_registered = True + + +class ForkManagedThread(object): + def __init__(self, target, args=()): + if _GRPC_ENABLE_FORK_SUPPORT: + def managed_target(*args): + try: + target(*args) + finally: + _fork_state.active_thread_count.decrement() + self._thread = threading.Thread(target=managed_target, args=args) + else: + self._thread = threading.Thread(target=target, args=args) + + def setDaemon(self, daemonic): + self._thread.daemon = daemonic + + def start(self): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.increment() + self._thread.start() + + def join(self): + self._thread.join() + + +def block_if_fork_in_progress(postfork_state_to_reset=None): + if _GRPC_ENABLE_FORK_SUPPORT: + with _fork_state.fork_in_progress_condition: + if not _fork_state.fork_in_progress: + return + if postfork_state_to_reset is not None: + _fork_state.postfork_states_to_reset.append(postfork_state_to_reset) + _fork_state.active_thread_count.decrement() + _fork_state.fork_in_progress_condition.wait() + _fork_state.active_thread_count.increment() + + +def enter_user_request_generator(): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.decrement() + + +def return_from_user_request_generator(): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.active_thread_count.increment() + block_if_fork_in_progress() + + +def get_fork_epoch(): + return _fork_state.fork_epoch + + +def is_fork_support_enabled(): + return _GRPC_ENABLE_FORK_SUPPORT + + +def fork_register_channel(channel): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.channels.add(channel) + + +def fork_unregister_channel(channel): + if _GRPC_ENABLE_FORK_SUPPORT: + _fork_state.channels.remove(channel) + + +class _ActiveThreadCount(object): + def __init__(self): + self._num_active_threads = 0 + self._condition = threading.Condition() + + def increment(self): + with self._condition: + self._num_active_threads += 1 + + def decrement(self): + with self._condition: + self._num_active_threads -= 1 + if self._num_active_threads == 0: + self._condition.notify_all() + + def await_zero_threads(self, timeout_secs): + end_time = time.time() + timeout_secs + wait_time = timeout_secs + with self._condition: + while True: + if self._num_active_threads > 0: + self._condition.wait(wait_time) + if self._num_active_threads == 0: + return True + # Thread count may have increased before this re-obtains the + # lock after a notify(). Wait again until timeout_secs has + # elapsed. + wait_time = end_time - time.time() + if wait_time <= 0: + return False + + +class _ForkState(object): + def __init__(self): + self.fork_in_progress_condition = threading.Condition() + self.fork_in_progress = False + self.postfork_states_to_reset = [] + self.fork_handler_registered_lock = threading.Lock() + self.fork_handler_registered = False + self.active_thread_count = _ActiveThreadCount() + self.fork_epoch = 0 + self.channels = set() + + +_fork_state = _ForkState() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi new file mode 100644 index 0000000000..8dc1ef3b1a --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi @@ -0,0 +1,63 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading + +# No-op implementations for Windows. + +def fork_handlers_and_grpc_init(): + grpc_init() + + +class ForkManagedThread(object): + def __init__(self, target, args=()): + self._thread = threading.Thread(target=target, args=args) + + def setDaemon(self, daemonic): + self._thread.daemon = daemonic + + def start(self): + self._thread.start() + + def join(self): + self._thread.join() + + +def block_if_fork_in_progress(postfork_state_to_reset=None): + pass + + +def enter_user_request_generator(): + pass + + +def return_from_user_request_generator(): + pass + + +def get_fork_epoch(): + return 0 + + +def is_fork_support_enabled(): + return False + + +def fork_register_channel(channel): + pass + + +def fork_unregister_channel(channel): + pass diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index bcbfec0c9f..4781219319 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -322,6 +322,7 @@ cdef extern from "grpc/grpc.h": void grpc_init() nogil void grpc_shutdown() nogil + int grpc_is_initialized() nogil ctypedef struct grpc_completion_queue_factory: pass diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index 37b98ebbdb..fe98d559f3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -127,7 +127,7 @@ class CompressionLevel: cdef class CallDetails: def __cinit__(self): - grpc_init() + fork_handlers_and_grpc_init() with nogil: grpc_call_details_init(&self.c_details) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index da3dd21244..ce701724fd 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -20,47 +20,10 @@ import grpc _LOGGER = logging.getLogger(__name__) -cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapper( - void* user_data, grpc_ssl_server_certificate_config **config) with gil: - # This is a credentials.ServerCertificateConfig - cdef ServerCertificateConfig cert_config = None - if not user_data: - raise ValueError('internal error: user_data must be specified') - credentials = <ServerCredentials>user_data - if not credentials.initial_cert_config_fetched: - # C-core is asking for the initial cert config - credentials.initial_cert_config_fetched = True - cert_config = credentials.initial_cert_config._certificate_configuration - else: - user_cb = credentials.cert_config_fetcher - try: - cert_config_wrapper = user_cb() - except Exception: - _LOGGER.exception('Error fetching certificate config') - return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL - if cert_config_wrapper is None: - return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED - elif not isinstance( - cert_config_wrapper, grpc.ServerCertificateConfiguration): - _LOGGER.error( - 'Error fetching certificate configuration: certificate ' - 'configuration must be of type grpc.ServerCertificateConfiguration, ' - 'not %s' % type(cert_config_wrapper).__name__) - return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL - else: - cert_config = cert_config_wrapper._certificate_configuration - config[0] = <grpc_ssl_server_certificate_config*>cert_config.c_cert_config - # our caller will assume ownership of memory, so we have to recreate - # a copy of c_cert_config here - cert_config.c_cert_config = grpc_ssl_server_certificate_config_create( - cert_config.c_pem_root_certs, cert_config.c_ssl_pem_key_cert_pairs, - cert_config.c_ssl_pem_key_cert_pairs_count) - return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW - cdef class Server: def __cinit__(self, object arguments): - grpc_init() + fork_handlers_and_grpc_init() self.references = [] self.registered_completion_queues = [] self._vtable.copy = &_copy_pointer diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index 0cc26bc0d0..8258b857bc 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -31,3 +31,6 @@ include "_cygrpc/time.pxd.pxi" include "_cygrpc/_hooks.pxd.pxi" include "_cygrpc/grpc_gevent.pxd.pxi" + +IF UNAME_SYSNAME != "Windows": + include "_cygrpc/fork_posix.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index 3cac406687..026f7ba2e3 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -39,6 +39,11 @@ include "_cygrpc/_hooks.pyx.pxi" include "_cygrpc/grpc_gevent.pyx.pxi" +IF UNAME_SYSNAME == "Windows": + include "_cygrpc/fork_windows.pyx.pxi" +ELSE: + include "_cygrpc/fork_posix.pyx.pxi" + # # initialize gRPC # diff --git a/src/python/grpcio/grpc/_grpcio_metadata.py b/src/python/grpcio/grpc/_grpcio_metadata.py index c33911ebc1..24e1557578 100644 --- a/src/python/grpcio/grpc/_grpcio_metadata.py +++ b/src/python/grpcio/grpc/_grpcio_metadata.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc/_grpcio_metadata.py.template`!!! -__version__ = """1.15.0.dev0""" +__version__ = """1.16.0.dev0""" diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index a8158311fb..0f68e823d7 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -82,6 +82,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/http/format_request.cc', 'src/core/lib/http/httpcli.cc', 'src/core/lib/http/parser.cc', + 'src/core/lib/iomgr/buffer_list.cc', 'src/core/lib/iomgr/call_combiner.cc', 'src/core/lib/iomgr/combiner.cc', 'src/core/lib/iomgr/endpoint.cc', @@ -102,6 +103,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/iomgr/gethostname_fallback.cc', 'src/core/lib/iomgr/gethostname_host_name_max.cc', 'src/core/lib/iomgr/gethostname_sysconf.cc', + 'src/core/lib/iomgr/internal_errqueue.cc', 'src/core/lib/iomgr/iocp_windows.cc', 'src/core/lib/iomgr/iomgr.cc', 'src/core/lib/iomgr/iomgr_custom.cc', @@ -363,7 +365,7 @@ CORE_SOURCE_FILES = [ 'src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc', 'src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc', 'src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc', - 'src/cpp/ext/filters/census/grpc_context.cc', + 'src/core/ext/filters/census/grpc_context.cc', 'src/core/ext/filters/max_age/max_age_filter.cc', 'src/core/ext/filters/message_size/message_size_filter.cc', 'src/core/ext/filters/http/client_authority_filter.cc', diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py index 9337800a33..6ffe1eb827 100644 --- a/src/python/grpcio/grpc_version.py +++ b/src/python/grpcio/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!! -VERSION = '1.15.0.dev0' +VERSION = '1.16.0.dev0' diff --git a/src/python/grpcio_health_checking/grpc_version.py b/src/python/grpcio_health_checking/grpc_version.py index 3b84f7a4c5..e080bf2cbc 100644 --- a/src/python/grpcio_health_checking/grpc_version.py +++ b/src/python/grpcio_health_checking/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_health_checking/grpc_version.py.template`!!! -VERSION = '1.15.0.dev0' +VERSION = '1.16.0.dev0' diff --git a/src/python/grpcio_reflection/grpc_version.py b/src/python/grpcio_reflection/grpc_version.py index 7b0e48ea23..4b3b95fee9 100644 --- a/src/python/grpcio_reflection/grpc_version.py +++ b/src/python/grpcio_reflection/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_reflection/grpc_version.py.template`!!! -VERSION = '1.15.0.dev0' +VERSION = '1.16.0.dev0' diff --git a/src/python/grpcio_testing/grpc_version.py b/src/python/grpcio_testing/grpc_version.py index df9953fa25..c12aa153a4 100644 --- a/src/python/grpcio_testing/grpc_version.py +++ b/src/python/grpcio_testing/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_testing/grpc_version.py.template`!!! -VERSION = '1.15.0.dev0' +VERSION = '1.16.0.dev0' diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index a23c980017..0dfbf3180b 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -202,3 +202,28 @@ class RunInterop(test.test): from tests.interop import client sys.argv[1:] = self.args.split() client.test_interoperability() + + +class RunFork(test.test): + + description = 'run fork test client' + user_options = [('args=', 'a', 'pass-thru arguments for the client')] + + def initialize_options(self): + self.args = '' + + def finalize_options(self): + # distutils requires this override. + pass + + def run(self): + if self.distribution.install_requires: + self.distribution.fetch_build_eggs( + self.distribution.install_requires) + if self.distribution.tests_require: + self.distribution.fetch_build_eggs(self.distribution.tests_require) + # We import here to ensure that our setuptools parent has had a chance to + # edit the Python system path. + from tests.fork import client + sys.argv[1:] = self.args.split() + client.test_fork() diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py index b2cf129e4f..f4b8a34a46 100644 --- a/src/python/grpcio_tests/grpc_version.py +++ b/src/python/grpcio_tests/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! -VERSION = '1.15.0.dev0' +VERSION = '1.16.0.dev0' diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index a94c0963ec..61c98fa038 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -52,6 +52,7 @@ COMMAND_CLASS = { 'preprocess': commands.GatherProto, 'build_package_protos': grpc_tools.command.BuildPackageProtos, 'build_py': commands.BuildPy, + 'run_fork': commands.RunFork, 'run_interop': commands.RunInterop, 'test_lite': commands.TestLite, 'test_gevent': commands.TestGevent, diff --git a/src/python/grpcio_tests/tests/fork/__init__.py b/src/python/grpcio_tests/tests/fork/__init__.py new file mode 100644 index 0000000000..9a26bac010 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_tests/tests/fork/client.py b/src/python/grpcio_tests/tests/fork/client.py new file mode 100644 index 0000000000..9a32629ed5 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/client.py @@ -0,0 +1,76 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC interoperability test client.""" + +import argparse +import logging +import sys + +from tests.fork import methods + + +def _args(): + + def parse_bool(value): + if value == 'true': + return True + if value == 'false': + return False + raise argparse.ArgumentTypeError('Only true/false allowed') + + parser = argparse.ArgumentParser() + parser.add_argument( + '--server_host', + default="localhost", + type=str, + help='the host to which to connect') + parser.add_argument( + '--server_port', + type=int, + required=True, + help='the port to which to connect') + parser.add_argument( + '--test_case', + default='large_unary', + type=str, + help='the test case to execute') + parser.add_argument( + '--use_tls', + default=False, + type=parse_bool, + help='require a secure connection') + return parser.parse_args() + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case "%s"!' % test_case_arg) + + +def test_fork(): + logging.basicConfig(level=logging.INFO) + args = _args() + if args.test_case == "all": + for test_case in methods.TestCase: + test_case.run_test(args) + else: + test_case = _test_case_from_arg(args.test_case) + test_case.run_test(args) + + +if __name__ == '__main__': + test_fork() diff --git a/src/python/grpcio_tests/tests/fork/methods.py b/src/python/grpcio_tests/tests/fork/methods.py new file mode 100644 index 0000000000..889ef13cb2 --- /dev/null +++ b/src/python/grpcio_tests/tests/fork/methods.py @@ -0,0 +1,445 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementations of fork support test methods.""" + +import enum +import json +import logging +import multiprocessing +import os +import threading +import time + +import grpc + +from six.moves import queue + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc + +_LOGGER = logging.getLogger(__name__) + + +def _channel(args): + target = '{}:{}'.format(args.server_host, args.server_port) + if args.use_tls: + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(target, channel_credentials) + else: + channel = grpc.insecure_channel(target) + return channel + + +def _validate_payload_type_and_length(response, expected_type, expected_length): + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +def _async_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response_future = stub.UnaryCall.future(request) + response = response_future.result() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +def _blocking_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response = stub.UnaryCall(request) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +class _Pipe(object): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._open = True + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class _ChildProcess(object): + + def __init__(self, task, args=None): + if args is None: + args = () + self._exceptions = multiprocessing.Queue() + + def record_exceptions(): + try: + task(*args) + except Exception as e: # pylint: disable=broad-except + self._exceptions.put(e) + + self._process = multiprocessing.Process(target=record_exceptions) + + def start(self): + self._process.start() + + def finish(self): + self._process.join() + if self._process.exitcode != 0: + raise ValueError('Child process failed with exitcode %d' % + self._process.exitcode) + try: + exception = self._exceptions.get(block=False) + raise ValueError('Child process failed: %s' % exception) + except queue.Empty: + pass + + +def _async_unary_same_channel(channel): + + def child_target(): + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _async_unary_new_channel(channel, args): + + def child_target(): + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _async_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _blocking_unary_same_channel(channel): + + def child_target(): + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + child_process.finish() + + +def _blocking_unary_new_channel(channel, args): + + def child_target(): + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(stub) + child_process.finish() + + +# Verify that the fork channel registry can handle already closed channels +def _close_channel_before_fork(channel, args): + + def child_target(): + new_channel.close() + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + channel.close() + + new_channel = _channel(args) + new_stub = test_pb2_grpc.TestServiceStub(new_channel) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(new_stub) + child_process.finish() + + +def _connectivity_watch(channel, args): + + def child_target(): + + def child_connectivity_callback(state): + child_states.append(state) + + child_states = [] + child_channel = _channel(args) + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + child_channel.subscribe(child_connectivity_callback) + _async_unary(child_stub) + if len(child_states + ) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY: + raise ValueError('Channel did not move to READY') + if len(parent_states) > 1: + raise ValueError('Received connectivity updates on parent callback') + child_channel.unsubscribe(child_connectivity_callback) + child_channel.close() + + def parent_connectivity_callback(state): + parent_states.append(state) + + parent_states = [] + channel.subscribe(parent_connectivity_callback) + stub = test_pb2_grpc.TestServiceStub(channel) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + if len(parent_states + ) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY: + raise ValueError('Channel did not move to READY') + channel.unsubscribe(parent_connectivity_callback) + child_process.finish() + + # Need to unsubscribe or _channel.py in _poll_connectivity triggers a + # "Cannot invoke RPC on closed channel!" error. + # TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling + channel.unsubscribe(parent_connectivity_callback) + + +def _ping_pong_with_child_processes_after_first_response( + channel, args, child_target, run_after_close=True): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + stub = test_pb2_grpc.TestServiceStub(channel) + pipe = _Pipe() + parent_bidi_call = stub.FullDuplexCall(pipe) + child_processes = [] + first_message_received = False + for response_size, payload_size in zip(request_response_sizes, + request_payload_sizes): + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=( + messages_pb2.ResponseParameters(size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + pipe.add(request) + if first_message_received: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + response = next(parent_bidi_call) + first_message_received = True + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + response_size) + pipe.close() + if run_after_close: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + for child_process in child_processes: + child_process.finish() + + +def _in_progress_bidi_continue_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + inherited_code = parent_bidi_call.code() + inherited_details = parent_bidi_call.details() + if inherited_code != grpc.StatusCode.CANCELLED: + raise ValueError( + 'Expected inherited code CANCELLED, got %s' % inherited_code) + if inherited_details != 'Channel closed due to fork': + raise ValueError( + 'Expected inherited details Channel closed due to fork, got %s' + % inherited_details) + + # Don't run child_target after closing the parent call, as the call may have + # received a status from the server before fork occurs. + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target, run_after_close=False) + + +def _in_progress_bidi_same_channel_async_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_same_channel_blocking_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_new_channel_async_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + channel = _channel(args) + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +def _in_progress_bidi_new_channel_blocking_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + channel = _channel(args) + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +@enum.unique +class TestCase(enum.Enum): + + CONNECTIVITY_WATCH = 'connectivity_watch' + CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork' + ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel' + ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel' + BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel' + BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel' + IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call' + + def run_test(self, args): + _LOGGER.info("Running %s", self) + channel = _channel(args) + if self is TestCase.ASYNC_UNARY_SAME_CHANNEL: + _async_unary_same_channel(channel) + elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL: + _async_unary_new_channel(channel, args) + elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL: + _blocking_unary_same_channel(channel) + elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL: + _blocking_unary_new_channel(channel, args) + elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK: + _close_channel_before_fork(channel, args) + elif self is TestCase.CONNECTIVITY_WATCH: + _connectivity_watch(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL: + _in_progress_bidi_continue_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL: + _in_progress_bidi_same_channel_async_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_same_channel_blocking_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL: + _in_progress_bidi_new_channel_async_call(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_new_channel_blocking_call(channel, args) + else: + raise NotImplementedError( + 'Test case "%s" not implemented!' % self.name) + channel.close() diff --git a/src/python/grpcio_tests/tests/stress/test_runner.py b/src/python/grpcio_tests/tests/stress/test_runner.py index 764cda17fb..e66eda64a8 100644 --- a/src/python/grpcio_tests/tests/stress/test_runner.py +++ b/src/python/grpcio_tests/tests/stress/test_runner.py @@ -53,5 +53,5 @@ class TestRunner(threading.Thread): except Exception as e: # pylint: disable=broad-except traceback.print_exc() self._exception_queue.put( - Exception("An exception occured during test {}" + Exception("An exception occurred during test {}" .format(test_case), e)) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index ebc41c63f0..76d5d22d57 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -32,6 +32,8 @@ "unit._credentials_test.CredentialsTest", "unit._cython._cancel_many_calls_test.CancelManyCallsTest", "unit._cython._channel_test.ChannelTest", + "unit._cython._fork_test.ForkPosixTester", + "unit._cython._fork_test.ForkWindowsTester", "unit._cython._no_messages_server_completion_queue_per_call_test.Test", "unit._cython._no_messages_single_server_completion_queue_test.Test", "unit._cython._read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", diff --git a/src/python/grpcio_tests/tests/unit/_channel_args_test.py b/src/python/grpcio_tests/tests/unit/_channel_args_test.py index 1a2d2c0117..869c2f4d2f 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_args_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_args_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests of Channel Args on client/server side.""" +from concurrent import futures import unittest import grpc @@ -39,7 +40,9 @@ class ChannelArgsTest(unittest.TestCase): grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS) def test_server(self): - grpc.server(None, options=TEST_CHANNEL_ARGS) + grpc.server( + futures.ThreadPoolExecutor(max_workers=1), + options=TEST_CHANNEL_ARGS) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py new file mode 100644 index 0000000000..aeb02458a7 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +import unittest + +from grpc._cython import cygrpc + + +def _get_number_active_threads(): + return cygrpc._fork_state.active_thread_count._num_active_threads + + +@unittest.skipIf(os.name == 'nt', 'Posix-specific tests') +class ForkPosixTester(unittest.TestCase): + + def setUp(self): + cygrpc._GRPC_ENABLE_FORK_SUPPORT = True + + def testForkManagedThread(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + def testForkManagedThreadThrowsException(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + raise Exception("expected exception") + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + +@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') +class ForkWindowsTester(unittest.TestCase): + + def testForkManagedThreadIsNoOp(self): + + def cb(): + pass + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + + +if __name__ == '__main__': + unittest.main(verbosity=2) |