diff options
Diffstat (limited to 'src')
24 files changed, 444 insertions, 148 deletions
diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc index a9a5965ed1..3c4f0d6552 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc @@ -361,7 +361,9 @@ void lb_token_destroy(void* token) { } } int lb_token_cmp(void* token1, void* token2) { - return GPR_ICMP(token1, token2); + // Always indicate a match, since we don't want this channel arg to + // affect the subchannel's key in the index. + return 0; } const grpc_arg_pointer_vtable lb_token_arg_vtable = { lb_token_copy, lb_token_destroy, lb_token_cmp}; @@ -422,7 +424,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { grpc_resolved_address addr; ParseServer(server, &addr); // LB token processing. - void* lb_token; + grpc_mdelem lb_token; if (server->has_load_balance_token) { const size_t lb_token_max_length = GPR_ARRAY_SIZE(server->load_balance_token); @@ -430,9 +432,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { strnlen(server->load_balance_token, lb_token_max_length); grpc_slice lb_token_mdstr = grpc_slice_from_copied_buffer( server->load_balance_token, lb_token_length); - lb_token = - (void*)grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr) - .payload; + lb_token = grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr); } else { char* uri = grpc_sockaddr_to_uri(&addr); gpr_log(GPR_INFO, @@ -440,14 +440,16 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { "be used instead", uri); gpr_free(uri); - lb_token = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload; + lb_token = GRPC_MDELEM_LB_TOKEN_EMPTY; } // Add address. grpc_arg arg = grpc_channel_arg_pointer_create( - const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), lb_token, - &lb_token_arg_vtable); + const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), + (void*)lb_token.payload, &lb_token_arg_vtable); grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); addresses.emplace_back(addr, args); + // Clean up. + GRPC_MDELEM_UNREF(lb_token); } return addresses; } @@ -525,8 +527,7 @@ void GrpcLb::BalancerCallState::Orphan() { void GrpcLb::BalancerCallState::StartQuery() { GPR_ASSERT(lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, - "[grpclb %p] Starting LB call (lb_calld: %p, lb_call: %p)", + gpr_log(GPR_INFO, "[grpclb %p] lb_calld=%p: Starting LB call %p", grpclb_policy_.get(), this, lb_call_); } // Create the ops. @@ -670,8 +671,9 @@ void GrpcLb::BalancerCallState::SendClientLoadReportLocked() { grpc_call_error call_error = grpc_call_start_batch_and_execute( lb_call_, &op, 1, &client_load_report_closure_); if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { - gpr_log(GPR_ERROR, "[grpclb %p] call_error=%d", grpclb_policy_.get(), - call_error); + gpr_log(GPR_ERROR, + "[grpclb %p] lb_calld=%p call_error=%d sending client load report", + grpclb_policy_.get(), this, call_error); GPR_ASSERT(GRPC_CALL_OK == call_error); } } @@ -732,15 +734,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( &initial_response->client_stats_report_interval)); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; " - "client load reporting interval = %" PRId64 " milliseconds", - grpclb_policy, lb_calld->client_stats_report_interval_); + "[grpclb %p] lb_calld=%p: Received initial LB response " + "message; client load reporting interval = %" PRId64 + " milliseconds", + grpclb_policy, lb_calld, + lb_calld->client_stats_report_interval_); } } else if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; client load " - "reporting NOT enabled", - grpclb_policy); + "[grpclb %p] lb_calld=%p: Received initial LB response message; " + "client load reporting NOT enabled", + grpclb_policy, lb_calld); } grpc_grpclb_initial_response_destroy(initial_response); lb_calld->seen_initial_response_ = true; @@ -750,15 +754,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( GPR_ASSERT(lb_calld->lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Serverlist with %" PRIuPTR " servers received", - grpclb_policy, serverlist->num_servers); + "[grpclb %p] lb_calld=%p: Serverlist with %" PRIuPTR + " servers received", + grpclb_policy, lb_calld, serverlist->num_servers); for (size_t i = 0; i < serverlist->num_servers; ++i) { grpc_resolved_address addr; ParseServer(serverlist->servers[i], &addr); char* ipport; grpc_sockaddr_to_string(&ipport, &addr, false); - gpr_log(GPR_INFO, "[grpclb %p] Serverlist[%" PRIuPTR "]: %s", - grpclb_policy, i, ipport); + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Serverlist[%" PRIuPTR "]: %s", + grpclb_policy, lb_calld, i, ipport); gpr_free(ipport); } } @@ -778,9 +784,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( if (grpc_grpclb_serverlist_equals(grpclb_policy->serverlist_, serverlist)) { if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Incoming server list identical to current, " - "ignoring.", - grpclb_policy); + "[grpclb %p] lb_calld=%p: Incoming server list identical to " + "current, ignoring.", + grpclb_policy, lb_calld); } grpc_grpclb_destroy_serverlist(serverlist); } else { // New serverlist. @@ -806,8 +812,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( char* response_slice_str = grpc_dump_slice(response_slice, GPR_DUMP_ASCII | GPR_DUMP_HEX); gpr_log(GPR_ERROR, - "[grpclb %p] Invalid LB response received: '%s'. Ignoring.", - grpclb_policy, response_slice_str); + "[grpclb %p] lb_calld=%p: Invalid LB response received: '%s'. " + "Ignoring.", + grpclb_policy, lb_calld, response_slice_str); gpr_free(response_slice_str); } grpc_slice_unref_internal(response_slice); @@ -838,9 +845,9 @@ void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked( char* status_details = grpc_slice_to_c_string(lb_calld->lb_call_status_details_); gpr_log(GPR_INFO, - "[grpclb %p] Status from LB server received. Status = %d, details " - "= '%s', (lb_calld: %p, lb_call: %p), error '%s'", - grpclb_policy, lb_calld->lb_call_status_, status_details, lb_calld, + "[grpclb %p] lb_calld=%p: Status from LB server received. " + "Status = %d, details = '%s', (lb_call: %p), error '%s'", + grpclb_policy, lb_calld, lb_calld->lb_call_status_, status_details, lb_calld->lb_call_, grpc_error_string(error)); gpr_free(status_details); } @@ -1592,6 +1599,10 @@ void GrpcLb::CreateRoundRobinPolicyLocked(const Args& args) { this); return; } + if (grpc_lb_glb_trace.enabled()) { + gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, + rr_policy_.get()); + } // TODO(roth): We currently track this ref manually. Once the new // ClosureRef API is done, pass the RefCountedPtr<> along with the closure. auto self = Ref(DEBUG_LOCATION, "on_rr_reresolution_requested"); @@ -1685,10 +1696,6 @@ void GrpcLb::CreateOrUpdateRoundRobinPolicyLocked() { lb_policy_args.client_channel_factory = client_channel_factory(); lb_policy_args.args = args; CreateRoundRobinPolicyLocked(lb_policy_args); - if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, - rr_policy_.get()); - } } grpc_channel_args_destroy(args); } diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 6802851392..fdea7511cc 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -85,7 +85,7 @@ class grpc_httpcli_ssl_channel_security_connector final return handshaker_factory_; } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* /*auth_context*/, grpc_closure* on_peer_checked) override { grpc_error* error = GRPC_ERROR_NONE; diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc index cfcb190d60..c268c18664 100644 --- a/src/core/lib/iomgr/tcp_posix.cc +++ b/src/core/lib/iomgr/tcp_posix.cc @@ -126,6 +126,7 @@ struct grpc_tcp { int bytes_counter; bool socket_ts_enabled; /* True if timestamping options are set on the socket */ + bool ts_capable; /* Cache whether we can set timestamping options */ gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified on errors anymore */ @@ -589,7 +590,7 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) { */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, grpc_error** error); + ssize_t* sent_length); /** The callback function to be invoked when we get an error on the socket. */ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); @@ -597,13 +598,11 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); #ifdef GRPC_LINUX_ERRQUEUE static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { if (!tcp->socket_ts_enabled) { uint32_t opt = grpc_core::kTimestampingSocketOptions; if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, static_cast<void*>(&opt), sizeof(opt)) != 0) { - *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "setsockopt"), tcp); grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); if (grpc_tcp_trace.enabled()) { gpr_log(GPR_ERROR, "Failed to set timestamping options on the socket."); @@ -784,8 +783,7 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { #else /* GRPC_LINUX_ERRQUEUE */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); GPR_ASSERT(0); return false; @@ -804,7 +802,7 @@ void tcp_shutdown_buffer_list(grpc_tcp* tcp) { gpr_mu_lock(&tcp->tb_mu); grpc_core::TracedBuffer::Shutdown( &tcp->tb_head, tcp->outgoing_buffer_arg, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed")); + GRPC_ERROR_CREATE_FROM_STATIC_STRING("TracedBuffer list shutdown")); gpr_mu_unlock(&tcp->tb_mu); tcp->outgoing_buffer_arg = nullptr; } @@ -820,7 +818,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { struct msghdr msg; struct iovec iov[MAX_WRITE_IOVEC]; msg_iovlen_type iov_size; - ssize_t sent_length; + ssize_t sent_length = 0; size_t sending_length; size_t trailing; size_t unwind_slice_idx; @@ -855,13 +853,19 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { msg.msg_iov = iov; msg.msg_iovlen = iov_size; msg.msg_flags = 0; + bool tried_sending_message = false; if (tcp->outgoing_buffer_arg != nullptr) { - if (!tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length, - error)) { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; tcp_shutdown_buffer_list(tcp); - return true; /* something went wrong with timestamps */ + } else { + tried_sending_message = true; } - } else { + } + if (!tried_sending_message) { msg.msg_control = nullptr; msg.msg_controllen = 0; @@ -1117,6 +1121,7 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, tcp->is_first_read = true; tcp->bytes_counter = -1; tcp->socket_ts_enabled = false; + tcp->ts_capable = true; tcp->outgoing_buffer_arg = nullptr; /* paired with unref in grpc_tcp_destroy */ gpr_ref_init(&tcp->refcount, 1); diff --git a/src/core/lib/security/credentials/jwt/jwt_verifier.cc b/src/core/lib/security/credentials/jwt/jwt_verifier.cc index c7d1b36ff0..cdef0f322a 100644 --- a/src/core/lib/security/credentials/jwt/jwt_verifier.cc +++ b/src/core/lib/security/credentials/jwt/jwt_verifier.cc @@ -31,7 +31,9 @@ #include <grpc/support/sync.h> extern "C" { +#include <openssl/bn.h> #include <openssl/pem.h> +#include <openssl/rsa.h> } #include "src/core/lib/gpr/string.h" diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index 6db70ef172..3ad0cc353c 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -48,7 +48,7 @@ void alts_set_rpc_protocol_versions( GRPC_PROTOCOL_VERSION_MIN_MINOR); } -void atls_check_peer(tsi_peer peer, +void alts_check_peer(tsi_peer peer, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) { *auth_context = @@ -93,10 +93,10 @@ class grpc_alts_channel_security_connector final handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { - atls_check_peer(peer, auth_context, on_peer_checked); + alts_check_peer(peer, auth_context, on_peer_checked); } int cmp(const grpc_security_connector* other_sc) const override { @@ -151,10 +151,10 @@ class grpc_alts_server_security_connector final handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { - atls_check_peer(peer, auth_context, on_peer_checked); + alts_check_peer(peer, auth_context, on_peer_checked); } int cmp(const grpc_security_connector* other) const override { diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc index d2cdaaac77..e3b8affb36 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -71,7 +71,7 @@ class grpc_fake_channel_security_connector final if (target_name_override_ != nullptr) gpr_free(target_name_override_); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override; @@ -250,7 +250,8 @@ end: } void grpc_fake_channel_security_connector::check_peer( - tsi_peer peer, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) { fake_check_peer(this, peer, auth_context, on_peer_checked); fake_secure_name_check(); @@ -265,7 +266,7 @@ class grpc_fake_server_security_connector std::move(server_creds)) {} ~grpc_fake_server_security_connector() override = default; - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { fake_check_peer(this, peer, auth_context, on_peer_checked); diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc index 7a59e54e9a..7cc482c16c 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.cc +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -32,12 +32,16 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr_utils.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" #include "src/core/lib/security/credentials/local/local_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/tsi/local_transport_security.h" #define GRPC_UDS_URI_PATTERN "unix:" -#define GRPC_UDS_URL_SCHEME "unix" #define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local" namespace { @@ -55,18 +59,59 @@ grpc_core::RefCountedPtr<grpc_auth_context> local_auth_context_create() { } void local_check_peer(grpc_security_connector* sc, tsi_peer peer, + grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, - grpc_closure* on_peer_checked) { + grpc_closure* on_peer_checked, + grpc_local_connect_type type) { + int fd = grpc_endpoint_get_fd(ep); + grpc_resolved_address resolved_addr; + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = GRPC_MAX_SOCKADDR_SIZE; + bool is_endpoint_local = false; + if (getsockname(fd, reinterpret_cast<grpc_sockaddr*>(resolved_addr.addr), + &resolved_addr.len) == 0) { + grpc_resolved_address addr_normalized; + grpc_resolved_address* addr = + grpc_sockaddr_is_v4mapped(&resolved_addr, &addr_normalized) + ? &addr_normalized + : &resolved_addr; + grpc_sockaddr* sock_addr = reinterpret_cast<grpc_sockaddr*>(&addr->addr); + // UDS + if (type == UDS && grpc_is_unix_socket(addr)) { + is_endpoint_local = true; + // IPV4 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast<const grpc_sockaddr_in*>(sock_addr); + if (grpc_htonl(addr4->sin_addr.s_addr) == INADDR_LOOPBACK) { + is_endpoint_local = true; + } + // IPv6 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast<const grpc_sockaddr_in6*>(addr); + if (memcmp(&addr6->sin6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)) == 0) { + is_endpoint_local = true; + } + } + } + grpc_error* error = GRPC_ERROR_NONE; + if (!is_endpoint_local) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Endpoint is neither UDS or TCP loopback address."); + GRPC_CLOSURE_SCHED(on_peer_checked, error); + return; + } /* Create an auth context which is necessary to pass the santiy check in - * {client, server}_auth_filter that verifies if the pepp's auth context is + * {client, server}_auth_filter that verifies if the peer's auth context is * obtained during handshakes. The auth context is only checked for its * existence and not actually used. */ *auth_context = local_auth_context_create(); - grpc_error* error = *auth_context != nullptr - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not create local auth context"); + error = *auth_context != nullptr ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not create local auth context"); GRPC_CLOSURE_SCHED(on_peer_checked, error); } @@ -77,8 +122,7 @@ class grpc_local_channel_security_connector final grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, const char* target_name) - : grpc_channel_security_connector(GRPC_UDS_URL_SCHEME, - std::move(channel_creds), + : grpc_channel_security_connector(nullptr, std::move(channel_creds), std::move(request_metadata_creds)), target_name_(gpr_strdup(target_name)) {} @@ -102,10 +146,13 @@ class grpc_local_channel_security_connector final return strcmp(target_name_, other->target_name_); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { - local_check_peer(this, peer, auth_context, on_peer_checked); + grpc_local_credentials* creds = + reinterpret_cast<grpc_local_credentials*>(mutable_channel_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); } bool check_call_host(const char* host, grpc_auth_context* auth_context, @@ -134,8 +181,7 @@ class grpc_local_server_security_connector final public: grpc_local_server_security_connector( grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) - : grpc_server_security_connector(GRPC_UDS_URL_SCHEME, - std::move(server_creds)) {} + : grpc_server_security_connector(nullptr, std::move(server_creds)) {} ~grpc_local_server_security_connector() override = default; void add_handshakers(grpc_pollset_set* interested_parties, @@ -147,10 +193,13 @@ class grpc_local_server_security_connector final handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { - local_check_peer(this, peer, auth_context, on_peer_checked); + grpc_local_server_credentials* creds = + static_cast<grpc_local_server_credentials*>(mutable_server_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); } int cmp(const grpc_security_connector* other) const override { @@ -171,23 +220,18 @@ grpc_local_channel_security_connector_create( "Invalid arguments to grpc_local_channel_security_connector_create()"); return nullptr; } - // Check if local_connect_type is UDS. Only UDS is supported for now. + // Perform sanity check on UDS address. For TCP local connection, the check + // will be done during check_peer procedure. grpc_local_credentials* creds = static_cast<grpc_local_credentials*>(channel_creds.get()); - if (creds->connect_type() != UDS) { - gpr_log(GPR_ERROR, - "Invalid local channel type to " - "grpc_local_channel_security_connector_create()"); - return nullptr; - } - // Check if target_name is a valid UDS address. const grpc_arg* server_uri_arg = grpc_channel_args_find(args, GRPC_ARG_SERVER_URI); const char* server_uri_str = grpc_channel_arg_get_string(server_uri_arg); - if (strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, + if (creds->connect_type() == UDS && + strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, strlen(GRPC_UDS_URI_PATTERN)) != 0) { gpr_log(GPR_ERROR, - "Invalid target_name to " + "Invalid UDS target name to " "grpc_local_channel_security_connector_create()"); return nullptr; } @@ -204,15 +248,6 @@ grpc_local_server_security_connector_create( "Invalid arguments to grpc_local_server_security_connector_create()"); return nullptr; } - // Check if local_connect_type is UDS. Only UDS is supported for now. - const grpc_local_server_credentials* creds = - static_cast<const grpc_local_server_credentials*>(server_creds.get()); - if (creds->connect_type() != UDS) { - gpr_log(GPR_ERROR, - "Invalid local server type to " - "grpc_local_server_security_connector_create()"); - return nullptr; - } return grpc_core::MakeRefCounted<grpc_local_server_security_connector>( std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index d90aa8c4da..74b0ef21a6 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -56,7 +56,8 @@ class grpc_security_connector /* Check the peer. Callee takes ownership of the peer object. When done, sets *auth_context and invokes on_peer_checked. */ virtual void check_peer( - tsi_peer peer, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) GRPC_ABSTRACT; /* Compares two security connectors. */ diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 14b2c4030f..7414ab1a37 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -146,7 +146,7 @@ class grpc_ssl_channel_security_connector final grpc_security_handshaker_create(tsi_hs, this)); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { const char* target_name = overridden_target_name_ != nullptr @@ -299,7 +299,7 @@ class grpc_ssl_server_security_connector grpc_security_handshaker_create(tsi_hs, this)); } - void check_peer(tsi_peer peer, + void check_peer(tsi_peer peer, grpc_endpoint* ep, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, grpc_closure* on_peer_checked) override { grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context); diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 48d6901e88..01831dab10 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -231,7 +231,8 @@ static grpc_error* check_peer_locked(security_handshaker* h) { return grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); } - h->connector->check_peer(peer, &h->auth_context, &h->on_peer_checked); + h->connector->check_peer(peer, h->args->endpoint, &h->auth_context, + &h->on_peer_checked); return GRPC_ERROR_NONE; } diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index efaf733503..fb6ea19210 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -156,9 +156,13 @@ static unsigned long openssl_thread_id_cb(void) { #endif static void init_openssl(void) { +#if OPENSSL_API_COMPAT >= 0x10100000L + OPENSSL_init_ssl(0, NULL); +#else SSL_library_init(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); +#endif #if OPENSSL_VERSION_NUMBER < 0x10100000 if (!CRYPTO_get_locking_callback()) { int num_locks = CRYPTO_num_locks(); @@ -1649,7 +1653,11 @@ tsi_result tsi_create_ssl_client_handshaker_factory_with_options( return TSI_INVALID_ARGUMENT; } +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + ssl_context = SSL_CTX_new(TLS_method()); +#else ssl_context = SSL_CTX_new(TLSv1_2_method()); +#endif if (ssl_context == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); return TSI_INVALID_ARGUMENT; @@ -1806,7 +1814,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options( for (i = 0; i < options->num_key_cert_pairs; i++) { do { +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + impl->ssl_contexts[i] = SSL_CTX_new(TLS_method()); +#else impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method()); +#endif if (impl->ssl_contexts[i] == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); result = TSI_OUT_OF_RESOURCES; diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc index 78fc69a805..160590353a 100644 --- a/src/cpp/ext/filters/census/context.cc +++ b/src/cpp/ext/filters/census/context.cc @@ -28,6 +28,9 @@ using ::opencensus::trace::SpanContext; void GenerateServerContext(absl::string_view tracing, absl::string_view stats, absl::string_view primary_role, absl::string_view method, CensusContext* context) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + context->~CensusContext(); GrpcTraceContext trace_ctxt; if (TraceContextEncoding::Decode(tracing, &trace_ctxt) != TraceContextEncoding::kEncodeDecodeFailure) { @@ -42,6 +45,9 @@ void GenerateServerContext(absl::string_view tracing, absl::string_view stats, void GenerateClientContext(absl::string_view method, CensusContext* ctxt, CensusContext* parent_ctxt) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + ctxt->~CensusContext(); if (parent_ctxt != nullptr) { SpanContext span_ctxt = parent_ctxt->Context(); Span span = parent_ctxt->Span(); diff --git a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets index 5f76c03ce5..3fe1ccc918 100644 --- a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets @@ -22,9 +22,8 @@ <Target Name="gRPC_ResolvePluginFullPath" AfterTargets="Protobuf_ResolvePlatform"> <PropertyGroup> <!-- TODO(kkm): Do not use Protobuf_PackagedToolsPath, roll gRPC's own. --> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/$(gRPC_PluginFileName)</gRPC_PluginFullPath> </PropertyGroup> diff --git a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets index 1d233d23a8..26f9efb5a8 100644 --- a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets @@ -74,9 +74,8 @@ <!-- Next try OS and CPU resolved by ProtoToolsPlatform. --> <Protobuf_ToolsOs Condition=" '$(Protobuf_ToolsOs)' == '' ">$(_Protobuf_ToolsOs)</Protobuf_ToolsOs> <Protobuf_ToolsCpu Condition=" '$(Protobuf_ToolsCpu)' == '' ">$(_Protobuf_ToolsCpu)</Protobuf_ToolsCpu> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\protoc.exe</Protobuf_ProtocFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\protoc.exe</Protobuf_ProtocFullPath> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/protoc</Protobuf_ProtocFullPath> </PropertyGroup> diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 141116df5d..3c33b46dbb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -49,7 +49,7 @@ cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline): cdef _interpret_event(grpc_event c_event): cdef _Tag tag if c_event.type == GRPC_QUEUE_TIMEOUT: - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + # TODO(ericgribkoff) Do not coopt ConnectivityEvent here. return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) elif c_event.type == GRPC_QUEUE_SHUTDOWN: # NOTE(nathaniel): For now we coopt ConnectivityEvent here. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index ce701724fd..e89e02b171 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -128,7 +128,10 @@ cdef class Server: with nogil: grpc_server_cancel_all_calls(self.c_server) - def __dealloc__(self): + # TODO(https://github.com/grpc/grpc/issues/17515) Determine what, if any, + # portion of this is safe to call from __dealloc__, and potentially remove + # backup_shutdown_queue. + def destroy(self): if self.c_server != NULL: if not self.is_started: pass @@ -146,4 +149,8 @@ cdef class Server: while not self.is_shutdown: time.sleep(0) grpc_server_destroy(self.c_server) - grpc_shutdown() + self.c_server = NULL + + def __dealloc(self): + if self.c_server == NULL: + grpc_shutdown() diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 3bbfa47da5..eb750ef1a8 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -48,7 +48,7 @@ _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 -_UNEXPECTED_EXIT_SERVER_GRACE = 1.0 +_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 def _serialized_request(request_event): @@ -676,6 +676,9 @@ class _ServerState(object): self.rpc_states = set() self.due = set() + # A "volatile" flag to interrupt the daemon serving thread + self.server_deallocated = False + def _add_generic_handlers(state, generic_handlers): with state.lock: @@ -702,6 +705,7 @@ def _request_call(state): # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. def _stop_serving(state): if not state.rpc_states and not state.due: + state.server.destroy() for shutdown_event in state.shutdown_events: shutdown_event.set() state.stage = _ServerStage.STOPPED @@ -715,49 +719,69 @@ def _on_call_completed(state): state.active_rpc_count -= 1 -def _serve(state): - while True: - event = state.completion_queue.poll() - if event.tag is _SHUTDOWN_TAG: +def _process_event_and_continue(state, event): + should_continue = True + if event.tag is _SHUTDOWN_TAG: + with state.lock: + state.due.remove(_SHUTDOWN_TAG) + if _stop_serving(state): + should_continue = False + elif event.tag is _REQUEST_CALL_TAG: + with state.lock: + state.due.remove(_REQUEST_CALL_TAG) + concurrency_exceeded = ( + state.maximum_concurrent_rpcs is not None and + state.active_rpc_count >= state.maximum_concurrent_rpcs) + rpc_state, rpc_future = _handle_call( + event, state.generic_handlers, state.interceptor_pipeline, + state.thread_pool, concurrency_exceeded) + if rpc_state is not None: + state.rpc_states.add(rpc_state) + if rpc_future is not None: + state.active_rpc_count += 1 + rpc_future.add_done_callback( + lambda unused_future: _on_call_completed(state)) + if state.stage is _ServerStage.STARTED: + _request_call(state) + elif _stop_serving(state): + should_continue = False + else: + rpc_state, callbacks = event.tag(event) + for callback in callbacks: + callable_util.call_logging_exceptions(callback, + 'Exception calling callback!') + if rpc_state is not None: with state.lock: - state.due.remove(_SHUTDOWN_TAG) + state.rpc_states.remove(rpc_state) if _stop_serving(state): - return - elif event.tag is _REQUEST_CALL_TAG: - with state.lock: - state.due.remove(_REQUEST_CALL_TAG) - concurrency_exceeded = ( - state.maximum_concurrent_rpcs is not None and - state.active_rpc_count >= state.maximum_concurrent_rpcs) - rpc_state, rpc_future = _handle_call( - event, state.generic_handlers, state.interceptor_pipeline, - state.thread_pool, concurrency_exceeded) - if rpc_state is not None: - state.rpc_states.add(rpc_state) - if rpc_future is not None: - state.active_rpc_count += 1 - rpc_future.add_done_callback( - lambda unused_future: _on_call_completed(state)) - if state.stage is _ServerStage.STARTED: - _request_call(state) - elif _stop_serving(state): - return - else: - rpc_state, callbacks = event.tag(event) - for callback in callbacks: - callable_util.call_logging_exceptions( - callback, 'Exception calling callback!') - if rpc_state is not None: - with state.lock: - state.rpc_states.remove(rpc_state) - if _stop_serving(state): - return + should_continue = False + return should_continue + + +def _serve(state): + while True: + timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S + event = state.completion_queue.poll(timeout) + if state.server_deallocated: + _begin_shutdown_once(state) + if event.completion_type != cygrpc.CompletionType.queue_timeout: + if not _process_event_and_continue(state, event): + return # We want to force the deletion of the previous event # ~before~ we poll again; if the event has a reference # to a shutdown Call object, this can induce spinlock. event = None +def _begin_shutdown_once(state): + with state.lock: + if state.stage is _ServerStage.STARTED: + state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) + state.stage = _ServerStage.GRACE + state.shutdown_events = [] + state.due.add(_SHUTDOWN_TAG) + + def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: @@ -765,11 +789,7 @@ def _stop(state, grace): shutdown_event.set() return shutdown_event else: - if state.stage is _ServerStage.STARTED: - state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) - state.stage = _ServerStage.GRACE - state.shutdown_events = [] - state.due.add(_SHUTDOWN_TAG) + _begin_shutdown_once(state) shutdown_event = threading.Event() state.shutdown_events.append(shutdown_event) if grace is None: @@ -840,7 +860,9 @@ class _Server(grpc.Server): return _stop(self._state, grace) def __del__(self): - _stop(self._state, None) + # We can not grab a lock in __del__(), so set a flag to signal the + # serving daemon thread (if it exists) to initiate shutdown. + self._state.server_deallocated = True def create_server(thread_pool, generic_rpc_handlers, interceptors, options, diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index 18413abab0..496bcfbcbf 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -22,7 +22,6 @@ import re import shutil import subprocess import sys -import traceback import setuptools from setuptools.command import build_ext diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index f9cb9d0cec..800b865da6 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -37,13 +37,19 @@ PACKAGE_DIRECTORIES = { } INSTALL_REQUIRES = ( - 'coverage>=4.0', 'enum34>=1.0.4', + 'coverage>=4.0', + 'enum34>=1.0.4', 'grpcio>={version}'.format(version=grpc_version.VERSION), - 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), + # TODO(https://github.com/pypa/warehouse/issues/5196) + # Re-enable it once we got the name back + # 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), 'grpcio-status>={version}'.format(version=grpc_version.VERSION), 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), 'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION), - 'oauth2client>=1.4.7', 'protobuf>=3.6.0', 'six>=1.10', 'google-auth>=1.0.0', + 'oauth2client>=1.4.7', + 'protobuf>=3.6.0', + 'six>=1.10', + 'google-auth>=1.0.0', 'requests>=2.14.2') if not PY3: diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py index 8ca5189522..c63ff5cd84 100644 --- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -88,11 +88,10 @@ def _generate_channel_server_pairs(n): def _close_channel_server_pairs(pairs): for pair in pairs: pair.server.stop(None) - # TODO(ericgribkoff) This del should not be required - del pair.server pair.channel.close() +@unittest.skip('https://github.com/pypa/warehouse/issues/5196') class ChannelzServicerTest(unittest.TestCase): def _send_successful_unary_unary(self, idx): diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index b27e6f2693..f202a3f932 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -57,6 +57,7 @@ "unit._reconnect_test.ReconnectTest", "unit._resource_exhausted_test.ResourceExhaustedTest", "unit._rpc_test.RPCTest", + "unit._server_shutdown_test.ServerShutdown", "unit._server_ssl_cert_config_test.ServerSSLCertConfigFetcherParamsChecks", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth", diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index 4f850220f8..1b462ec67a 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -28,6 +28,7 @@ GRPCIO_TESTS_UNIT = [ # TODO(ghostwriternr): To be added later. # "_server_ssl_cert_config_test.py", "_server_test.py", + "_server_shutdown_test.py", "_session_cache_test.py", ] @@ -50,6 +51,11 @@ py_library( ) py_library( + name = "_server_shutdown_scenarios", + srcs = ["_server_shutdown_scenarios.py"], +) + +py_library( name = "_thread_pool", srcs = ["_thread_pool.py"], ) @@ -70,6 +76,7 @@ py_library( ":resources", ":test_common", ":_exit_scenarios", + ":_server_shutdown_scenarios", ":_thread_pool", ":_from_grpc_import_star", "//src/python/grpcio_tests/tests/unit/framework/common", diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py new file mode 100644 index 0000000000..1d1fdba11e --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -0,0 +1,97 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines a number of module-scope gRPC scenarios to test server shutdown.""" + +import argparse +import os +import threading +import time +import logging + +import grpc +from tests.unit import test_common + +from concurrent import futures +from six.moves import queue + +WAIT_TIME = 1000 + +REQUEST = b'request' +RESPONSE = b'response' + +SERVER_RAISES_EXCEPTION = 'server_raises_exception' +SERVER_DEALLOCATED = 'server_deallocated' +SERVER_FORK_CAN_EXIT = 'server_fork_can_exit' + +FORK_EXIT = '/test/ForkExit' + + +def fork_and_exit(request, servicer_context): + pid = os.fork() + if pid == 0: + os._exit(0) + return RESPONSE + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == FORK_EXIT: + return grpc.unary_unary_rpc_method_handler(fork_and_exit) + else: + return None + + +def run_server(port_queue): + server = test_common.test_server() + port = server.add_insecure_port('[::]:0') + port_queue.put(port) + server.add_generic_rpc_handlers((GenericHandler(),)) + server.start() + # threading.Event.wait() does not exhibit the bug identified in + # https://github.com/grpc/grpc/issues/17093, sleep instead + time.sleep(WAIT_TIME) + + +def run_test(args): + if args.scenario == SERVER_RAISES_EXCEPTION: + server = test_common.test_server() + server.start() + raise Exception() + elif args.scenario == SERVER_DEALLOCATED: + server = test_common.test_server() + server.start() + server.__del__() + while server._state.stage != grpc._server._ServerStage.STOPPED: + pass + elif args.scenario == SERVER_FORK_CAN_EXIT: + port_queue = queue.Queue() + thread = threading.Thread(target=run_server, args=(port_queue,)) + thread.daemon = True + thread.start() + port = port_queue.get() + channel = grpc.insecure_channel('localhost:%d' % port) + multi_callable = channel.unary_unary(FORK_EXIT) + result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) + os.wait() + else: + raise ValueError('unknown test scenario') + + +if __name__ == '__main__': + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + args = parser.parse_args() + run_test(args) diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py new file mode 100644 index 0000000000..47446d65a5 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests clean shutdown of server on various interpreter exit conditions. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import subprocess +import sys +import threading +import unittest +import logging + +from tests.unit import _server_shutdown_scenarios + +SCENARIO_FILE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '_server_shutdown_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: # pylint: disable=broad-except + pass + + +atexit.register(cleanup_processes) + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ServerShutdown(unittest.TestCase): + + # Currently we shut down a server (if possible) after the Python server + # instance is garbage collected. This behavior may change in the future. + def test_deallocated_server_stops(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + def test_server_exception_exits(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + @unittest.skipIf(os.name == 'nt', 'fork not supported on windows') + def test_server_fork_can_exit(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) |