aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.pylintrc-tests2
-rw-r--r--src/core/lib/gprpp/fork.cc73
-rw-r--r--src/core/lib/gprpp/fork.h17
-rw-r--r--src/core/lib/iomgr/ev_epoll1_linux.cc72
-rw-r--r--src/core/lib/iomgr/fork_posix.cc5
-rw-r--r--src/python/grpcio/grpc/_channel.py59
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi1
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi63
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi8
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi29
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi203
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi63
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/cygrpc.pxd3
-rw-r--r--src/python/grpcio/grpc/_cython/cygrpc.pyx5
-rw-r--r--src/python/grpcio_tests/commands.py25
-rw-r--r--src/python/grpcio_tests/setup.py1
-rw-r--r--src/python/grpcio_tests/tests/fork/__init__.py13
-rw-r--r--src/python/grpcio_tests/tests/fork/client.py76
-rw-r--r--src/python/grpcio_tests/tests/fork/methods.py445
-rw-r--r--src/python/grpcio_tests/tests/tests.json2
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_fork_test.py68
25 files changed, 1167 insertions, 74 deletions
diff --git a/.pylintrc-tests b/.pylintrc-tests
index ebe9d507cd..e68755c674 100644
--- a/.pylintrc-tests
+++ b/.pylintrc-tests
@@ -20,6 +20,8 @@ notes=FIXME,XXX
[MESSAGES CONTROL]
+extension-pkg-whitelist=grpc._cython.cygrpc
+
disable=
# These suppressions are specific to tests:
#
diff --git a/src/core/lib/gprpp/fork.cc b/src/core/lib/gprpp/fork.cc
index f6d9a87d2c..0288c39680 100644
--- a/src/core/lib/gprpp/fork.cc
+++ b/src/core/lib/gprpp/fork.cc
@@ -157,11 +157,11 @@ class ThreadState {
} // namespace
void Fork::GlobalInit() {
- if (!overrideEnabled_) {
+ if (!override_enabled_) {
#ifdef GRPC_ENABLE_FORK_SUPPORT
- supportEnabled_ = true;
+ support_enabled_ = true;
#else
- supportEnabled_ = false;
+ support_enabled_ = false;
#endif
bool env_var_set = false;
char* env = gpr_getenv("GRPC_ENABLE_FORK_SUPPORT");
@@ -172,7 +172,7 @@ void Fork::GlobalInit() {
"False", "FALSE", "0"};
for (size_t i = 0; i < GPR_ARRAY_SIZE(truthy); i++) {
if (0 == strcmp(env, truthy[i])) {
- supportEnabled_ = true;
+ support_enabled_ = true;
env_var_set = true;
break;
}
@@ -180,7 +180,7 @@ void Fork::GlobalInit() {
if (!env_var_set) {
for (size_t i = 0; i < GPR_ARRAY_SIZE(falsey); i++) {
if (0 == strcmp(env, falsey[i])) {
- supportEnabled_ = false;
+ support_enabled_ = false;
env_var_set = true;
break;
}
@@ -189,72 +189,79 @@ void Fork::GlobalInit() {
gpr_free(env);
}
}
- if (supportEnabled_) {
- execCtxState_ = grpc_core::New<internal::ExecCtxState>();
- threadState_ = grpc_core::New<internal::ThreadState>();
+ if (support_enabled_) {
+ exec_ctx_state_ = grpc_core::New<internal::ExecCtxState>();
+ thread_state_ = grpc_core::New<internal::ThreadState>();
}
}
void Fork::GlobalShutdown() {
- if (supportEnabled_) {
- grpc_core::Delete(execCtxState_);
- grpc_core::Delete(threadState_);
+ if (support_enabled_) {
+ grpc_core::Delete(exec_ctx_state_);
+ grpc_core::Delete(thread_state_);
}
}
-bool Fork::Enabled() { return supportEnabled_; }
+bool Fork::Enabled() { return support_enabled_; }
// Testing Only
void Fork::Enable(bool enable) {
- overrideEnabled_ = true;
- supportEnabled_ = enable;
+ override_enabled_ = true;
+ support_enabled_ = enable;
}
void Fork::IncExecCtxCount() {
- if (supportEnabled_) {
- execCtxState_->IncExecCtxCount();
+ if (support_enabled_) {
+ exec_ctx_state_->IncExecCtxCount();
}
}
void Fork::DecExecCtxCount() {
- if (supportEnabled_) {
- execCtxState_->DecExecCtxCount();
+ if (support_enabled_) {
+ exec_ctx_state_->DecExecCtxCount();
}
}
+void Fork::SetResetChildPollingEngineFunc(Fork::child_postfork_func func) {
+ reset_child_polling_engine_ = func;
+}
+Fork::child_postfork_func Fork::GetResetChildPollingEngineFunc() {
+ return reset_child_polling_engine_;
+}
+
bool Fork::BlockExecCtx() {
- if (supportEnabled_) {
- return execCtxState_->BlockExecCtx();
+ if (support_enabled_) {
+ return exec_ctx_state_->BlockExecCtx();
}
return false;
}
void Fork::AllowExecCtx() {
- if (supportEnabled_) {
- execCtxState_->AllowExecCtx();
+ if (support_enabled_) {
+ exec_ctx_state_->AllowExecCtx();
}
}
void Fork::IncThreadCount() {
- if (supportEnabled_) {
- threadState_->IncThreadCount();
+ if (support_enabled_) {
+ thread_state_->IncThreadCount();
}
}
void Fork::DecThreadCount() {
- if (supportEnabled_) {
- threadState_->DecThreadCount();
+ if (support_enabled_) {
+ thread_state_->DecThreadCount();
}
}
void Fork::AwaitThreads() {
- if (supportEnabled_) {
- threadState_->AwaitThreads();
+ if (support_enabled_) {
+ thread_state_->AwaitThreads();
}
}
-internal::ExecCtxState* Fork::execCtxState_ = nullptr;
-internal::ThreadState* Fork::threadState_ = nullptr;
-bool Fork::supportEnabled_ = false;
-bool Fork::overrideEnabled_ = false;
-
+internal::ExecCtxState* Fork::exec_ctx_state_ = nullptr;
+internal::ThreadState* Fork::thread_state_ = nullptr;
+bool Fork::support_enabled_ = false;
+bool Fork::override_enabled_ = false;
+Fork::child_postfork_func Fork::reset_child_polling_engine_ = nullptr;
} // namespace grpc_core
diff --git a/src/core/lib/gprpp/fork.h b/src/core/lib/gprpp/fork.h
index 123e22c4c6..5a7404f0d9 100644
--- a/src/core/lib/gprpp/fork.h
+++ b/src/core/lib/gprpp/fork.h
@@ -33,6 +33,8 @@ class ThreadState;
class Fork {
public:
+ typedef void (*child_postfork_func)(void);
+
static void GlobalInit();
static void GlobalShutdown();
@@ -46,6 +48,12 @@ class Fork {
// Decrement the count of active ExecCtxs
static void DecExecCtxCount();
+ // Provide a function that will be invoked in the child's postfork handler to
+ // reset the polling engine's internal state.
+ static void SetResetChildPollingEngineFunc(
+ child_postfork_func reset_child_polling_engine);
+ static child_postfork_func GetResetChildPollingEngineFunc();
+
// Check if there is a single active ExecCtx
// (the one used to invoke this function). If there are more,
// return false. Otherwise, return true and block creation of
@@ -68,10 +76,11 @@ class Fork {
static void Enable(bool enable);
private:
- static internal::ExecCtxState* execCtxState_;
- static internal::ThreadState* threadState_;
- static bool supportEnabled_;
- static bool overrideEnabled_;
+ static internal::ExecCtxState* exec_ctx_state_;
+ static internal::ThreadState* thread_state_;
+ static bool support_enabled_;
+ static bool override_enabled_;
+ static child_postfork_func reset_child_polling_engine_;
};
} // namespace grpc_core
diff --git a/src/core/lib/iomgr/ev_epoll1_linux.cc b/src/core/lib/iomgr/ev_epoll1_linux.cc
index 66e0f1fd6d..aa5016bd8f 100644
--- a/src/core/lib/iomgr/ev_epoll1_linux.cc
+++ b/src/core/lib/iomgr/ev_epoll1_linux.cc
@@ -131,6 +131,13 @@ static void epoll_set_shutdown() {
* Fd Declarations
*/
+/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
+struct grpc_fork_fd_list {
+ grpc_fd* fd;
+ grpc_fd* next;
+ grpc_fd* prev;
+};
+
struct grpc_fd {
int fd;
@@ -141,6 +148,9 @@ struct grpc_fd {
struct grpc_fd* freelist_next;
grpc_iomgr_object iomgr_object;
+
+ /* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
+ grpc_fork_fd_list* fork_fd_list;
};
static void fd_global_init(void);
@@ -256,6 +266,10 @@ static bool append_error(grpc_error** composite, grpc_error* error,
static grpc_fd* fd_freelist = nullptr;
static gpr_mu fd_freelist_mu;
+/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
+static grpc_fd* fork_fd_list_head = nullptr;
+static gpr_mu fork_fd_list_mu;
+
static void fd_global_init(void) { gpr_mu_init(&fd_freelist_mu); }
static void fd_global_shutdown(void) {
@@ -269,6 +283,38 @@ static void fd_global_shutdown(void) {
gpr_mu_destroy(&fd_freelist_mu);
}
+static void fork_fd_list_add_grpc_fd(grpc_fd* fd) {
+ if (grpc_core::Fork::Enabled()) {
+ gpr_mu_lock(&fork_fd_list_mu);
+ fd->fork_fd_list =
+ static_cast<grpc_fork_fd_list*>(gpr_malloc(sizeof(grpc_fork_fd_list)));
+ fd->fork_fd_list->next = fork_fd_list_head;
+ fd->fork_fd_list->prev = nullptr;
+ if (fork_fd_list_head != nullptr) {
+ fork_fd_list_head->fork_fd_list->prev = fd;
+ }
+ fork_fd_list_head = fd;
+ gpr_mu_unlock(&fork_fd_list_mu);
+ }
+}
+
+static void fork_fd_list_remove_grpc_fd(grpc_fd* fd) {
+ if (grpc_core::Fork::Enabled()) {
+ gpr_mu_lock(&fork_fd_list_mu);
+ if (fork_fd_list_head == fd) {
+ fork_fd_list_head = fd->fork_fd_list->next;
+ }
+ if (fd->fork_fd_list->prev != nullptr) {
+ fd->fork_fd_list->prev->fork_fd_list->next = fd->fork_fd_list->next;
+ }
+ if (fd->fork_fd_list->next != nullptr) {
+ fd->fork_fd_list->next->fork_fd_list->prev = fd->fork_fd_list->prev;
+ }
+ gpr_free(fd->fork_fd_list);
+ gpr_mu_unlock(&fork_fd_list_mu);
+ }
+}
+
static grpc_fd* fd_create(int fd, const char* name, bool track_err) {
grpc_fd* new_fd = nullptr;
@@ -295,6 +341,7 @@ static grpc_fd* fd_create(int fd, const char* name, bool track_err) {
char* fd_name;
gpr_asprintf(&fd_name, "%s fd=%d", name, fd);
grpc_iomgr_register_object(&new_fd->iomgr_object, fd_name);
+ fork_fd_list_add_grpc_fd(new_fd);
#ifndef NDEBUG
if (grpc_trace_fd_refcount.enabled()) {
gpr_log(GPR_DEBUG, "FD %d %p create %s", fd, new_fd, fd_name);
@@ -361,6 +408,7 @@ static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd,
GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_REF(error));
grpc_iomgr_unregister_object(&fd->iomgr_object);
+ fork_fd_list_remove_grpc_fd(fd);
fd->read_closure->DestroyEvent();
fd->write_closure->DestroyEvent();
fd->error_closure->DestroyEvent();
@@ -1190,6 +1238,10 @@ static void shutdown_engine(void) {
fd_global_shutdown();
pollset_global_shutdown();
epoll_set_shutdown();
+ if (grpc_core::Fork::Enabled()) {
+ gpr_mu_destroy(&fork_fd_list_mu);
+ grpc_core::Fork::SetResetChildPollingEngineFunc(nullptr);
+ }
}
static const grpc_event_engine_vtable vtable = {
@@ -1227,6 +1279,21 @@ static const grpc_event_engine_vtable vtable = {
shutdown_engine,
};
+/* Called by the child process's post-fork handler to close open fds, including
+ * the global epoll fd. This allows gRPC to shutdown in the child process
+ * without interfering with connections or RPCs ongoing in the parent. */
+static void reset_event_manager_on_fork() {
+ gpr_mu_lock(&fork_fd_list_mu);
+ while (fork_fd_list_head != nullptr) {
+ close(fork_fd_list_head->fd);
+ fork_fd_list_head->fd = -1;
+ fork_fd_list_head = fork_fd_list_head->fork_fd_list->next;
+ }
+ gpr_mu_unlock(&fork_fd_list_mu);
+ shutdown_engine();
+ grpc_init_epoll1_linux(true);
+}
+
/* It is possible that GLIBC has epoll but the underlying kernel doesn't.
* Create epoll_fd (epoll_set_init() takes care of that) to make sure epoll
* support is available */
@@ -1248,6 +1315,11 @@ const grpc_event_engine_vtable* grpc_init_epoll1_linux(bool explicit_request) {
return nullptr;
}
+ if (grpc_core::Fork::Enabled()) {
+ gpr_mu_init(&fork_fd_list_mu);
+ grpc_core::Fork::SetResetChildPollingEngineFunc(
+ reset_event_manager_on_fork);
+ }
return &vtable;
}
diff --git a/src/core/lib/iomgr/fork_posix.cc b/src/core/lib/iomgr/fork_posix.cc
index b37384b8db..a5b61fb4ce 100644
--- a/src/core/lib/iomgr/fork_posix.cc
+++ b/src/core/lib/iomgr/fork_posix.cc
@@ -84,6 +84,11 @@ void grpc_postfork_child() {
if (!skipped_handler) {
grpc_core::Fork::AllowExecCtx();
grpc_core::ExecCtx exec_ctx;
+ grpc_core::Fork::child_postfork_func reset_polling_engine =
+ grpc_core::Fork::GetResetChildPollingEngineFunc();
+ if (reset_polling_engine != nullptr) {
+ reset_polling_engine();
+ }
grpc_timer_manager_set_threading(true);
grpc_executor_set_threading(true);
}
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index e9246991df..6876601785 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -111,6 +111,10 @@ class _RPCState(object):
# prior to termination of the RPC.
self.cancelled = False
self.callbacks = []
+ self.fork_epoch = cygrpc.get_fork_epoch()
+
+ def reset_postfork_child(self):
+ self.condition = threading.Condition()
def _abort(state, code, details):
@@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer):
done = not state.due
for callback in callbacks:
callback()
- return done
+ return done and state.fork_epoch >= cygrpc.get_fork_epoch()
return handle_event
def _consume_request_iterator(request_iterator, state, call, request_serializer,
event_handler):
+ if cygrpc.is_fork_support_enabled():
+ condition_wait_timeout = 1.0
+ else:
+ condition_wait_timeout = None
def consume_request_iterator(): # pylint: disable=too-many-branches
while True:
+ return_from_user_request_generator_invoked = False
try:
+ # The thread may die in user-code. Do not block fork for this.
+ cygrpc.enter_user_request_generator()
request = next(request_iterator)
except StopIteration:
break
except Exception: # pylint: disable=broad-except
+ cygrpc.return_from_user_request_generator()
+ return_from_user_request_generator_invoked = True
code = grpc.StatusCode.UNKNOWN
details = 'Exception iterating requests!'
_LOGGER.exception(details)
@@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
details)
_abort(state, code, details)
return
+ finally:
+ if not return_from_user_request_generator_invoked:
+ cygrpc.return_from_user_request_generator()
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
if state.code is None and not state.cancelled:
@@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
else:
return
while True:
- state.condition.wait()
+ state.condition.wait(condition_wait_timeout)
+ cygrpc.block_if_fork_in_progress(state)
if state.code is None:
if cygrpc.OperationType.send_message not in state.due:
break
@@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
if operating:
state.due.add(cygrpc.OperationType.send_close_from_client)
- consumption_thread = threading.Thread(target=consume_request_iterator)
- consumption_thread.daemon = True
+ consumption_thread = cygrpc.ForkManagedThread(
+ target=consume_request_iterator)
+ consumption_thread.setDaemon(True)
consumption_thread.start()
@@ -671,13 +689,20 @@ class _ChannelCallState(object):
self.lock = threading.Lock()
self.channel = channel
self.managed_calls = 0
+ self.threading = False
+
+ def reset_postfork_child(self):
+ self.managed_calls = 0
def _run_channel_spin_thread(state):
def channel_spin():
while True:
+ cygrpc.block_if_fork_in_progress(state)
event = state.channel.next_call_event()
+ if event.completion_type == cygrpc.CompletionType.queue_timeout:
+ continue
call_completed = event.tag(event)
if call_completed:
with state.lock:
@@ -685,8 +710,8 @@ def _run_channel_spin_thread(state):
if state.managed_calls == 0:
return
- channel_spin_thread = threading.Thread(target=channel_spin)
- channel_spin_thread.daemon = True
+ channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin)
+ channel_spin_thread.setDaemon(True)
channel_spin_thread.start()
@@ -742,6 +767,13 @@ class _ChannelConnectivityState(object):
self.callbacks_and_connectivities = []
self.delivering = False
+ def reset_postfork_child(self):
+ self.polling = False
+ self.connectivity = None
+ self.try_to_connect = False
+ self.callbacks_and_connectivities = []
+ self.delivering = False
+
def _deliveries(state):
callbacks_needing_update = []
@@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
callbacks = initial_callbacks
while True:
for callback in callbacks:
+ cygrpc.block_if_fork_in_progress(state)
callable_util.call_logging_exceptions(
callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
connectivity)
@@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
def _spawn_delivery(state, callbacks):
- delivering_thread = threading.Thread(
+ delivering_thread = cygrpc.ForkManagedThread(
target=_deliver, args=(
state,
state.connectivity,
@@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
while True:
event = channel.watch_connectivity_state(connectivity,
time.time() + 0.2)
+ cygrpc.block_if_fork_in_progress(state)
with state.lock:
if not state.callbacks_and_connectivities and not state.try_to_connect:
state.polling = False
@@ -826,10 +860,10 @@ def _moot(state):
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
- polling_thread = threading.Thread(
+ polling_thread = cygrpc.ForkManagedThread(
target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
- polling_thread.daemon = True
+ polling_thread.setDaemon(True)
polling_thread.start()
state.polling = True
state.callbacks_and_connectivities.append([callback, None])
@@ -876,6 +910,7 @@ class Channel(grpc.Channel):
_common.encode(target), _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
+ cygrpc.fork_register_channel(self)
def subscribe(self, callback, try_to_connect=None):
_subscribe(self._connectivity_state, callback, try_to_connect)
@@ -919,6 +954,11 @@ class Channel(grpc.Channel):
self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!')
_moot(self._connectivity_state)
+ def _close_on_fork(self):
+ self._channel.close_on_fork(cygrpc.StatusCode.cancelled,
+ 'Channel closed due to fork')
+ _moot(self._connectivity_state)
+
def __enter__(self):
return self
@@ -939,4 +979,5 @@ class Channel(grpc.Channel):
# for as long as they are in use and to close them after using them,
# then deletion of this grpc._channel.Channel instance can be made to
# effect closure of the underlying cygrpc.Channel instance.
+ cygrpc.fork_unregister_channel(self)
_moot(self._connectivity_state)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
index a0de862d94..24e85b08e7 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
@@ -19,7 +19,7 @@ cdef class Call:
def __cinit__(self):
# Create an *empty* call
- grpc_init()
+ fork_handlers_and_grpc_init()
self.c_call = NULL
self.references = []
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
index f067d76fab..ced32abba1 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
@@ -40,6 +40,7 @@ cdef class _ChannelState:
# field and just use the NULLness of c_channel as an indication that the
# channel is closed.
cdef object open
+ cdef object closed_reason
# A dict from _BatchOperationTag to _CallState
cdef dict integrated_call_states
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
index aa187e88a6..a81ff4d823 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
@@ -15,6 +15,7 @@
cimport cpython
import threading
+import time
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = (
'Internal gRPC call error %d. ' +
@@ -83,6 +84,7 @@ cdef class _ChannelState:
self.integrated_call_states = {}
self.segregated_call_states = set()
self.connectivity_due = set()
+ self.closed_reason = None
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
@@ -142,10 +144,10 @@ cdef _cancel(
_check_and_raise_call_error_no_metadata(c_call_error)
-cdef BatchOperationEvent _next_call_event(
+cdef _next_call_event(
_ChannelState channel_state, grpc_completion_queue *c_completion_queue,
- on_success):
- tag, event = _latent_event(c_completion_queue, None)
+ on_success, deadline):
+ tag, event = _latent_event(c_completion_queue, deadline)
with channel_state.condition:
on_success(tag)
channel_state.condition.notify_all()
@@ -229,8 +231,7 @@ cdef void _call(
call_state.due.update(started_tags)
on_success(started_tags)
else:
- raise ValueError('Cannot invoke RPC on closed channel!')
-
+ raise ValueError('Cannot invoke RPC: %s' % channel_state.closed_reason)
cdef void _process_integrated_call_tag(
_ChannelState state, _BatchOperationTag tag) except *:
cdef _CallState call_state = state.integrated_call_states.pop(tag)
@@ -302,7 +303,7 @@ cdef class SegregatedCall:
_process_segregated_call_tag(
self._channel_state, self._call_state, self._c_completion_queue, tag)
return _next_call_event(
- self._channel_state, self._c_completion_queue, on_success)
+ self._channel_state, self._c_completion_queue, on_success, None)
cdef SegregatedCall _segregated_call(
@@ -346,7 +347,7 @@ cdef object _watch_connectivity_state(
state.c_connectivity_completion_queue, <cpython.PyObject *>tag)
state.connectivity_due.add(tag)
else:
- raise ValueError('Cannot invoke RPC on closed channel!')
+ raise ValueError('Cannot invoke RPC: %s' % state.closed_reason)
completed_tag, event = _latent_event(
state.c_connectivity_completion_queue, None)
with state.condition:
@@ -355,12 +356,15 @@ cdef object _watch_connectivity_state(
return event
-cdef _close(_ChannelState state, grpc_status_code code, object details):
+cdef _close(Channel channel, grpc_status_code code, object details,
+ drain_calls):
+ cdef _ChannelState state = channel._state
cdef _CallState call_state
encoded_details = _encode(details)
with state.condition:
if state.open:
state.open = False
+ state.closed_reason = details
for call_state in set(state.integrated_call_states.values()):
grpc_call_cancel_with_status(
call_state.c_call, code, encoded_details, NULL)
@@ -370,12 +374,19 @@ cdef _close(_ChannelState state, grpc_status_code code, object details):
# TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity
# watching.
- while state.integrated_call_states:
- state.condition.wait()
- while state.segregated_call_states:
- state.condition.wait()
- while state.connectivity_due:
- state.condition.wait()
+ if drain_calls:
+ while not _calls_drained(state):
+ event = channel.next_call_event()
+ if event.completion_type == CompletionType.queue_timeout:
+ continue
+ event.tag(event)
+ else:
+ while state.integrated_call_states:
+ state.condition.wait()
+ while state.segregated_call_states:
+ state.condition.wait()
+ while state.connectivity_due:
+ state.condition.wait()
_destroy_c_completion_queue(state.c_call_completion_queue)
_destroy_c_completion_queue(state.c_connectivity_completion_queue)
@@ -390,13 +401,17 @@ cdef _close(_ChannelState state, grpc_status_code code, object details):
state.condition.wait()
+cdef _calls_drained(_ChannelState state):
+ return not (state.integrated_call_states or state.segregated_call_states or
+ state.connectivity_due)
+
cdef class Channel:
def __cinit__(
self, bytes target, object arguments,
ChannelCredentials channel_credentials):
arguments = () if arguments is None else tuple(arguments)
- grpc_init()
+ fork_handlers_and_grpc_init()
self._state = _ChannelState()
self._vtable.copy = &_copy_pointer
self._vtable.destroy = &_destroy_pointer
@@ -435,9 +450,14 @@ cdef class Channel:
def next_call_event(self):
def on_success(tag):
- _process_integrated_call_tag(self._state, tag)
- return _next_call_event(
- self._state, self._state.c_call_completion_queue, on_success)
+ if tag is not None:
+ _process_integrated_call_tag(self._state, tag)
+ if is_fork_support_enabled():
+ queue_deadline = time.time() + 1.0
+ else:
+ queue_deadline = None
+ return _next_call_event(self._state, self._state.c_call_completion_queue,
+ on_success, queue_deadline)
def segregated_call(
self, int flags, method, host, object deadline, object metadata,
@@ -452,11 +472,14 @@ cdef class Channel:
return grpc_channel_check_connectivity_state(
self._state.c_channel, try_to_connect)
else:
- raise ValueError('Cannot invoke RPC on closed channel!')
+ raise ValueError('Cannot invoke RPC: %s' % self._state.closed_reason)
def watch_connectivity_state(
self, grpc_connectivity_state last_observed_state, object deadline):
return _watch_connectivity_state(self._state, last_observed_state, deadline)
def close(self, code, details):
- _close(self._state, code, details)
+ _close(self, code, details, False)
+
+ def close_on_fork(self, code, details):
+ _close(self, code, details, True)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
index a2d765546a..141116df5d 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
@@ -71,7 +71,7 @@ cdef class CompletionQueue:
def __cinit__(self, shutdown_cq=False):
cdef grpc_completion_queue_attributes c_attrs
- grpc_init()
+ fork_handlers_and_grpc_init()
if shutdown_cq:
c_attrs.version = 1
c_attrs.cq_completion_type = GRPC_CQ_NEXT
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
index 0a25218e19..e3c1c8215c 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
@@ -21,7 +21,7 @@ from libc.stdint cimport uintptr_t
def _spawn_callback_in_thread(cb_func, args):
- threading.Thread(target=cb_func, args=args).start()
+ ForkManagedThread(target=cb_func, args=args).start()
async_callback_func = _spawn_callback_in_thread
@@ -114,7 +114,7 @@ cdef class ChannelCredentials:
cdef class SSLSessionCacheLRU:
def __cinit__(self, capacity):
- grpc_init()
+ fork_handlers_and_grpc_init()
self._cache = grpc_ssl_session_cache_create_lru(capacity)
def __int__(self):
@@ -172,7 +172,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials):
cdef class ServerCertificateConfig:
def __cinit__(self):
- grpc_init()
+ fork_handlers_and_grpc_init()
self.c_cert_config = NULL
self.c_pem_root_certs = NULL
self.c_ssl_pem_key_cert_pairs = NULL
@@ -187,7 +187,7 @@ cdef class ServerCertificateConfig:
cdef class ServerCredentials:
def __cinit__(self):
- grpc_init()
+ fork_handlers_and_grpc_init()
self.c_credentials = NULL
self.references = []
self.initial_cert_config = None
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi
new file mode 100644
index 0000000000..a925bdd2e6
--- /dev/null
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi
@@ -0,0 +1,29 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+cdef extern from "pthread.h" nogil:
+ int pthread_atfork(
+ void (*prepare)() nogil,
+ void (*parent)() nogil,
+ void (*child)() nogil)
+
+
+cdef void __prefork() nogil
+
+
+cdef void __postfork_parent() nogil
+
+
+cdef void __postfork_child() nogil \ No newline at end of file
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi
new file mode 100644
index 0000000000..1176258da8
--- /dev/null
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi
@@ -0,0 +1,203 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import threading
+
+_LOGGER = logging.getLogger(__name__)
+
+_AWAIT_THREADS_TIMEOUT_SECONDS = 5
+
+_TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1']
+
+# This flag enables experimental support within gRPC Python for applications
+# that will fork() without exec(). When enabled, gRPC Python will attempt to
+# pause all of its internally created threads before the fork syscall proceeds.
+#
+# For this to be successful, the application must not have multiple threads of
+# its own calling into gRPC when fork is invoked. Any callbacks from gRPC
+# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs)
+# must not block and should execute quickly.
+#
+# This flag is not supported on Windows.
+_GRPC_ENABLE_FORK_SUPPORT = (
+ os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0')
+ .lower() in _TRUE_VALUES)
+
+_GRPC_POLL_STRATEGY = os.environ.get('GRPC_POLL_STRATEGY')
+
+cdef void __prefork() nogil:
+ with gil:
+ with _fork_state.fork_in_progress_condition:
+ _fork_state.fork_in_progress = True
+ if not _fork_state.active_thread_count.await_zero_threads(
+ _AWAIT_THREADS_TIMEOUT_SECONDS):
+ _LOGGER.error(
+ 'Failed to shutdown gRPC Python threads prior to fork. '
+ 'Behavior after fork will be undefined.')
+
+
+cdef void __postfork_parent() nogil:
+ with gil:
+ with _fork_state.fork_in_progress_condition:
+ _fork_state.fork_in_progress = False
+ _fork_state.fork_in_progress_condition.notify_all()
+
+
+cdef void __postfork_child() nogil:
+ with gil:
+ # Thread could be holding the fork_in_progress_condition inside of
+ # block_if_fork_in_progress() when fork occurs. Reset the lock here.
+ _fork_state.fork_in_progress_condition = threading.Condition()
+ # A thread in return_from_user_request_generator() may hold this lock
+ # when fork occurs.
+ _fork_state.active_thread_count = _ActiveThreadCount()
+ for state_to_reset in _fork_state.postfork_states_to_reset:
+ state_to_reset.reset_postfork_child()
+ _fork_state.fork_epoch += 1
+ for channel in _fork_state.channels:
+ channel._close_on_fork()
+ # TODO(ericgribkoff) Check and abort if core is not shutdown
+ with _fork_state.fork_in_progress_condition:
+ _fork_state.fork_in_progress = False
+
+
+def fork_handlers_and_grpc_init():
+ grpc_init()
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ # TODO(ericgribkoff) epoll1 is default for grpcio distribution. Decide whether to expose
+ # grpc_get_poll_strategy_name() from ev_posix.cc to get actual polling choice.
+ if _GRPC_POLL_STRATEGY is not None and _GRPC_POLL_STRATEGY != "epoll1":
+ _LOGGER.error(
+ 'gRPC Python fork support is only compatible with the epoll1 '
+ 'polling engine')
+ return
+ with _fork_state.fork_handler_registered_lock:
+ if not _fork_state.fork_handler_registered:
+ pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child)
+ _fork_state.fork_handler_registered = True
+
+
+class ForkManagedThread(object):
+ def __init__(self, target, args=()):
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ def managed_target(*args):
+ try:
+ target(*args)
+ finally:
+ _fork_state.active_thread_count.decrement()
+ self._thread = threading.Thread(target=managed_target, args=args)
+ else:
+ self._thread = threading.Thread(target=target, args=args)
+
+ def setDaemon(self, daemonic):
+ self._thread.daemon = daemonic
+
+ def start(self):
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ _fork_state.active_thread_count.increment()
+ self._thread.start()
+
+ def join(self):
+ self._thread.join()
+
+
+def block_if_fork_in_progress(postfork_state_to_reset=None):
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ with _fork_state.fork_in_progress_condition:
+ if not _fork_state.fork_in_progress:
+ return
+ if postfork_state_to_reset is not None:
+ _fork_state.postfork_states_to_reset.append(postfork_state_to_reset)
+ _fork_state.active_thread_count.decrement()
+ _fork_state.fork_in_progress_condition.wait()
+ _fork_state.active_thread_count.increment()
+
+
+def enter_user_request_generator():
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ _fork_state.active_thread_count.decrement()
+
+
+def return_from_user_request_generator():
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ _fork_state.active_thread_count.increment()
+ block_if_fork_in_progress()
+
+
+def get_fork_epoch():
+ return _fork_state.fork_epoch
+
+
+def is_fork_support_enabled():
+ return _GRPC_ENABLE_FORK_SUPPORT
+
+
+def fork_register_channel(channel):
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ _fork_state.channels.add(channel)
+
+
+def fork_unregister_channel(channel):
+ if _GRPC_ENABLE_FORK_SUPPORT:
+ _fork_state.channels.remove(channel)
+
+
+class _ActiveThreadCount(object):
+ def __init__(self):
+ self._num_active_threads = 0
+ self._condition = threading.Condition()
+
+ def increment(self):
+ with self._condition:
+ self._num_active_threads += 1
+
+ def decrement(self):
+ with self._condition:
+ self._num_active_threads -= 1
+ if self._num_active_threads == 0:
+ self._condition.notify_all()
+
+ def await_zero_threads(self, timeout_secs):
+ end_time = time.time() + timeout_secs
+ wait_time = timeout_secs
+ with self._condition:
+ while True:
+ if self._num_active_threads > 0:
+ self._condition.wait(wait_time)
+ if self._num_active_threads == 0:
+ return True
+ # Thread count may have increased before this re-obtains the
+ # lock after a notify(). Wait again until timeout_secs has
+ # elapsed.
+ wait_time = end_time - time.time()
+ if wait_time <= 0:
+ return False
+
+
+class _ForkState(object):
+ def __init__(self):
+ self.fork_in_progress_condition = threading.Condition()
+ self.fork_in_progress = False
+ self.postfork_states_to_reset = []
+ self.fork_handler_registered_lock = threading.Lock()
+ self.fork_handler_registered = False
+ self.active_thread_count = _ActiveThreadCount()
+ self.fork_epoch = 0
+ self.channels = set()
+
+
+_fork_state = _ForkState()
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi
new file mode 100644
index 0000000000..8dc1ef3b1a
--- /dev/null
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi
@@ -0,0 +1,63 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import threading
+
+# No-op implementations for Windows.
+
+def fork_handlers_and_grpc_init():
+ grpc_init()
+
+
+class ForkManagedThread(object):
+ def __init__(self, target, args=()):
+ self._thread = threading.Thread(target=target, args=args)
+
+ def setDaemon(self, daemonic):
+ self._thread.daemon = daemonic
+
+ def start(self):
+ self._thread.start()
+
+ def join(self):
+ self._thread.join()
+
+
+def block_if_fork_in_progress(postfork_state_to_reset=None):
+ pass
+
+
+def enter_user_request_generator():
+ pass
+
+
+def return_from_user_request_generator():
+ pass
+
+
+def get_fork_epoch():
+ return 0
+
+
+def is_fork_support_enabled():
+ return False
+
+
+def fork_register_channel(channel):
+ pass
+
+
+def fork_unregister_channel(channel):
+ pass
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
index 37b98ebbdb..fe98d559f3 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
@@ -127,7 +127,7 @@ class CompressionLevel:
cdef class CallDetails:
def __cinit__(self):
- grpc_init()
+ fork_handlers_and_grpc_init()
with nogil:
grpc_call_details_init(&self.c_details)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
index da3dd21244..db59d468dc 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
@@ -60,7 +60,7 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp
cdef class Server:
def __cinit__(self, object arguments):
- grpc_init()
+ fork_handlers_and_grpc_init()
self.references = []
self.registered_completion_queues = []
self._vtable.copy = &_copy_pointer
diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd
index 0cc26bc0d0..8258b857bc 100644
--- a/src/python/grpcio/grpc/_cython/cygrpc.pxd
+++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd
@@ -31,3 +31,6 @@ include "_cygrpc/time.pxd.pxi"
include "_cygrpc/_hooks.pxd.pxi"
include "_cygrpc/grpc_gevent.pxd.pxi"
+
+IF UNAME_SYSNAME != "Windows":
+ include "_cygrpc/fork_posix.pxd.pxi"
diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx
index 3cac406687..026f7ba2e3 100644
--- a/src/python/grpcio/grpc/_cython/cygrpc.pyx
+++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx
@@ -39,6 +39,11 @@ include "_cygrpc/_hooks.pyx.pxi"
include "_cygrpc/grpc_gevent.pyx.pxi"
+IF UNAME_SYSNAME == "Windows":
+ include "_cygrpc/fork_windows.pyx.pxi"
+ELSE:
+ include "_cygrpc/fork_posix.pyx.pxi"
+
#
# initialize gRPC
#
diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py
index a23c980017..0dfbf3180b 100644
--- a/src/python/grpcio_tests/commands.py
+++ b/src/python/grpcio_tests/commands.py
@@ -202,3 +202,28 @@ class RunInterop(test.test):
from tests.interop import client
sys.argv[1:] = self.args.split()
client.test_interoperability()
+
+
+class RunFork(test.test):
+
+ description = 'run fork test client'
+ user_options = [('args=', 'a', 'pass-thru arguments for the client')]
+
+ def initialize_options(self):
+ self.args = ''
+
+ def finalize_options(self):
+ # distutils requires this override.
+ pass
+
+ def run(self):
+ if self.distribution.install_requires:
+ self.distribution.fetch_build_eggs(
+ self.distribution.install_requires)
+ if self.distribution.tests_require:
+ self.distribution.fetch_build_eggs(self.distribution.tests_require)
+ # We import here to ensure that our setuptools parent has had a chance to
+ # edit the Python system path.
+ from tests.fork import client
+ sys.argv[1:] = self.args.split()
+ client.test_fork()
diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py
index a94c0963ec..61c98fa038 100644
--- a/src/python/grpcio_tests/setup.py
+++ b/src/python/grpcio_tests/setup.py
@@ -52,6 +52,7 @@ COMMAND_CLASS = {
'preprocess': commands.GatherProto,
'build_package_protos': grpc_tools.command.BuildPackageProtos,
'build_py': commands.BuildPy,
+ 'run_fork': commands.RunFork,
'run_interop': commands.RunInterop,
'test_lite': commands.TestLite,
'test_gevent': commands.TestGevent,
diff --git a/src/python/grpcio_tests/tests/fork/__init__.py b/src/python/grpcio_tests/tests/fork/__init__.py
new file mode 100644
index 0000000000..9a26bac010
--- /dev/null
+++ b/src/python/grpcio_tests/tests/fork/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/python/grpcio_tests/tests/fork/client.py b/src/python/grpcio_tests/tests/fork/client.py
new file mode 100644
index 0000000000..9a32629ed5
--- /dev/null
+++ b/src/python/grpcio_tests/tests/fork/client.py
@@ -0,0 +1,76 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The Python implementation of the GRPC interoperability test client."""
+
+import argparse
+import logging
+import sys
+
+from tests.fork import methods
+
+
+def _args():
+
+ def parse_bool(value):
+ if value == 'true':
+ return True
+ if value == 'false':
+ return False
+ raise argparse.ArgumentTypeError('Only true/false allowed')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--server_host',
+ default="localhost",
+ type=str,
+ help='the host to which to connect')
+ parser.add_argument(
+ '--server_port',
+ type=int,
+ required=True,
+ help='the port to which to connect')
+ parser.add_argument(
+ '--test_case',
+ default='large_unary',
+ type=str,
+ help='the test case to execute')
+ parser.add_argument(
+ '--use_tls',
+ default=False,
+ type=parse_bool,
+ help='require a secure connection')
+ return parser.parse_args()
+
+
+def _test_case_from_arg(test_case_arg):
+ for test_case in methods.TestCase:
+ if test_case_arg == test_case.value:
+ return test_case
+ else:
+ raise ValueError('No test case "%s"!' % test_case_arg)
+
+
+def test_fork():
+ logging.basicConfig(level=logging.INFO)
+ args = _args()
+ if args.test_case == "all":
+ for test_case in methods.TestCase:
+ test_case.run_test(args)
+ else:
+ test_case = _test_case_from_arg(args.test_case)
+ test_case.run_test(args)
+
+
+if __name__ == '__main__':
+ test_fork()
diff --git a/src/python/grpcio_tests/tests/fork/methods.py b/src/python/grpcio_tests/tests/fork/methods.py
new file mode 100644
index 0000000000..889ef13cb2
--- /dev/null
+++ b/src/python/grpcio_tests/tests/fork/methods.py
@@ -0,0 +1,445 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Implementations of fork support test methods."""
+
+import enum
+import json
+import logging
+import multiprocessing
+import os
+import threading
+import time
+
+import grpc
+
+from six.moves import queue
+
+from src.proto.grpc.testing import empty_pb2
+from src.proto.grpc.testing import messages_pb2
+from src.proto.grpc.testing import test_pb2_grpc
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def _channel(args):
+ target = '{}:{}'.format(args.server_host, args.server_port)
+ if args.use_tls:
+ channel_credentials = grpc.ssl_channel_credentials()
+ channel = grpc.secure_channel(target, channel_credentials)
+ else:
+ channel = grpc.insecure_channel(target)
+ return channel
+
+
+def _validate_payload_type_and_length(response, expected_type, expected_length):
+ if response.payload.type is not expected_type:
+ raise ValueError('expected payload type %s, got %s' %
+ (expected_type, type(response.payload.type)))
+ elif len(response.payload.body) != expected_length:
+ raise ValueError('expected payload body size %d, got %d' %
+ (expected_length, len(response.payload.body)))
+
+
+def _async_unary(stub):
+ size = 314159
+ request = messages_pb2.SimpleRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_size=size,
+ payload=messages_pb2.Payload(body=b'\x00' * 271828))
+ response_future = stub.UnaryCall.future(request)
+ response = response_future.result()
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
+
+
+def _blocking_unary(stub):
+ size = 314159
+ request = messages_pb2.SimpleRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_size=size,
+ payload=messages_pb2.Payload(body=b'\x00' * 271828))
+ response = stub.UnaryCall(request)
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
+
+
+class _Pipe(object):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._values = []
+ self._open = True
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ def next(self):
+ with self._condition:
+ while not self._values and self._open:
+ self._condition.wait()
+ if self._values:
+ return self._values.pop(0)
+ else:
+ raise StopIteration()
+
+ def add(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._condition.notify()
+
+ def close(self):
+ with self._condition:
+ self._open = False
+ self._condition.notify()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+
+class _ChildProcess(object):
+
+ def __init__(self, task, args=None):
+ if args is None:
+ args = ()
+ self._exceptions = multiprocessing.Queue()
+
+ def record_exceptions():
+ try:
+ task(*args)
+ except Exception as e: # pylint: disable=broad-except
+ self._exceptions.put(e)
+
+ self._process = multiprocessing.Process(target=record_exceptions)
+
+ def start(self):
+ self._process.start()
+
+ def finish(self):
+ self._process.join()
+ if self._process.exitcode != 0:
+ raise ValueError('Child process failed with exitcode %d' %
+ self._process.exitcode)
+ try:
+ exception = self._exceptions.get(block=False)
+ raise ValueError('Child process failed: %s' % exception)
+ except queue.Empty:
+ pass
+
+
+def _async_unary_same_channel(channel):
+
+ def child_target():
+ try:
+ _async_unary(stub)
+ raise Exception(
+ 'Child should not be able to re-use channel after fork')
+ except ValueError as expected_value_error:
+ pass
+
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _async_unary(stub)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ _async_unary(stub)
+ child_process.finish()
+
+
+def _async_unary_new_channel(channel, args):
+
+ def child_target():
+ child_channel = _channel(args)
+ child_stub = test_pb2_grpc.TestServiceStub(child_channel)
+ _async_unary(child_stub)
+ child_channel.close()
+
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _async_unary(stub)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ _async_unary(stub)
+ child_process.finish()
+
+
+def _blocking_unary_same_channel(channel):
+
+ def child_target():
+ try:
+ _blocking_unary(stub)
+ raise Exception(
+ 'Child should not be able to re-use channel after fork')
+ except ValueError as expected_value_error:
+ pass
+
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _blocking_unary(stub)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ child_process.finish()
+
+
+def _blocking_unary_new_channel(channel, args):
+
+ def child_target():
+ child_channel = _channel(args)
+ child_stub = test_pb2_grpc.TestServiceStub(child_channel)
+ _blocking_unary(child_stub)
+ child_channel.close()
+
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _blocking_unary(stub)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ _blocking_unary(stub)
+ child_process.finish()
+
+
+# Verify that the fork channel registry can handle already closed channels
+def _close_channel_before_fork(channel, args):
+
+ def child_target():
+ new_channel.close()
+ child_channel = _channel(args)
+ child_stub = test_pb2_grpc.TestServiceStub(child_channel)
+ _blocking_unary(child_stub)
+ child_channel.close()
+
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _blocking_unary(stub)
+ channel.close()
+
+ new_channel = _channel(args)
+ new_stub = test_pb2_grpc.TestServiceStub(new_channel)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ _blocking_unary(new_stub)
+ child_process.finish()
+
+
+def _connectivity_watch(channel, args):
+
+ def child_target():
+
+ def child_connectivity_callback(state):
+ child_states.append(state)
+
+ child_states = []
+ child_channel = _channel(args)
+ child_stub = test_pb2_grpc.TestServiceStub(child_channel)
+ child_channel.subscribe(child_connectivity_callback)
+ _async_unary(child_stub)
+ if len(child_states
+ ) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY:
+ raise ValueError('Channel did not move to READY')
+ if len(parent_states) > 1:
+ raise ValueError('Received connectivity updates on parent callback')
+ child_channel.unsubscribe(child_connectivity_callback)
+ child_channel.close()
+
+ def parent_connectivity_callback(state):
+ parent_states.append(state)
+
+ parent_states = []
+ channel.subscribe(parent_connectivity_callback)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ child_process = _ChildProcess(child_target)
+ child_process.start()
+ _async_unary(stub)
+ if len(parent_states
+ ) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY:
+ raise ValueError('Channel did not move to READY')
+ channel.unsubscribe(parent_connectivity_callback)
+ child_process.finish()
+
+ # Need to unsubscribe or _channel.py in _poll_connectivity triggers a
+ # "Cannot invoke RPC on closed channel!" error.
+ # TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling
+ channel.unsubscribe(parent_connectivity_callback)
+
+
+def _ping_pong_with_child_processes_after_first_response(
+ channel, args, child_target, run_after_close=True):
+ request_response_sizes = (
+ 31415,
+ 9,
+ 2653,
+ 58979,
+ )
+ request_payload_sizes = (
+ 27182,
+ 8,
+ 1828,
+ 45904,
+ )
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ pipe = _Pipe()
+ parent_bidi_call = stub.FullDuplexCall(pipe)
+ child_processes = []
+ first_message_received = False
+ for response_size, payload_size in zip(request_response_sizes,
+ request_payload_sizes):
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(
+ messages_pb2.ResponseParameters(size=response_size),),
+ payload=messages_pb2.Payload(body=b'\x00' * payload_size))
+ pipe.add(request)
+ if first_message_received:
+ child_process = _ChildProcess(child_target,
+ (parent_bidi_call, channel, args))
+ child_process.start()
+ child_processes.append(child_process)
+ response = next(parent_bidi_call)
+ first_message_received = True
+ child_process = _ChildProcess(child_target,
+ (parent_bidi_call, channel, args))
+ child_process.start()
+ child_processes.append(child_process)
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
+ response_size)
+ pipe.close()
+ if run_after_close:
+ child_process = _ChildProcess(child_target,
+ (parent_bidi_call, channel, args))
+ child_process.start()
+ child_processes.append(child_process)
+ for child_process in child_processes:
+ child_process.finish()
+
+
+def _in_progress_bidi_continue_call(channel):
+
+ def child_target(parent_bidi_call, parent_channel, args):
+ stub = test_pb2_grpc.TestServiceStub(parent_channel)
+ try:
+ _async_unary(stub)
+ raise Exception(
+ 'Child should not be able to re-use channel after fork')
+ except ValueError as expected_value_error:
+ pass
+ inherited_code = parent_bidi_call.code()
+ inherited_details = parent_bidi_call.details()
+ if inherited_code != grpc.StatusCode.CANCELLED:
+ raise ValueError(
+ 'Expected inherited code CANCELLED, got %s' % inherited_code)
+ if inherited_details != 'Channel closed due to fork':
+ raise ValueError(
+ 'Expected inherited details Channel closed due to fork, got %s'
+ % inherited_details)
+
+ # Don't run child_target after closing the parent call, as the call may have
+ # received a status from the server before fork occurs.
+ _ping_pong_with_child_processes_after_first_response(
+ channel, None, child_target, run_after_close=False)
+
+
+def _in_progress_bidi_same_channel_async_call(channel):
+
+ def child_target(parent_bidi_call, parent_channel, args):
+ stub = test_pb2_grpc.TestServiceStub(parent_channel)
+ try:
+ _async_unary(stub)
+ raise Exception(
+ 'Child should not be able to re-use channel after fork')
+ except ValueError as expected_value_error:
+ pass
+
+ _ping_pong_with_child_processes_after_first_response(
+ channel, None, child_target)
+
+
+def _in_progress_bidi_same_channel_blocking_call(channel):
+
+ def child_target(parent_bidi_call, parent_channel, args):
+ stub = test_pb2_grpc.TestServiceStub(parent_channel)
+ try:
+ _blocking_unary(stub)
+ raise Exception(
+ 'Child should not be able to re-use channel after fork')
+ except ValueError as expected_value_error:
+ pass
+
+ _ping_pong_with_child_processes_after_first_response(
+ channel, None, child_target)
+
+
+def _in_progress_bidi_new_channel_async_call(channel, args):
+
+ def child_target(parent_bidi_call, parent_channel, args):
+ channel = _channel(args)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _async_unary(stub)
+
+ _ping_pong_with_child_processes_after_first_response(
+ channel, args, child_target)
+
+
+def _in_progress_bidi_new_channel_blocking_call(channel, args):
+
+ def child_target(parent_bidi_call, parent_channel, args):
+ channel = _channel(args)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+ _blocking_unary(stub)
+
+ _ping_pong_with_child_processes_after_first_response(
+ channel, args, child_target)
+
+
+@enum.unique
+class TestCase(enum.Enum):
+
+ CONNECTIVITY_WATCH = 'connectivity_watch'
+ CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork'
+ ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel'
+ ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel'
+ BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel'
+ BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel'
+ IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call'
+ IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call'
+ IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call'
+ IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call'
+ IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call'
+
+ def run_test(self, args):
+ _LOGGER.info("Running %s", self)
+ channel = _channel(args)
+ if self is TestCase.ASYNC_UNARY_SAME_CHANNEL:
+ _async_unary_same_channel(channel)
+ elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL:
+ _async_unary_new_channel(channel, args)
+ elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL:
+ _blocking_unary_same_channel(channel)
+ elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL:
+ _blocking_unary_new_channel(channel, args)
+ elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK:
+ _close_channel_before_fork(channel, args)
+ elif self is TestCase.CONNECTIVITY_WATCH:
+ _connectivity_watch(channel, args)
+ elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL:
+ _in_progress_bidi_continue_call(channel)
+ elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL:
+ _in_progress_bidi_same_channel_async_call(channel)
+ elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL:
+ _in_progress_bidi_same_channel_blocking_call(channel)
+ elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL:
+ _in_progress_bidi_new_channel_async_call(channel, args)
+ elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL:
+ _in_progress_bidi_new_channel_blocking_call(channel, args)
+ else:
+ raise NotImplementedError(
+ 'Test case "%s" not implemented!' % self.name)
+ channel.close()
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index ebc41c63f0..76d5d22d57 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -32,6 +32,8 @@
"unit._credentials_test.CredentialsTest",
"unit._cython._cancel_many_calls_test.CancelManyCallsTest",
"unit._cython._channel_test.ChannelTest",
+ "unit._cython._fork_test.ForkPosixTester",
+ "unit._cython._fork_test.ForkWindowsTester",
"unit._cython._no_messages_server_completion_queue_per_call_test.Test",
"unit._cython._no_messages_single_server_completion_queue_test.Test",
"unit._cython._read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest",
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
new file mode 100644
index 0000000000..aeb02458a7
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
@@ -0,0 +1,68 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import threading
+import unittest
+
+from grpc._cython import cygrpc
+
+
+def _get_number_active_threads():
+ return cygrpc._fork_state.active_thread_count._num_active_threads
+
+
+@unittest.skipIf(os.name == 'nt', 'Posix-specific tests')
+class ForkPosixTester(unittest.TestCase):
+
+ def setUp(self):
+ cygrpc._GRPC_ENABLE_FORK_SUPPORT = True
+
+ def testForkManagedThread(self):
+
+ def cb():
+ self.assertEqual(1, _get_number_active_threads())
+
+ thread = cygrpc.ForkManagedThread(cb)
+ thread.start()
+ thread.join()
+ self.assertEqual(0, _get_number_active_threads())
+
+ def testForkManagedThreadThrowsException(self):
+
+ def cb():
+ self.assertEqual(1, _get_number_active_threads())
+ raise Exception("expected exception")
+
+ thread = cygrpc.ForkManagedThread(cb)
+ thread.start()
+ thread.join()
+ self.assertEqual(0, _get_number_active_threads())
+
+
+@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests')
+class ForkWindowsTester(unittest.TestCase):
+
+ def testForkManagedThreadIsNoOp(self):
+
+ def cb():
+ pass
+
+ thread = cygrpc.ForkManagedThread(cb)
+ thread.start()
+ thread.join()
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)