diff options
Diffstat (limited to 'src/core')
42 files changed, 1949 insertions, 237 deletions
diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index bbc5160bec..67dd3a1fa7 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -38,12 +38,12 @@ #include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" #include "src/core/ext/filters/client_channel/retry_throttle.h" -#include "src/core/ext/filters/client_channel/status_util.h" #include "src/core/ext/filters/client_channel/subchannel.h" #include "src/core/ext/filters/deadline/deadline_filter.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/gprpp/inlined_vector.h" #include "src/core/lib/gprpp/manual_constructor.h" @@ -303,11 +303,16 @@ static void request_reresolution_locked(void* arg, grpc_error* error) { chand->lb_policy->SetReresolutionClosureLocked(&args->closure); } +// TODO(roth): The logic in this function is very hard to follow. We +// should refactor this so that it's easier to understand, perhaps as +// part of changing the resolver API to more clearly differentiate +// between transient failures and shutdown. static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { channel_data* chand = static_cast<channel_data*>(arg); if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_DEBUG, "chand=%p: got resolver result: error=%s", chand, - grpc_error_string(error)); + gpr_log(GPR_DEBUG, + "chand=%p: got resolver result: resolver_result=%p error=%s", chand, + chand->resolver_result, grpc_error_string(error)); } // Extract the following fields from the resolver result, if non-nullptr. bool lb_policy_updated = false; @@ -423,8 +428,6 @@ static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { } } } - grpc_channel_args_destroy(chand->resolver_result); - chand->resolver_result = nullptr; } if (grpc_client_channel_trace.enabled()) { gpr_log(GPR_DEBUG, @@ -497,6 +500,8 @@ static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { "Channel disconnected", &error, 1)); GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "resolver"); + grpc_channel_args_destroy(chand->resolver_result); + chand->resolver_result = nullptr; } else { // Not shutting down. grpc_connectivity_state state = GRPC_CHANNEL_TRANSIENT_FAILURE; grpc_error* state_error = @@ -515,11 +520,16 @@ static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { chand->exit_idle_when_lb_policy_arrives = false; } watch_lb_policy_locked(chand, chand->lb_policy.get(), state); + } else if (chand->resolver_result == nullptr) { + // Transient failure. + GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); } if (!lb_policy_updated) { set_channel_connectivity_state_locked( chand, state, GRPC_ERROR_REF(state_error), "new_lb+resolver"); } + grpc_channel_args_destroy(chand->resolver_result); + chand->resolver_result = nullptr; chand->resolver->NextLocked(&chand->resolver_result, &chand->on_resolver_result_changed); GRPC_ERROR_UNREF(state_error); @@ -2753,7 +2763,45 @@ static void pick_after_resolver_result_done_locked(void* arg, chand, calld); } async_pick_done_locked(elem, GRPC_ERROR_REF(error)); - } else if (chand->lb_policy != nullptr) { + } else if (chand->resolver == nullptr) { + // Shutting down. + if (grpc_client_channel_trace.enabled()) { + gpr_log(GPR_DEBUG, "chand=%p calld=%p: resolver disconnected", chand, + calld); + } + async_pick_done_locked( + elem, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); + } else if (chand->lb_policy == nullptr) { + // Transient resolver failure. + // If call has wait_for_ready=true, try again; otherwise, fail. + uint32_t send_initial_metadata_flags = + calld->seen_send_initial_metadata + ? calld->send_initial_metadata_flags + : calld->pending_batches[0] + .batch->payload->send_initial_metadata + .send_initial_metadata_flags; + if (send_initial_metadata_flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) { + if (grpc_client_channel_trace.enabled()) { + gpr_log(GPR_DEBUG, + "chand=%p calld=%p: resolver returned but no LB policy; " + "wait_for_ready=true; trying again", + chand, calld); + } + pick_after_resolver_result_start_locked(elem); + } else { + if (grpc_client_channel_trace.enabled()) { + gpr_log(GPR_DEBUG, + "chand=%p calld=%p: resolver returned but no LB policy; " + "wait_for_ready=false; failing", + chand, calld); + } + async_pick_done_locked( + elem, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } else { if (grpc_client_channel_trace.enabled()) { gpr_log(GPR_DEBUG, "chand=%p calld=%p: resolver returned, doing pick", chand, calld); @@ -2767,30 +2815,6 @@ static void pick_after_resolver_result_done_locked(void* arg, async_pick_done_locked(elem, GRPC_ERROR_NONE); } } - // TODO(roth): It should be impossible for chand->lb_policy to be nullptr - // here, so the rest of this code should never actually be executed. - // However, we have reports of a crash on iOS that triggers this case, - // so we are temporarily adding this to restore branches that were - // removed in https://github.com/grpc/grpc/pull/12297. Need to figure - // out what is actually causing this to occur and then figure out the - // right way to deal with it. - else if (chand->resolver != nullptr) { - // No LB policy, so try again. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_DEBUG, - "chand=%p calld=%p: resolver returned but no LB policy, " - "trying again", - chand, calld); - } - pick_after_resolver_result_start_locked(elem); - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_DEBUG, "chand=%p calld=%p: resolver disconnected", chand, - calld); - } - async_pick_done_locked( - elem, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); - } } static void pick_after_resolver_result_start_locked(grpc_call_element* elem) { diff --git a/src/core/ext/filters/client_channel/method_params.cc b/src/core/ext/filters/client_channel/method_params.cc index 374b87e170..1f116bb67d 100644 --- a/src/core/ext/filters/client_channel/method_params.cc +++ b/src/core/ext/filters/client_channel/method_params.cc @@ -26,7 +26,7 @@ #include <grpc/support/string_util.h> #include "src/core/ext/filters/client_channel/method_params.h" -#include "src/core/ext/filters/client_channel/status_util.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/gprpp/memory.h" diff --git a/src/core/ext/filters/client_channel/method_params.h b/src/core/ext/filters/client_channel/method_params.h index 48ece29867..099924edf3 100644 --- a/src/core/ext/filters/client_channel/method_params.h +++ b/src/core/ext/filters/client_channel/method_params.h @@ -21,7 +21,7 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/status_util.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/exec_ctx.h" // for grpc_millis diff --git a/src/core/ext/filters/client_channel/resolver.h b/src/core/ext/filters/client_channel/resolver.h index 1685a6c803..cdb5a20ea3 100644 --- a/src/core/ext/filters/client_channel/resolver.h +++ b/src/core/ext/filters/client_channel/resolver.h @@ -53,8 +53,12 @@ class Resolver : public InternallyRefCountedWithTracing<Resolver> { /// Requests a callback when a new result becomes available. /// When the new result is available, sets \a *result to the new result /// and schedules \a on_complete for execution. + /// Upon transient failure, sets \a *result to nullptr and schedules + /// \a on_complete with no error. /// If resolution is fatally broken, sets \a *result to nullptr and /// schedules \a on_complete with an error. + /// TODO(roth): When we have time, improve the way this API represents + /// transient failure vs. shutdown. /// /// Note that the client channel will almost always have a request /// to \a NextLocked() pending. When it gets the callback, it will diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc index c63de3c509..aca7307a2f 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc @@ -28,6 +28,8 @@ #include <grpc/support/alloc.h> #include <grpc/support/string_util.h> +#include <address_sorting/address_sorting.h> + #include "src/core/ext/filters/client_channel/http_connect_handshaker.h" #include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" @@ -458,6 +460,7 @@ void grpc_resolver_dns_ares_init() { /* TODO(zyc): Turn on c-ares based resolver by default after the address sorter and the CNAME support are added. */ if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + address_sorting_init(); grpc_error* error = grpc_ares_init(); if (error != GRPC_ERROR_NONE) { GRPC_LOG_IF_ERROR("ares_library_init() failed", error); @@ -475,6 +478,7 @@ void grpc_resolver_dns_ares_init() { void grpc_resolver_dns_ares_shutdown() { char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER"); if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + address_sorting_shutdown(); grpc_ares_cleanup(); } gpr_free(resolver_env); diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc index 71b06eb87e..fb2435749d 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc @@ -33,6 +33,7 @@ #include <grpc/support/string_util.h> #include <grpc/support/time.h> +#include <address_sorting/address_sorting.h> #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" #include "src/core/lib/gpr/host_port.h" @@ -46,6 +47,9 @@ static gpr_once g_basic_init = GPR_ONCE_INIT; static gpr_mu g_init_mu; +grpc_core::TraceFlag grpc_trace_cares_address_sorting(false, + "cares_address_sorting"); + struct grpc_ares_request { /** indicates the DNS server to use, if specified */ struct ares_addr_port_node dns_server_addr; @@ -96,11 +100,60 @@ static void grpc_ares_request_ref(grpc_ares_request* r) { gpr_ref(&r->pending_queries); } +static void log_address_sorting_list(grpc_lb_addresses* lb_addrs, + const char* input_output_str) { + for (size_t i = 0; i < lb_addrs->num_addresses; i++) { + char* addr_str; + if (grpc_sockaddr_to_string(&addr_str, &lb_addrs->addresses[i].address, + true)) { + gpr_log(GPR_DEBUG, "c-ares address sorting: %s[%" PRIuPTR "]=%s", + input_output_str, i, addr_str); + gpr_free(addr_str); + } else { + gpr_log(GPR_DEBUG, + "c-ares address sorting: %s[%" PRIuPTR "]=<unprintable>", + input_output_str, i); + } + } +} + +void grpc_cares_wrapper_address_sorting_sort(grpc_lb_addresses* lb_addrs) { + if (grpc_trace_cares_address_sorting.enabled()) { + log_address_sorting_list(lb_addrs, "input"); + } + address_sorting_sortable* sortables = (address_sorting_sortable*)gpr_zalloc( + sizeof(address_sorting_sortable) * lb_addrs->num_addresses); + for (size_t i = 0; i < lb_addrs->num_addresses; i++) { + sortables[i].user_data = &lb_addrs->addresses[i]; + memcpy(&sortables[i].dest_addr.addr, &lb_addrs->addresses[i].address.addr, + lb_addrs->addresses[i].address.len); + sortables[i].dest_addr.len = lb_addrs->addresses[i].address.len; + } + address_sorting_rfc_6724_sort(sortables, lb_addrs->num_addresses); + grpc_lb_address* sorted_lb_addrs = (grpc_lb_address*)gpr_zalloc( + sizeof(grpc_lb_address) * lb_addrs->num_addresses); + for (size_t i = 0; i < lb_addrs->num_addresses; i++) { + sorted_lb_addrs[i] = *(grpc_lb_address*)sortables[i].user_data; + } + gpr_free(sortables); + gpr_free(lb_addrs->addresses); + lb_addrs->addresses = sorted_lb_addrs; + if (grpc_trace_cares_address_sorting.enabled()) { + log_address_sorting_list(lb_addrs, "output"); + } +} + +/* Allow tests to access grpc_ares_wrapper_address_sorting_sort */ +void grpc_cares_wrapper_test_only_address_sorting_sort( + grpc_lb_addresses* lb_addrs) { + grpc_cares_wrapper_address_sorting_sort(lb_addrs); +} + static void grpc_ares_request_unref(grpc_ares_request* r) { /* If there are no pending queries, invoke on_done callback and destroy the request */ if (gpr_unref(&r->pending_queries)) { - /* TODO(zyc): Sort results with RFC6724 before invoking on_done. */ + grpc_cares_wrapper_address_sorting_sort(*(r->lb_addrs_out)); GRPC_CLOSURE_SCHED(r->on_done, r->error); gpr_mu_destroy(&r->mu); grpc_ares_ev_driver_destroy(r->ev_driver); diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h index 3e8ea01e12..2d84a038d6 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h @@ -26,6 +26,8 @@ #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/iomgr/resolve_address.h" +extern grpc_core::TraceFlag grpc_trace_cares_address_sorting; + typedef struct grpc_ares_request grpc_ares_request; /* Asynchronously resolve \a name. Use \a default_port if a port isn't @@ -64,5 +66,9 @@ grpc_error* grpc_ares_init(void); it has been called the same number of times as grpc_ares_init(). */ void grpc_ares_cleanup(void); +/* Exposed only for testing */ +void grpc_cares_wrapper_test_only_address_sorting_sort( + grpc_lb_addresses* lb_addrs); + #endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_DNS_C_ARES_GRPC_ARES_WRAPPER_H \ */ diff --git a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc index 4d8958f519..99a33f2277 100644 --- a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc @@ -82,6 +82,8 @@ class FakeResolver : public Resolver { grpc_closure* next_completion_ = nullptr; // target result address for next completion grpc_channel_args** target_result_ = nullptr; + // if true, return failure + bool return_failure_ = false; }; FakeResolver::FakeResolver(const ResolverArgs& args) : Resolver(args.combiner) { @@ -121,12 +123,16 @@ void FakeResolver::RequestReresolutionLocked() { } void FakeResolver::MaybeFinishNextLocked() { - if (next_completion_ != nullptr && next_results_ != nullptr) { - *target_result_ = grpc_channel_args_union(next_results_, channel_args_); + if (next_completion_ != nullptr && + (next_results_ != nullptr || return_failure_)) { + *target_result_ = + return_failure_ ? nullptr + : grpc_channel_args_union(next_results_, channel_args_); grpc_channel_args_destroy(next_results_); next_results_ = nullptr; GRPC_CLOSURE_SCHED(next_completion_, GRPC_ERROR_NONE); next_completion_ = nullptr; + return_failure_ = false; } } @@ -197,6 +203,26 @@ void FakeResolverResponseGenerator::SetReresolutionResponse( GRPC_ERROR_NONE); } +void FakeResolverResponseGenerator::SetFailureLocked(void* arg, + grpc_error* error) { + SetResponseClosureArg* closure_arg = static_cast<SetResponseClosureArg*>(arg); + FakeResolver* resolver = closure_arg->generator->resolver_; + resolver->return_failure_ = true; + resolver->MaybeFinishNextLocked(); + Delete(closure_arg); +} + +void FakeResolverResponseGenerator::SetFailure() { + GPR_ASSERT(resolver_ != nullptr); + SetResponseClosureArg* closure_arg = New<SetResponseClosureArg>(); + closure_arg->generator = this; + GRPC_CLOSURE_SCHED( + GRPC_CLOSURE_INIT(&closure_arg->set_response_closure, SetFailureLocked, + closure_arg, + grpc_combiner_scheduler(resolver_->combiner())), + GRPC_ERROR_NONE); +} + namespace { static void* response_generator_arg_copy(void* p) { diff --git a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h index 858f35851d..e5175f9b7b 100644 --- a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h +++ b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h @@ -56,6 +56,10 @@ class FakeResolverResponseGenerator // resolver will return the last value set via \a SetResponse(). void SetReresolutionResponse(grpc_channel_args* response); + // Tells the resolver to return a transient failure (signalled by + // returning a null result with no error). + void SetFailure(); + // Returns a channel arg containing \a generator. static grpc_arg MakeChannelArg(FakeResolverResponseGenerator* generator); @@ -68,6 +72,7 @@ class FakeResolverResponseGenerator static void SetResponseLocked(void* arg, grpc_error* error); static void SetReresolutionResponseLocked(void* arg, grpc_error* error); + static void SetFailureLocked(void* arg, grpc_error* error); FakeResolver* resolver_ = nullptr; // Do not own. }; diff --git a/src/core/ext/filters/client_channel/subchannel.cc b/src/core/ext/filters/client_channel/subchannel.cc index cae7cc35e3..d7815fb7e1 100644 --- a/src/core/ext/filters/client_channel/subchannel.cc +++ b/src/core/ext/filters/client_channel/subchannel.cc @@ -40,6 +40,7 @@ #include "src/core/lib/debug/stats.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/sockaddr_utils.h" #include "src/core/lib/iomgr/timer.h" #include "src/core/lib/profiling/timers.h" diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index a4d616d778..dc4e002405 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -68,7 +68,7 @@ #define DEFAULT_MIN_SENT_PING_INTERVAL_WITHOUT_DATA_MS 300000 /* 5 minutes */ #define DEFAULT_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS 300000 /* 5 minutes */ -#define DEFAULT_MAX_PINGS_BETWEEN_DATA 0 /* unlimited */ +#define DEFAULT_MAX_PINGS_BETWEEN_DATA 2 #define DEFAULT_MAX_PING_STRIKES 2 static int g_default_client_keepalive_time_ms = @@ -669,6 +669,7 @@ static int init_stream(grpc_transport* gt, grpc_stream* gs, GRPC_CLOSURE_INIT(&s->complete_fetch_locked, complete_fetch_locked, s, grpc_schedule_on_exec_ctx); grpc_slice_buffer_init(&s->unprocessed_incoming_frames_buffer); + s->unprocessed_incoming_frames_buffer_cached_length = 0; grpc_slice_buffer_init(&s->frame_storage); grpc_slice_buffer_init(&s->compressed_data_buffer); grpc_slice_buffer_init(&s->decompressed_data_buffer); @@ -1577,20 +1578,27 @@ static void perform_stream_op_locked(void* stream_op, if (op->recv_message) { GRPC_STATS_INC_HTTP2_OP_RECV_MESSAGE(); - size_t already_received; + size_t before = 0; GPR_ASSERT(s->recv_message_ready == nullptr); GPR_ASSERT(!s->pending_byte_stream); s->recv_message_ready = op_payload->recv_message.recv_message_ready; s->recv_message = op_payload->recv_message.recv_message; if (s->id != 0) { if (!s->read_closed) { - already_received = s->frame_storage.length; + before = s->frame_storage.length + + s->unprocessed_incoming_frames_buffer.length; + } + } + grpc_chttp2_maybe_complete_recv_message(t, s); + if (s->id != 0) { + if (!s->read_closed && s->frame_storage.length == 0) { + size_t after = s->frame_storage.length + + s->unprocessed_incoming_frames_buffer_cached_length; s->flow_control->IncomingByteStreamUpdate(GRPC_HEADER_SIZE_IN_BYTES, - already_received); + before - after); grpc_chttp2_act_on_flowctl_action(s->flow_control->MakeAction(), t, s); } } - grpc_chttp2_maybe_complete_recv_message(t, s); } if (op->recv_trailing_metadata) { @@ -1867,6 +1875,10 @@ void grpc_chttp2_maybe_complete_recv_message(grpc_chttp2_transport* t, } } } + // save the length of the buffer before handing control back to application + // threads. Needed to support correct flow control bookkeeping + s->unprocessed_incoming_frames_buffer_cached_length = + s->unprocessed_incoming_frames_buffer.length; if (error == GRPC_ERROR_NONE && *s->recv_message != nullptr) { null_then_run_closure(&s->recv_message_ready, GRPC_ERROR_NONE); } else if (s->published_metadata[1] != GRPC_METADATA_NOT_PUBLISHED) { @@ -2630,7 +2642,7 @@ static void start_keepalive_ping_locked(void* arg, grpc_error* error) { grpc_chttp2_transport* t = static_cast<grpc_chttp2_transport*>(arg); GRPC_CHTTP2_REF_TRANSPORT(t, "keepalive watchdog"); grpc_timer_init(&t->keepalive_watchdog_timer, - grpc_core::ExecCtx::Get()->Now() + t->keepalive_time, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_timeout, &t->keepalive_watchdog_fired_locked); } diff --git a/src/core/ext/transport/chttp2/transport/internal.h b/src/core/ext/transport/chttp2/transport/internal.h index 6d11e5aa31..ca6e715978 100644 --- a/src/core/ext/transport/chttp2/transport/internal.h +++ b/src/core/ext/transport/chttp2/transport/internal.h @@ -550,6 +550,11 @@ struct grpc_chttp2_stream { grpc_slice_buffer unprocessed_incoming_frames_buffer; grpc_closure* on_next; /* protected by t combiner */ bool pending_byte_stream; /* protected by t combiner */ + // cached length of buffer to be used by the transport thread in cases where + // stream->pending_byte_stream == true. The value is saved before + // application threads are allowed to modify + // unprocessed_incoming_frames_buffer + size_t unprocessed_incoming_frames_buffer_cached_length; grpc_closure reset_byte_stream; grpc_error* byte_stream_error; /* protected by t combiner */ bool received_last_frame; /* protected by t combiner */ diff --git a/src/core/ext/transport/chttp2/transport/parsing.cc b/src/core/ext/transport/chttp2/transport/parsing.cc index 988380b76c..a10c9ada46 100644 --- a/src/core/ext/transport/chttp2/transport/parsing.cc +++ b/src/core/ext/transport/chttp2/transport/parsing.cc @@ -373,8 +373,6 @@ error_handler: /* t->parser = grpc_chttp2_data_parser_parse;*/ t->parser = grpc_chttp2_data_parser_parse; t->parser_data = &s->data_parser; - t->ping_state.pings_before_data_required = - t->ping_policy.max_pings_without_data; t->ping_state.last_ping_sent_time = GRPC_MILLIS_INF_PAST; return GRPC_ERROR_NONE; } else if (grpc_error_get_int(err, GRPC_ERROR_INT_STREAM_ID, nullptr)) { @@ -547,8 +545,6 @@ static grpc_error* init_header_frame_parser(grpc_chttp2_transport* t, (t->incoming_frame_flags & GRPC_CHTTP2_DATA_FLAG_END_STREAM) != 0; } - t->ping_state.pings_before_data_required = - t->ping_policy.max_pings_without_data; t->ping_state.last_ping_sent_time = GRPC_MILLIS_INF_PAST; /* could be a new grpc_chttp2_stream or an existing grpc_chttp2_stream */ diff --git a/src/core/ext/transport/chttp2/transport/writing.cc b/src/core/ext/transport/chttp2/transport/writing.cc index 7471d88aa1..6f32397a3a 100644 --- a/src/core/ext/transport/chttp2/transport/writing.cc +++ b/src/core/ext/transport/chttp2/transport/writing.cc @@ -224,7 +224,7 @@ class WriteContext { grpc_slice_buffer_add( &t_->outbuf, grpc_chttp2_window_update_create(0, transport_announce, &throwaway_stats)); - ResetPingRecvClock(); + ResetPingClock(); } } @@ -269,11 +269,13 @@ class WriteContext { return s; } - void ResetPingRecvClock() { + void ResetPingClock() { if (!t_->is_client) { t_->ping_recv_state.last_ping_recv_time = GRPC_MILLIS_INF_PAST; t_->ping_recv_state.ping_strikes = 0; } + t_->ping_state.pings_before_data_required = + t_->ping_policy.max_pings_without_data; } void IncInitialMetadataWrites() { ++initial_metadata_writes_; } @@ -435,7 +437,7 @@ class StreamWriteContext { }; grpc_chttp2_encode_header(&t_->hpack_compressor, nullptr, 0, s_->send_initial_metadata, &hopt, &t_->outbuf); - write_context_->ResetPingRecvClock(); + write_context_->ResetPingClock(); write_context_->IncInitialMetadataWrites(); } @@ -455,7 +457,7 @@ class StreamWriteContext { grpc_slice_buffer_add( &t_->outbuf, grpc_chttp2_window_update_create(s_->id, stream_announce, &s_->stats.outgoing)); - write_context_->ResetPingRecvClock(); + write_context_->ResetPingClock(); write_context_->IncWindowUpdateWrites(); } @@ -489,7 +491,7 @@ class StreamWriteContext { data_send_context.CompressMoreBytes(); } } - write_context_->ResetPingRecvClock(); + write_context_->ResetPingClock(); if (data_send_context.is_last_frame()) { SentLastFrame(); } @@ -530,7 +532,7 @@ class StreamWriteContext { s_->send_trailing_metadata, &hopt, &t_->outbuf); } write_context_->IncTrailingMetadataWrites(); - write_context_->ResetPingRecvClock(); + write_context_->ResetPingClock(); SentLastFrame(); write_context_->NoteScheduledResults(); diff --git a/src/core/lib/channel/channel_trace.cc b/src/core/lib/channel/channel_trace.cc new file mode 100644 index 0000000000..654300cd32 --- /dev/null +++ b/src/core/lib/channel/channel_trace.cc @@ -0,0 +1,239 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/impl/codegen/port_platform.h> + +#include "src/core/lib/channel/channel_trace.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "src/core/lib/channel/channel_trace_registry.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +ChannelTrace::TraceEvent::TraceEvent( + Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer, ReferencedType type) + : severity_(severity), + data_(data), + timestamp_(grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME)), + next_(nullptr), + referenced_tracer_(std::move(referenced_tracer)), + referenced_type_(type) {} + +ChannelTrace::TraceEvent::TraceEvent(Severity severity, grpc_slice data) + : severity_(severity), + data_(data), + timestamp_(grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME)), + next_(nullptr) {} + +ChannelTrace::TraceEvent::~TraceEvent() { grpc_slice_unref_internal(data_); } + +ChannelTrace::ChannelTrace(size_t max_events) + : channel_uuid_(-1), + num_events_logged_(0), + list_size_(0), + max_list_size_(max_events), + head_trace_(nullptr), + tail_trace_(nullptr) { + if (max_list_size_ == 0) return; // tracing is disabled if max_events == 0 + gpr_mu_init(&tracer_mu_); + channel_uuid_ = grpc_channel_trace_registry_register_channel_trace(this); + time_created_ = grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME); +} + +ChannelTrace::~ChannelTrace() { + if (max_list_size_ == 0) return; // tracing is disabled if max_events == 0 + TraceEvent* it = head_trace_; + while (it != nullptr) { + TraceEvent* to_free = it; + it = it->next(); + Delete<TraceEvent>(to_free); + } + grpc_channel_trace_registry_unregister_channel_trace(channel_uuid_); + gpr_mu_destroy(&tracer_mu_); +} + +intptr_t ChannelTrace::GetUuid() const { return channel_uuid_; } + +void ChannelTrace::AddTraceEventHelper(TraceEvent* new_trace_event) { + ++num_events_logged_; + // first event case + if (head_trace_ == nullptr) { + head_trace_ = tail_trace_ = new_trace_event; + } + // regular event add case + else { + tail_trace_->set_next(new_trace_event); + tail_trace_ = tail_trace_->next(); + } + ++list_size_; + // maybe garbage collect the end + if (list_size_ > max_list_size_) { + TraceEvent* to_free = head_trace_; + head_trace_ = head_trace_->next(); + Delete<TraceEvent>(to_free); + --list_size_; + } +} + +void ChannelTrace::AddTraceEvent(Severity severity, grpc_slice data) { + if (max_list_size_ == 0) return; // tracing is disabled if max_events == 0 + AddTraceEventHelper(New<TraceEvent>(severity, data)); +} + +void ChannelTrace::AddTraceEventReferencingChannel( + Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer) { + if (max_list_size_ == 0) return; // tracing is disabled if max_events == 0 + // create and fill up the new event + AddTraceEventHelper( + New<TraceEvent>(severity, data, std::move(referenced_tracer), Channel)); +} + +void ChannelTrace::AddTraceEventReferencingSubchannel( + Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer) { + if (max_list_size_ == 0) return; // tracing is disabled if max_events == 0 + // create and fill up the new event + AddTraceEventHelper(New<TraceEvent>( + severity, data, std::move(referenced_tracer), Subchannel)); +} + +namespace { + +// returns an allocated string that represents tm according to RFC-3339, and, +// more specifically, follows: +// https://developers.google.com/protocol-buffers/docs/proto3#json +// +// "Uses RFC 3339, where generated output will always be Z-normalized and uses +// 0, 3, 6 or 9 fractional digits." +char* fmt_time(gpr_timespec tm) { + char time_buffer[35]; + char ns_buffer[11]; // '.' + 9 digits of precision + struct tm* tm_info = localtime((const time_t*)&tm.tv_sec); + strftime(time_buffer, sizeof(time_buffer), "%Y-%m-%dT%H:%M:%S", tm_info); + snprintf(ns_buffer, 11, ".%09d", tm.tv_nsec); + // This loop trims off trailing zeros by inserting a null character that the + // right point. We iterate in chunks of three because we want 0, 3, 6, or 9 + // fractional digits. + for (int i = 7; i >= 1; i -= 3) { + if (ns_buffer[i] == '0' && ns_buffer[i + 1] == '0' && + ns_buffer[i + 2] == '0') { + ns_buffer[i] = '\0'; + // Edge case in which all fractional digits were 0. + if (i == 1) { + ns_buffer[0] = '\0'; + } + } else { + break; + } + } + char* full_time_str; + gpr_asprintf(&full_time_str, "%s%sZ", time_buffer, ns_buffer); + return full_time_str; +} + +const char* severity_string(ChannelTrace::Severity severity) { + switch (severity) { + case ChannelTrace::Severity::Info: + return "CT_INFO"; + case ChannelTrace::Severity::Warning: + return "CT_WARNING"; + case ChannelTrace::Severity::Error: + return "CT_ERROR"; + default: + GPR_UNREACHABLE_CODE(return "CT_UNKNOWN"); + } +} + +} // anonymous namespace + +void ChannelTrace::TraceEvent::RenderTraceEvent(grpc_json* json) const { + grpc_json* json_iterator = nullptr; + json_iterator = grpc_json_create_child(json_iterator, json, "description", + grpc_slice_to_c_string(data_), + GRPC_JSON_STRING, true); + json_iterator = grpc_json_create_child(json_iterator, json, "severity", + severity_string(severity_), + GRPC_JSON_STRING, false); + json_iterator = + grpc_json_create_child(json_iterator, json, "timestamp", + fmt_time(timestamp_), GRPC_JSON_STRING, true); + if (referenced_tracer_ != nullptr) { + char* uuid_str; + gpr_asprintf(&uuid_str, "%" PRIdPTR, referenced_tracer_->channel_uuid_); + grpc_json* child_ref = grpc_json_create_child( + json_iterator, json, + (referenced_type_ == Channel) ? "channelRef" : "subchannelRef", nullptr, + GRPC_JSON_OBJECT, false); + json_iterator = grpc_json_create_child( + nullptr, child_ref, + (referenced_type_ == Channel) ? "channelId" : "subchannelId", uuid_str, + GRPC_JSON_STRING, true); + json_iterator = child_ref; + } +} + +char* ChannelTrace::RenderTrace() const { + if (!max_list_size_) + return nullptr; // tracing is disabled if max_events == 0 + grpc_json* json = grpc_json_create(GRPC_JSON_OBJECT); + char* num_events_logged_str; + gpr_asprintf(&num_events_logged_str, "%" PRId64, num_events_logged_); + grpc_json* json_iterator = nullptr; + json_iterator = + grpc_json_create_child(json_iterator, json, "numEventsLogged", + num_events_logged_str, GRPC_JSON_STRING, true); + json_iterator = + grpc_json_create_child(json_iterator, json, "creationTime", + fmt_time(time_created_), GRPC_JSON_STRING, true); + grpc_json* events = grpc_json_create_child(json_iterator, json, "events", + nullptr, GRPC_JSON_ARRAY, false); + json_iterator = nullptr; + TraceEvent* it = head_trace_; + while (it != nullptr) { + json_iterator = grpc_json_create_child(json_iterator, events, nullptr, + nullptr, GRPC_JSON_OBJECT, false); + it->RenderTraceEvent(json_iterator); + it = it->next(); + } + char* json_str = grpc_json_dump_to_string(json, 0); + grpc_json_destroy(json); + return json_str; +} + +} // namespace grpc_core diff --git a/src/core/lib/channel/channel_trace.h b/src/core/lib/channel/channel_trace.h new file mode 100644 index 0000000000..1df1e585f2 --- /dev/null +++ b/src/core/lib/channel/channel_trace.h @@ -0,0 +1,133 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_H +#define GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_H + +#include <grpc/impl/codegen/port_platform.h> + +#include <grpc/grpc.h> +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/json/json.h" + +namespace grpc_core { + +// Object used to hold live data for a channel. This data is exposed via the +// channelz service: +// https://github.com/grpc/proposal/blob/master/A14-channelz.md +class ChannelTrace : public RefCounted<ChannelTrace> { + public: + ChannelTrace(size_t max_events); + ~ChannelTrace(); + + // returns the tracer's uuid + intptr_t GetUuid() const; + + enum Severity { + Unset = 0, // never to be used + Info, // we start at 1 to avoid using proto default values + Warning, + Error + }; + + // Adds a new trace event to the tracing object + // + // TODO(ncteisen): as this call is used more and more throughout the gRPC + // stack, determine if it makes more sense to accept a char* instead of a + // slice. + void AddTraceEvent(Severity severity, grpc_slice data); + + // Adds a new trace event to the tracing object. This trace event refers to a + // an event on a child of the channel. For example, if this channel has + // created a new subchannel, then it would record that with a TraceEvent + // referencing the new subchannel. + // + // TODO(ncteisen): Once channelz is implemented, the events should reference + // the overall channelz object, not just the ChannelTrace object. + // TODO(ncteisen): as this call is used more and more throughout the gRPC + // stack, determine if it makes more sense to accept a char* instead of a + // slice. + void AddTraceEventReferencingChannel( + Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer); + void AddTraceEventReferencingSubchannel( + Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer); + + // Returns the tracing data rendered as a grpc json string. + // The string is owned by the caller and must be freed. + char* RenderTrace() const; + + private: + // Types of objects that can be references by trace events. + enum ReferencedType { Channel, Subchannel }; + // Private class to encapsulate all the data and bookkeeping needed for a + // a trace event. + class TraceEvent { + public: + // Constructor for a TraceEvent that references a different channel. + // TODO(ncteisen): once channelz is implemented, this should reference the + // overall channelz object, not just the ChannelTrace object + TraceEvent(Severity severity, grpc_slice data, + RefCountedPtr<ChannelTrace> referenced_tracer, + ReferencedType type); + + // Constructor for a TraceEvent that does not reverence a different + // channel. + TraceEvent(Severity severity, grpc_slice data); + + ~TraceEvent(); + + // Renders the data inside of this TraceEvent into a json object. This is + // used by the ChannelTrace, when it is rendering itself. + void RenderTraceEvent(grpc_json* json) const; + + // set and get for the next_ pointer. + TraceEvent* next() const { return next_; } + void set_next(TraceEvent* next) { next_ = next; } + + private: + Severity severity_; + grpc_slice data_; + gpr_timespec timestamp_; + TraceEvent* next_; + // the tracer object for the (sub)channel that this trace event refers to. + RefCountedPtr<ChannelTrace> referenced_tracer_; + // the type that the referenced tracer points to. Unused if this trace + // does not point to any channel or subchannel + ReferencedType referenced_type_; + }; // TraceEvent + + // Internal helper to add and link in a trace event + void AddTraceEventHelper(TraceEvent* new_trace_event); + + gpr_mu tracer_mu_; + intptr_t channel_uuid_; + uint64_t num_events_logged_; + size_t list_size_; + size_t max_list_size_; + TraceEvent* head_trace_; + TraceEvent* tail_trace_; + gpr_timespec time_created_; +}; + +} // namespace grpc_core + +#endif /* GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_H */ diff --git a/src/core/lib/channel/channel_trace_registry.cc b/src/core/lib/channel/channel_trace_registry.cc new file mode 100644 index 0000000000..6c82431467 --- /dev/null +++ b/src/core/lib/channel/channel_trace_registry.cc @@ -0,0 +1,80 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/impl/codegen/port_platform.h> + +#include "src/core/lib/avl/avl.h" +#include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/channel_trace_registry.h" +#include "src/core/lib/gpr/useful.h" + +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> + +// file global lock and avl. +static gpr_mu g_mu; +static grpc_avl g_avl; +static gpr_atm g_uuid = 0; + +// avl vtable for uuid (intptr_t) -> ChannelTrace +// this table is only looking, it does not own anything. +static void destroy_intptr(void* not_used, void* user_data) {} +static void* copy_intptr(void* key, void* user_data) { return key; } +static long compare_intptr(void* key1, void* key2, void* user_data) { + return GPR_ICMP(key1, key2); +} + +static void destroy_channel_trace(void* trace, void* user_data) {} +static void* copy_channel_trace(void* trace, void* user_data) { return trace; } +static const grpc_avl_vtable avl_vtable = { + destroy_intptr, copy_intptr, compare_intptr, destroy_channel_trace, + copy_channel_trace}; + +void grpc_channel_trace_registry_init() { + gpr_mu_init(&g_mu); + g_avl = grpc_avl_create(&avl_vtable); +} + +void grpc_channel_trace_registry_shutdown() { + grpc_avl_unref(g_avl, nullptr); + gpr_mu_destroy(&g_mu); +} + +intptr_t grpc_channel_trace_registry_register_channel_trace( + grpc_core::ChannelTrace* channel_trace) { + intptr_t prior = gpr_atm_no_barrier_fetch_add(&g_uuid, 1); + gpr_mu_lock(&g_mu); + g_avl = grpc_avl_add(g_avl, (void*)prior, channel_trace, nullptr); + gpr_mu_unlock(&g_mu); + return prior; +} + +void grpc_channel_trace_registry_unregister_channel_trace(intptr_t uuid) { + gpr_mu_lock(&g_mu); + g_avl = grpc_avl_remove(g_avl, (void*)uuid, nullptr); + gpr_mu_unlock(&g_mu); +} + +grpc_core::ChannelTrace* grpc_channel_trace_registry_get_channel_trace( + intptr_t uuid) { + gpr_mu_lock(&g_mu); + grpc_core::ChannelTrace* ret = static_cast<grpc_core::ChannelTrace*>( + grpc_avl_get(g_avl, (void*)uuid, nullptr)); + gpr_mu_unlock(&g_mu); + return ret; +} diff --git a/src/core/lib/channel/channel_trace_registry.h b/src/core/lib/channel/channel_trace_registry.h new file mode 100644 index 0000000000..391ecba7de --- /dev/null +++ b/src/core/lib/channel/channel_trace_registry.h @@ -0,0 +1,43 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_REGISTRY_H +#define GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_REGISTRY_H + +#include <grpc/impl/codegen/port_platform.h> + +#include "src/core/lib/channel/channel_trace.h" + +#include <stdint.h> + +// TODO(ncteisen): convert this file to C++ + +void grpc_channel_trace_registry_init(); +void grpc_channel_trace_registry_shutdown(); + +// globally registers a ChannelTrace. Returns its unique uuid +intptr_t grpc_channel_trace_registry_register_channel_trace( + grpc_core::ChannelTrace* channel_trace); +// globally unregisters the ChannelTrace that is associated to uuid. +void grpc_channel_trace_registry_unregister_channel_trace(intptr_t uuid); +// if object with uuid has previously been registered, returns the ChannelTrace +// associated with that uuid. Else returns nullptr. +grpc_core::ChannelTrace* grpc_channel_trace_registry_get_channel_trace( + intptr_t uuid); + +#endif /* GRPC_CORE_LIB_CHANNEL_CHANNEL_TRACE_REGISTRY_H */ diff --git a/src/core/ext/filters/client_channel/status_util.cc b/src/core/lib/channel/status_util.cc index 11f732ab44..563db40846 100644 --- a/src/core/ext/filters/client_channel/status_util.cc +++ b/src/core/lib/channel/status_util.cc @@ -18,7 +18,7 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/status_util.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/gpr/useful.h" diff --git a/src/core/ext/filters/client_channel/status_util.h b/src/core/lib/channel/status_util.h index e018709730..5409de6b3c 100644 --- a/src/core/ext/filters/client_channel/status_util.h +++ b/src/core/lib/channel/status_util.h @@ -16,8 +16,8 @@ * */ -#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_STATUS_UTIL_H -#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_STATUS_UTIL_H +#ifndef GRPC_CORE_LIB_CHANNEL_STATUS_UTIL_H +#define GRPC_CORE_LIB_CHANNEL_STATUS_UTIL_H #include <grpc/support/port_platform.h> @@ -55,4 +55,4 @@ class StatusCodeSet { } // namespace internal } // namespace grpc_core -#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_STATUS_UTIL_H */ +#endif /* GRPC_CORE_LIB_CHANNEL_STATUS_UTIL_H */ diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 180912383f..0b53d63e77 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -102,8 +102,8 @@ static grpc_security_connector_vtable httpcli_ssl_vtable = { httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp}; static grpc_security_status httpcli_ssl_channel_security_connector_create( - const char* pem_root_certs, const char* secure_peer_name, - grpc_channel_security_connector** sc) { + const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, + const char* secure_peer_name, grpc_channel_security_connector** sc) { tsi_result result = TSI_OK; grpc_httpcli_ssl_channel_security_connector* c; @@ -121,8 +121,12 @@ static grpc_security_status httpcli_ssl_channel_security_connector_create( if (secure_peer_name != nullptr) { c->secure_peer_name = gpr_strdup(secure_peer_name); } - result = tsi_create_ssl_client_handshaker_factory( - nullptr, pem_root_certs, nullptr, nullptr, 0, &c->handshaker_factory); + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + result = tsi_create_ssl_client_handshaker_factory_with_options( + &options, &c->handshaker_factory); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); @@ -169,8 +173,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, grpc_millis deadline, void (*on_done)(void* arg, grpc_endpoint* endpoint)) { on_done_closure* c = static_cast<on_done_closure*>(gpr_malloc(sizeof(*c))); - const char* pem_root_certs = grpc_get_default_ssl_roots(); - if (pem_root_certs == nullptr) { + const char* pem_root_certs = + grpc_core::DefaultSslRootStore::GetPemRootCerts(); + const tsi_ssl_root_certs_store* root_store = + grpc_core::DefaultSslRootStore::GetRootStore(); + if (root_store == nullptr) { gpr_log(GPR_ERROR, "Could not get default pem root certs."); on_done(arg, nullptr); gpr_free(c); @@ -180,7 +187,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, c->arg = arg; grpc_channel_security_connector* sc = nullptr; GPR_ASSERT(httpcli_ssl_channel_security_connector_create( - pem_root_certs, host, &sc) == GRPC_SECURITY_OK); + pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK); grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base); grpc_channel_args args = {1, &channel_arg}; c->handshake_mgr = grpc_handshake_manager_create(); diff --git a/src/core/lib/iomgr/tcp_client_posix.cc b/src/core/lib/iomgr/tcp_client_posix.cc index ba943d302a..9f19b833da 100644 --- a/src/core/lib/iomgr/tcp_client_posix.cc +++ b/src/core/lib/iomgr/tcp_client_posix.cc @@ -293,7 +293,6 @@ void grpc_tcp_client_create_from_prepared_fd( int err; async_connect* ac; do { - GPR_ASSERT(addr->len < ~(socklen_t)0); err = connect(fd, reinterpret_cast<const grpc_sockaddr*>(addr->addr), addr->len); } while (err < 0 && errno == EINTR); diff --git a/src/core/lib/json/json.cc b/src/core/lib/json/json.cc index 2141db4c5b..816241bbf0 100644 --- a/src/core/lib/json/json.cc +++ b/src/core/lib/json/json.cc @@ -21,6 +21,7 @@ #include <string.h> #include <grpc/support/alloc.h> +#include <grpc/support/log.h> #include "src/core/lib/json/json.h" @@ -46,5 +47,40 @@ void grpc_json_destroy(grpc_json* json) { json->parent->child = json->next; } + if (json->owns_value) { + gpr_free((void*)json->value); + } + gpr_free(json); } + +grpc_json* grpc_json_link_child(grpc_json* parent, grpc_json* child, + grpc_json* sibling) { + // first child case. + if (parent->child == nullptr) { + GPR_ASSERT(sibling == nullptr); + parent->child = child; + return child; + } + if (sibling == nullptr) { + sibling = parent->child; + } + // always find the right most sibling. + while (sibling->next != nullptr) { + sibling = sibling->next; + } + sibling->next = child; + return child; +} + +grpc_json* grpc_json_create_child(grpc_json* sibling, grpc_json* parent, + const char* key, const char* value, + grpc_json_type type, bool owns_value) { + grpc_json* child = grpc_json_create(type); + grpc_json_link_child(parent, child, sibling); + child->owns_value = owns_value; + child->parent = parent; + child->value = value; + child->key = key; + return child; +} diff --git a/src/core/lib/json/json.h b/src/core/lib/json/json.h index 3a62ef9cfb..f93b43048b 100644 --- a/src/core/lib/json/json.h +++ b/src/core/lib/json/json.h @@ -21,6 +21,7 @@ #include <grpc/support/port_platform.h> +#include <stdbool.h> #include <stdlib.h> #include "src/core/lib/json/json_common.h" @@ -37,6 +38,9 @@ typedef struct grpc_json { grpc_json_type type; const char* key; const char* value; + + /* if set, destructor will free value */ + bool owns_value; } grpc_json; /* The next two functions are going to parse the input string, and @@ -67,9 +71,24 @@ char* grpc_json_dump_to_string(grpc_json* json, int indent); /* Use these to create or delete a grpc_json object. * Deletion is recursive. We will not attempt to free any of the strings - * in any of the objects of that tree. + * in any of the objects of that tree, unless the boolean, owns_value, + * is true. */ grpc_json* grpc_json_create(grpc_json_type type); void grpc_json_destroy(grpc_json* json); +/* Links the child json object into the parent's json tree. If the parent + * already has children, then passing in the most recently added child as the + * sibling parameter is an optimization. For if sibling is NULL, this function + * will manually traverse the tree in order to find the right most sibling. + */ +grpc_json* grpc_json_link_child(grpc_json* parent, grpc_json* child, + grpc_json* sibling); + +/* Creates a child json object into the parent's json tree then links it in + * as described above. */ +grpc_json* grpc_json_create_child(grpc_json* sibling, grpc_json* parent, + const char* key, const char* value, + grpc_json_type type, bool owns_value); + #endif /* GRPC_CORE_LIB_JSON_JSON_H */ diff --git a/src/core/lib/profiling/basic_timers.cc b/src/core/lib/profiling/basic_timers.cc index 43384fd0ca..652a498b6e 100644 --- a/src/core/lib/profiling/basic_timers.cc +++ b/src/core/lib/profiling/basic_timers.cc @@ -26,11 +26,11 @@ #include <grpc/support/log.h> #include <grpc/support/sync.h> #include <grpc/support/time.h> +#include <inttypes.h> #include <stdio.h> #include <string.h> #include "src/core/lib/gpr/env.h" -#include "src/core/lib/gprpp/thd.h" typedef enum { BEGIN = '{', END = '}', MARK = '.' } marker_type; @@ -68,7 +68,7 @@ static pthread_cond_t g_cv; static gpr_timer_log_list g_in_progress_logs; static gpr_timer_log_list g_done_logs; static int g_shutdown; -static grpc_core::Thread* g_writing_thread; +static pthread_t g_writing_thread; static __thread int g_thread_id; static int g_next_thread_id; static int g_writing_enabled = 1; @@ -149,7 +149,7 @@ static void write_log(gpr_timer_log* log) { } } -static void writing_thread(void* unused) { +static void* writing_thread(void* unused) { gpr_timer_log* log; pthread_mutex_lock(&g_mu); for (;;) { @@ -164,7 +164,7 @@ static void writing_thread(void* unused) { } if (g_shutdown) { pthread_mutex_unlock(&g_mu); - return; + return NULL; } } } @@ -182,8 +182,7 @@ static void finish_writing(void) { g_shutdown = 1; pthread_cond_signal(&g_cv); pthread_mutex_unlock(&g_mu); - g_writing_thread->Join(); - grpc_core::Delete(g_writing_thread); + pthread_join(g_writing_thread, NULL); gpr_log(GPR_INFO, "flushing logs"); @@ -202,8 +201,12 @@ void gpr_timers_set_log_filename(const char* filename) { } static void init_output() { - g_writing_thread = grpc_core::New<grpc_core::Thread>("timer_output_thread", - writing_thread, nullptr); + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE); + pthread_create(&g_writing_thread, &attr, &writing_thread, NULL); + pthread_attr_destroy(&attr); + atexit(finish_writing); } diff --git a/src/core/lib/profiling/timers.h b/src/core/lib/profiling/timers.h index d0188b5054..7ff72783ec 100644 --- a/src/core/lib/profiling/timers.h +++ b/src/core/lib/profiling/timers.h @@ -82,9 +82,12 @@ class ProfileScope { }; } // namespace grpc -#define GPR_TIMER_SCOPE(tag, important) \ - ::grpc::ProfileScope _profile_scope_##__LINE__((tag), (important), __FILE__, \ - __LINE__) +#define GPR_TIMER_SCOPE_NAME_INTERNAL(prefix, line) prefix##line +#define GPR_TIMER_SCOPE_NAME(prefix, line) \ + GPR_TIMER_SCOPE_NAME_INTERNAL(prefix, line) +#define GPR_TIMER_SCOPE(tag, important) \ + ::grpc::ProfileScope GPR_TIMER_SCOPE_NAME(_profile_scope_, __LINE__)( \ + (tag), (important), __FILE__, __LINE__) #endif /* at least one profiler requested. */ diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index 252b25bc0a..2b6377d3ec 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -24,6 +24,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/surface/api_trace.h" +#include "src/core/tsi/ssl_transport_security.h" #include <grpc/support/alloc.h> #include <grpc/support/log.h> @@ -56,16 +57,22 @@ static grpc_security_status ssl_create_security_connector( grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds); grpc_security_status status = GRPC_SECURITY_OK; const char* overridden_target_name = nullptr; + tsi_ssl_session_cache* ssl_session_cache = nullptr; for (size_t i = 0; args && i < args->num_args; i++) { grpc_arg* arg = &args->args[i]; if (strcmp(arg->key, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == 0 && arg->type == GRPC_ARG_STRING) { overridden_target_name = arg->value.string; - break; + } + if (strcmp(arg->key, GRPC_SSL_SESSION_CACHE_ARG) == 0 && + arg->type == GRPC_ARG_POINTER) { + ssl_session_cache = + static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p); } } status = grpc_ssl_channel_security_connector_create( - creds, call_creds, &c->config, target, overridden_target_name, sc); + creds, call_creds, &c->config, target, overridden_target_name, + ssl_session_cache, sc); if (status != GRPC_SECURITY_OK) { return status; } diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index 3cc151bec7..3967112bb8 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -542,6 +542,46 @@ grpc_server_security_connector* grpc_fake_server_security_connector_create( /* --- Ssl implementation. --- */ +grpc_ssl_session_cache* grpc_ssl_session_cache_create_lru(size_t capacity) { + tsi_ssl_session_cache* cache = tsi_ssl_session_cache_create_lru(capacity); + return reinterpret_cast<grpc_ssl_session_cache*>(cache); +} + +void grpc_ssl_session_cache_destroy(grpc_ssl_session_cache* cache) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast<tsi_ssl_session_cache*>(cache); + tsi_ssl_session_cache_unref(tsi_cache); +} + +static void* grpc_ssl_session_cache_arg_copy(void* p) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast<tsi_ssl_session_cache*>(p); + // destroy call below will unref the pointer. + tsi_ssl_session_cache_ref(tsi_cache); + return p; +} + +static void grpc_ssl_session_cache_arg_destroy(void* p) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast<tsi_ssl_session_cache*>(p); + tsi_ssl_session_cache_unref(tsi_cache); +} + +static int grpc_ssl_session_cache_arg_cmp(void* p, void* q) { + return GPR_ICMP(p, q); +} + +grpc_arg grpc_ssl_session_cache_create_channel_arg( + grpc_ssl_session_cache* cache) { + static const grpc_arg_pointer_vtable vtable = { + grpc_ssl_session_cache_arg_copy, + grpc_ssl_session_cache_arg_destroy, + grpc_ssl_session_cache_arg_cmp, + }; + return grpc_channel_arg_pointer_create( + const_cast<char*>(GRPC_SSL_SESSION_CACHE_ARG), cache, &vtable); +} + typedef struct { grpc_channel_security_connector base; tsi_ssl_client_handshaker_factory* client_handshaker_factory; @@ -760,6 +800,9 @@ grpc_auth_context* tsi_ssl_peer_to_auth_context(const tsi_peer* peer) { } else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) { grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME, prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY, + prop->value.data, prop->value.length); } } if (peer_identity_property_name != nullptr) { @@ -922,91 +965,37 @@ static grpc_security_connector_vtable ssl_channel_vtable = { static grpc_security_connector_vtable ssl_server_vtable = { ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp}; -/* returns a NULL terminated slice. */ -static grpc_slice compute_default_pem_root_certs_once(void) { - grpc_slice result = grpc_empty_slice(); - - /* First try to load the roots from the environment. */ - char* default_root_certs_path = - gpr_getenv(GRPC_DEFAULT_SSL_ROOTS_FILE_PATH_ENV_VAR); - if (default_root_certs_path != nullptr) { - GRPC_LOG_IF_ERROR("load_file", - grpc_load_file(default_root_certs_path, 1, &result)); - gpr_free(default_root_certs_path); - } - - /* Try overridden roots if needed. */ - grpc_ssl_roots_override_result ovrd_res = GRPC_SSL_ROOTS_OVERRIDE_FAIL; - if (GRPC_SLICE_IS_EMPTY(result) && ssl_roots_override_cb != nullptr) { - char* pem_root_certs = nullptr; - ovrd_res = ssl_roots_override_cb(&pem_root_certs); - if (ovrd_res == GRPC_SSL_ROOTS_OVERRIDE_OK) { - GPR_ASSERT(pem_root_certs != nullptr); - result = grpc_slice_from_copied_buffer( - pem_root_certs, - strlen(pem_root_certs) + 1); // NULL terminator. - } - gpr_free(pem_root_certs); - } - - /* Fall back to installed certs if needed. */ - if (GRPC_SLICE_IS_EMPTY(result) && - ovrd_res != GRPC_SSL_ROOTS_OVERRIDE_FAIL_PERMANENTLY) { - GRPC_LOG_IF_ERROR("load_file", - grpc_load_file(installed_roots_path, 1, &result)); - } - return result; -} - -static grpc_slice default_pem_root_certs; - -static void init_default_pem_root_certs(void) { - default_pem_root_certs = compute_default_pem_root_certs_once(); -} - -grpc_slice grpc_get_default_ssl_roots_for_testing(void) { - return compute_default_pem_root_certs_once(); -} - -const char* grpc_get_default_ssl_roots(void) { - /* TODO(jboeuf@google.com): Maybe revisit the approach which consists in - loading all the roots once for the lifetime of the process. */ - static gpr_once once = GPR_ONCE_INIT; - gpr_once_init(&once, init_default_pem_root_certs); - return GRPC_SLICE_IS_EMPTY(default_pem_root_certs) - ? nullptr - : reinterpret_cast<const char*> - GRPC_SLICE_START_PTR(default_pem_root_certs); -} - grpc_security_status grpc_ssl_channel_security_connector_create( grpc_channel_credentials* channel_creds, grpc_call_credentials* request_metadata_creds, const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, grpc_channel_security_connector** sc) { - size_t num_alpn_protocols = 0; - const char** alpn_protocol_strings = - fill_alpn_protocol_strings(&num_alpn_protocols); + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache, + grpc_channel_security_connector** sc) { tsi_result result = TSI_OK; grpc_ssl_channel_security_connector* c; - const char* pem_root_certs; char* port; bool has_key_cert_pair; + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.alpn_protocols = + fill_alpn_protocol_strings(&options.num_alpn_protocols); if (config == nullptr || target_name == nullptr) { gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); goto error; } if (config->pem_root_certs == nullptr) { - pem_root_certs = grpc_get_default_ssl_roots(); - if (pem_root_certs == nullptr) { + // Use default root certificates. + options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + options.root_store = grpc_core::DefaultSslRootStore::GetRootStore(); + if (options.pem_root_certs == nullptr) { gpr_log(GPR_ERROR, "Could not get default pem root certs."); goto error; } } else { - pem_root_certs = config->pem_root_certs; + options.pem_root_certs = config->pem_root_certs; } - c = static_cast<grpc_ssl_channel_security_connector*>( gpr_zalloc(sizeof(grpc_ssl_channel_security_connector))); @@ -1028,10 +1017,13 @@ grpc_security_status grpc_ssl_channel_security_connector_create( has_key_cert_pair = config->pem_key_cert_pair != nullptr && config->pem_key_cert_pair->private_key != nullptr && config->pem_key_cert_pair->cert_chain != nullptr; - result = tsi_create_ssl_client_handshaker_factory( - has_key_cert_pair ? config->pem_key_cert_pair : nullptr, pem_root_certs, - ssl_cipher_suites(), alpn_protocol_strings, - static_cast<uint16_t>(num_alpn_protocols), &c->client_handshaker_factory); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + result = tsi_create_ssl_client_handshaker_factory_with_options( + &options, &c->client_handshaker_factory); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); @@ -1040,11 +1032,11 @@ grpc_security_status grpc_ssl_channel_security_connector_create( goto error; } *sc = &c->base; - gpr_free((void*)alpn_protocol_strings); + gpr_free((void*)options.alpn_protocols); return GRPC_SECURITY_OK; error: - gpr_free((void*)alpn_protocol_strings); + gpr_free((void*)options.alpn_protocols); return GRPC_SECURITY_ERROR; } @@ -1109,3 +1101,79 @@ grpc_security_status grpc_ssl_server_security_connector_create( } return retval; } + +namespace grpc_core { + +tsi_ssl_root_certs_store* DefaultSslRootStore::default_root_store_; +grpc_slice DefaultSslRootStore::default_pem_root_certs_; + +const tsi_ssl_root_certs_store* DefaultSslRootStore::GetRootStore() { + InitRootStore(); + return default_root_store_; +} + +const char* DefaultSslRootStore::GetPemRootCerts() { + InitRootStore(); + return GRPC_SLICE_IS_EMPTY(default_pem_root_certs_) + ? nullptr + : reinterpret_cast<const char*> + GRPC_SLICE_START_PTR(default_pem_root_certs_); +} + +void DefaultSslRootStore::Initialize() { + default_root_store_ = nullptr; + default_pem_root_certs_ = grpc_empty_slice(); +} + +void DefaultSslRootStore::Destroy() { + tsi_ssl_root_certs_store_destroy(default_root_store_); + grpc_slice_unref_internal(default_pem_root_certs_); +} + +grpc_slice DefaultSslRootStore::ComputePemRootCerts() { + grpc_slice result = grpc_empty_slice(); + // First try to load the roots from the environment. + char* default_root_certs_path = + gpr_getenv(GRPC_DEFAULT_SSL_ROOTS_FILE_PATH_ENV_VAR); + if (default_root_certs_path != nullptr) { + GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(default_root_certs_path, 1, &result)); + gpr_free(default_root_certs_path); + } + // Try overridden roots if needed. + grpc_ssl_roots_override_result ovrd_res = GRPC_SSL_ROOTS_OVERRIDE_FAIL; + if (GRPC_SLICE_IS_EMPTY(result) && ssl_roots_override_cb != nullptr) { + char* pem_root_certs = nullptr; + ovrd_res = ssl_roots_override_cb(&pem_root_certs); + if (ovrd_res == GRPC_SSL_ROOTS_OVERRIDE_OK) { + GPR_ASSERT(pem_root_certs != nullptr); + result = grpc_slice_from_copied_buffer( + pem_root_certs, + strlen(pem_root_certs) + 1); // nullptr terminator. + } + gpr_free(pem_root_certs); + } + // Fall back to installed certs if needed. + if (GRPC_SLICE_IS_EMPTY(result) && + ovrd_res != GRPC_SSL_ROOTS_OVERRIDE_FAIL_PERMANENTLY) { + GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(installed_roots_path, 1, &result)); + } + return result; +} + +void DefaultSslRootStore::InitRootStore() { + static gpr_once once = GPR_ONCE_INIT; + gpr_once_init(&once, DefaultSslRootStore::InitRootStoreOnce); +} + +void DefaultSslRootStore::InitRootStoreOnce() { + default_pem_root_certs_ = ComputePemRootCerts(); + if (!GRPC_SLICE_IS_EMPTY(default_pem_root_certs_)) { + default_root_store_ = + tsi_ssl_root_certs_store_create(reinterpret_cast<const char*>( + GRPC_SLICE_START_PTR(default_pem_root_certs_))); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 130c8ecd3e..5d3d1e0f44 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -212,13 +212,9 @@ grpc_security_status grpc_ssl_channel_security_connector_create( grpc_channel_credentials* channel_creds, grpc_call_credentials* request_metadata_creds, const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, grpc_channel_security_connector** sc); - -/* Gets the default ssl roots. Returns NULL if not found. */ -const char* grpc_get_default_ssl_roots(void); - -/* Exposed for TESTING ONLY!. */ -grpc_slice grpc_get_default_ssl_roots_for_testing(void); + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache, + grpc_channel_security_connector** sc); /* Config for ssl servers. */ typedef struct { @@ -248,4 +244,49 @@ tsi_peer tsi_shallow_peer_from_ssl_auth_context( const grpc_auth_context* auth_context); void tsi_shallow_peer_destruct(tsi_peer* peer); +/* --- Default SSL Root Store. --- */ +namespace grpc_core { + +// The class implements default SSL root store. +class DefaultSslRootStore { + public: + // Gets the default SSL root store. Returns nullptr if not found. + static const tsi_ssl_root_certs_store* GetRootStore(); + + // Gets the default PEM root certificate. + static const char* GetPemRootCerts(); + + // Initializes the SSL root store's underlying data structure. It does not + // load default SSL root certificates. Should only be called by + // grpc_security_init(). + static void Initialize(); + + // Destroys the default SSL root store. Should only be called by + // grpc_security_shutdown(). + static void Destroy(); + + protected: + // Returns default PEM root certificates in nullptr terminated grpc_slice. + // This function is protected instead of private, so that it can be tested. + static grpc_slice ComputePemRootCerts(); + + private: + // Construct me not! + DefaultSslRootStore(); + + // Initialization of default SSL root store. + static void InitRootStore(); + + // One-time initialization of default SSL root store. + static void InitRootStoreOnce(); + + // SSL root store in tsi_ssl_root_certs_store object. + static tsi_ssl_root_certs_store* default_root_store_; + + // Default PEM root certificates. + static grpc_slice default_pem_root_certs_; +}; + +} // namespace grpc_core + #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */ diff --git a/src/core/lib/surface/channel.cc b/src/core/lib/surface/channel.cc index 03353d6beb..cecc15b2df 100644 --- a/src/core/lib/surface/channel.cc +++ b/src/core/lib/surface/channel.cc @@ -21,6 +21,7 @@ #include "src/core/lib/surface/channel.h" #include <inttypes.h> +#include <limits.h> #include <stdlib.h> #include <string.h> @@ -30,8 +31,12 @@ #include <grpc/support/string_util.h> #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_trace.h" #include "src/core/lib/debug/stats.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/iomgr.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/surface/api_trace.h" @@ -62,6 +67,8 @@ struct grpc_channel { gpr_mu registered_call_mu; registered_call* registered_calls; + grpc_core::RefCountedPtr<grpc_core::ChannelTrace> tracer; + char* target; }; @@ -93,12 +100,14 @@ grpc_channel* grpc_channel_create_with_builder( grpc_error_string(error)); GRPC_ERROR_UNREF(error); gpr_free(target); - goto done; + grpc_channel_args_destroy(args); + return channel; } memset(channel, 0, sizeof(*channel)); channel->target = target; channel->is_client = grpc_channel_stack_type_is_client(channel_stack_type); + size_t channel_tracer_max_nodes = 0; // default to off gpr_mu_init(&channel->registered_call_mu); channel->registered_calls = nullptr; @@ -161,14 +170,33 @@ grpc_channel* grpc_channel_create_with_builder( channel->compression_options.enabled_algorithms_bitset = static_cast<uint32_t>(args->args[i].value.integer) | 0x1; /* always support no compression */ + } else if (0 == strcmp(args->args[i].key, + GRPC_ARG_MAX_CHANNEL_TRACE_EVENTS_PER_NODE)) { + GPR_ASSERT(channel_tracer_max_nodes == 0); + // max_nodes defaults to 0 (which is off), clamped between 0 and INT_MAX + const grpc_integer_options options = {0, 0, INT_MAX}; + channel_tracer_max_nodes = + (size_t)grpc_channel_arg_get_integer(&args->args[i], options); } } -done: grpc_channel_args_destroy(args); + channel->tracer = grpc_core::MakeRefCounted<grpc_core::ChannelTrace>( + channel_tracer_max_nodes); + channel->tracer->AddTraceEvent( + grpc_core::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("Channel created")); return channel; } +char* grpc_channel_get_trace(grpc_channel* channel) { + return channel->tracer->RenderTrace(); +} + +intptr_t grpc_channel_get_uuid(grpc_channel* channel) { + return channel->tracer->GetUuid(); +} + grpc_channel* grpc_channel_create(const char* target, const grpc_channel_args* input_args, grpc_channel_stack_type channel_stack_type, @@ -377,6 +405,7 @@ static void destroy_channel(void* arg, grpc_error* error) { GRPC_MDELEM_UNREF(rc->authority); gpr_free(rc); } + channel->tracer.reset(); GRPC_MDELEM_UNREF(channel->default_authority); gpr_mu_destroy(&channel->registered_call_mu); gpr_free(channel->target); diff --git a/src/core/lib/surface/init.cc b/src/core/lib/surface/init.cc index ac9f9e6066..52e0ee1c44 100644 --- a/src/core/lib/surface/init.cc +++ b/src/core/lib/surface/init.cc @@ -27,6 +27,7 @@ #include <grpc/support/log.h> #include <grpc/support/time.h> #include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_trace_registry.h" #include "src/core/lib/channel/connected_channel.h" #include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/debug/stats.h" @@ -128,6 +129,7 @@ void grpc_init(void) { grpc_slice_intern_init(); grpc_mdctx_global_init(); grpc_channel_init_init(); + grpc_channel_trace_registry_init(); grpc_security_pre_init(); grpc_core::ExecCtx::GlobalInit(); grpc_iomgr_init(); @@ -170,12 +172,14 @@ void grpc_shutdown(void) { } } } + grpc_security_shutdown(); grpc_iomgr_shutdown(); gpr_timers_global_destroy(); grpc_tracer_shutdown(); grpc_mdctx_global_shutdown(); grpc_handshaker_factory_registry_shutdown(); grpc_slice_intern_shutdown(); + grpc_channel_trace_registry_shutdown(); grpc_stats_shutdown(); } grpc_core::ExecCtx::GlobalShutdown(); diff --git a/src/core/lib/surface/init.h b/src/core/lib/surface/init.h index 9353208332..d8282b475b 100644 --- a/src/core/lib/surface/init.h +++ b/src/core/lib/surface/init.h @@ -22,6 +22,7 @@ void grpc_register_security_filters(void); void grpc_security_pre_init(void); void grpc_security_init(void); +void grpc_security_shutdown(void); int grpc_is_initialized(void); #endif /* GRPC_CORE_LIB_SURFACE_INIT_H */ diff --git a/src/core/lib/surface/init_secure.cc b/src/core/lib/surface/init_secure.cc index 78e983e0cd..3e4f1a8829 100644 --- a/src/core/lib/surface/init_secure.cc +++ b/src/core/lib/surface/init_secure.cc @@ -75,4 +75,9 @@ void grpc_register_security_filters(void) { maybe_prepend_server_auth_filter, nullptr); } -void grpc_security_init() { grpc_security_register_handshaker_factories(); } +void grpc_security_init() { + grpc_security_register_handshaker_factories(); + grpc_core::DefaultSslRootStore::Initialize(); +} + +void grpc_security_shutdown() { grpc_core::DefaultSslRootStore::Destroy(); } diff --git a/src/core/lib/surface/init_unsecure.cc b/src/core/lib/surface/init_unsecure.cc index 2b3bc64382..1c8d07b38b 100644 --- a/src/core/lib/surface/init_unsecure.cc +++ b/src/core/lib/surface/init_unsecure.cc @@ -25,3 +25,5 @@ void grpc_security_pre_init(void) {} void grpc_register_security_filters(void) {} void grpc_security_init(void) {} + +void grpc_security_shutdown(void) {} diff --git a/src/core/plugin_registry/grpc_cronet_plugin_registry.cc b/src/core/plugin_registry/grpc_cronet_plugin_registry.cc index 49b9c7d4fe..84228cb38a 100644 --- a/src/core/plugin_registry/grpc_cronet_plugin_registry.cc +++ b/src/core/plugin_registry/grpc_cronet_plugin_registry.cc @@ -20,30 +20,50 @@ #include <grpc/grpc.h> -void grpc_http_filters_init(void); -void grpc_http_filters_shutdown(void); -void grpc_chttp2_plugin_init(void); -void grpc_chttp2_plugin_shutdown(void); void grpc_deadline_filter_init(void); void grpc_deadline_filter_shutdown(void); void grpc_client_channel_init(void); void grpc_client_channel_shutdown(void); -void grpc_tsi_alts_init(void); -void grpc_tsi_alts_shutdown(void); +void grpc_lb_policy_pick_first_init(void); +void grpc_lb_policy_pick_first_shutdown(void); +void grpc_max_age_filter_init(void); +void grpc_max_age_filter_shutdown(void); +void grpc_message_size_filter_init(void); +void grpc_message_size_filter_shutdown(void); +void grpc_resolver_dns_native_init(void); +void grpc_resolver_dns_native_shutdown(void); +void grpc_resolver_sockaddr_init(void); +void grpc_resolver_sockaddr_shutdown(void); void grpc_server_load_reporting_plugin_init(void); void grpc_server_load_reporting_plugin_shutdown(void); +void grpc_http_filters_init(void); +void grpc_http_filters_shutdown(void); +void grpc_chttp2_plugin_init(void); +void grpc_chttp2_plugin_shutdown(void); +void grpc_tsi_alts_init(void); +void grpc_tsi_alts_shutdown(void); void grpc_register_built_in_plugins(void) { - grpc_register_plugin(grpc_http_filters_init, - grpc_http_filters_shutdown); - grpc_register_plugin(grpc_chttp2_plugin_init, - grpc_chttp2_plugin_shutdown); grpc_register_plugin(grpc_deadline_filter_init, grpc_deadline_filter_shutdown); grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); - grpc_register_plugin(grpc_tsi_alts_init, - grpc_tsi_alts_shutdown); + grpc_register_plugin(grpc_lb_policy_pick_first_init, + grpc_lb_policy_pick_first_shutdown); + grpc_register_plugin(grpc_max_age_filter_init, + grpc_max_age_filter_shutdown); + grpc_register_plugin(grpc_message_size_filter_init, + grpc_message_size_filter_shutdown); + grpc_register_plugin(grpc_resolver_dns_native_init, + grpc_resolver_dns_native_shutdown); + grpc_register_plugin(grpc_resolver_sockaddr_init, + grpc_resolver_sockaddr_shutdown); grpc_register_plugin(grpc_server_load_reporting_plugin_init, grpc_server_load_reporting_plugin_shutdown); + grpc_register_plugin(grpc_http_filters_init, + grpc_http_filters_shutdown); + grpc_register_plugin(grpc_chttp2_plugin_init, + grpc_chttp2_plugin_shutdown); + grpc_register_plugin(grpc_tsi_alts_init, + grpc_tsi_alts_shutdown); } diff --git a/src/core/tsi/ssl/session_cache/ssl_session.h b/src/core/tsi/ssl/session_cache/ssl_session.h new file mode 100644 index 0000000000..115221ec06 --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session.h @@ -0,0 +1,73 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_H +#define GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_H + +#include <grpc/support/port_platform.h> + +#include <grpc/slice.h> + +extern "C" { +#include <openssl/ssl.h> +} + +#include "src/core/lib/gprpp/ref_counted.h" + +// The main purpose of code here is to provide means to cache SSL sessions +// in a way that they can be shared between connections. +// +// SSL_SESSION stands for single instance of session and is not generally safe +// to share between SSL contexts with different lifetimes. It happens because +// not all SSL implementations guarantee immutability of SSL_SESSION object. +// See SSL_SESSION documentation in BoringSSL and OpenSSL for more details. + +namespace tsi { + +struct SslSessionDeleter { + void operator()(SSL_SESSION* session) { SSL_SESSION_free(session); } +}; + +typedef std::unique_ptr<SSL_SESSION, SslSessionDeleter> SslSessionPtr; + +/// SslCachedSession is an immutable thread-safe storage for single session +/// representation. It provides means to share SSL session data (e.g. TLS +/// ticket) between encrypted connections regardless of SSL context lifetime. +class SslCachedSession { + public: + // Not copyable nor movable. + SslCachedSession(const SslCachedSession&) = delete; + SslCachedSession& operator=(const SslCachedSession&) = delete; + + /// Create single cached instance of \a session. + static grpc_core::UniquePtr<SslCachedSession> Create(SslSessionPtr session); + + virtual ~SslCachedSession() = default; + + /// Returns a copy of previously cached session. + virtual SslSessionPtr CopySession() const GRPC_ABSTRACT; + + GRPC_ABSTRACT_BASE_CLASS + + protected: + SslCachedSession() = default; +}; + +} // namespace tsi + +#endif /* GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_H */ diff --git a/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc b/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc new file mode 100644 index 0000000000..0da5a96164 --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc @@ -0,0 +1,58 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +#ifdef OPENSSL_IS_BORINGSSL + +// BoringSSL allows SSL_SESSION to outlive SSL and SSL_CTX objects which are +// re-created by gRPC on every certificate rotation or subchannel creation. +// BoringSSL guarantees that SSL_SESSION is immutable so it's safe to share +// the same original session object between different threads and connections. + +namespace tsi { +namespace { + +class BoringSslCachedSession : public SslCachedSession { + public: + BoringSslCachedSession(SslSessionPtr session) + : session_(std::move(session)) {} + + SslSessionPtr CopySession() const override { + // SslSessionPtr will dereference on destruction. + SSL_SESSION_up_ref(session_.get()); + return SslSessionPtr(session_.get()); + } + + private: + SslSessionPtr session_; +}; + +} // namespace + +grpc_core::UniquePtr<SslCachedSession> SslCachedSession::Create( + SslSessionPtr session) { + return grpc_core::UniquePtr<SslCachedSession>( + grpc_core::New<BoringSslCachedSession>(std::move(session))); +} + +} // namespace tsi + +#endif /* OPENSSL_IS_BORINGSSL */ diff --git a/src/core/tsi/ssl/session_cache/ssl_session_cache.cc b/src/core/tsi/ssl/session_cache/ssl_session_cache.cc new file mode 100644 index 0000000000..fe4f83a13d --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_cache.cc @@ -0,0 +1,211 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" + +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> + +namespace tsi { + +static void cache_key_avl_destroy(void* key, void* unused) {} + +static void* cache_key_avl_copy(void* key, void* unused) { return key; } + +static long cache_key_avl_compare(void* key1, void* key2, void* unused) { + return grpc_slice_cmp(*static_cast<grpc_slice*>(key1), + *static_cast<grpc_slice*>(key2)); +} + +static void cache_value_avl_destroy(void* value, void* unused) {} + +static void* cache_value_avl_copy(void* value, void* unused) { return value; } + +// AVL only stores pointers, ownership belonges to the linked list. +static const grpc_avl_vtable cache_avl_vtable = { + cache_key_avl_destroy, cache_key_avl_copy, cache_key_avl_compare, + cache_value_avl_destroy, cache_value_avl_copy, +}; + +/// Node for single cached session. +class SslSessionLRUCache::Node { + public: + Node(const grpc_slice& key, SslSessionPtr session) : key_(key) { + SetSession(std::move(session)); + } + + ~Node() { grpc_slice_unref(key_); } + + // Not copyable nor movable. + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + void* AvlKey() { return &key_; } + + /// Returns a copy of the node's cache session. + SslSessionPtr CopySession() const { return session_->CopySession(); } + + /// Set the \a session (which is moved) for the node. + void SetSession(SslSessionPtr session) { + session_ = SslCachedSession::Create(std::move(session)); + } + + private: + friend class SslSessionLRUCache; + + grpc_slice key_; + grpc_core::UniquePtr<SslCachedSession> session_; + + Node* next_ = nullptr; + Node* prev_ = nullptr; +}; + +SslSessionLRUCache::SslSessionLRUCache(size_t capacity) : capacity_(capacity) { + GPR_ASSERT(capacity > 0); + gpr_mu_init(&lock_); + entry_by_key_ = grpc_avl_create(&cache_avl_vtable); +} + +SslSessionLRUCache::~SslSessionLRUCache() { + Node* node = use_order_list_head_; + while (node) { + Node* next = node->next_; + grpc_core::Delete(node); + node = next; + } + grpc_avl_unref(entry_by_key_, nullptr); + gpr_mu_destroy(&lock_); +} + +size_t SslSessionLRUCache::Size() { + grpc_core::mu_guard guard(&lock_); + return use_order_list_size_; +} + +SslSessionLRUCache::Node* SslSessionLRUCache::FindLocked( + const grpc_slice& key) { + void* value = + grpc_avl_get(entry_by_key_, const_cast<grpc_slice*>(&key), nullptr); + if (value == nullptr) { + return nullptr; + } + Node* node = static_cast<Node*>(value); + // Move to the beginning. + Remove(node); + PushFront(node); + AssertInvariants(); + return node; +} + +void SslSessionLRUCache::Put(const char* key, SslSessionPtr session) { + grpc_core::mu_guard guard(&lock_); + Node* node = FindLocked(grpc_slice_from_static_string(key)); + if (node != nullptr) { + node->SetSession(std::move(session)); + return; + } + grpc_slice key_slice = grpc_slice_from_copied_string(key); + node = grpc_core::New<Node>(key_slice, std::move(session)); + PushFront(node); + entry_by_key_ = grpc_avl_add(entry_by_key_, node->AvlKey(), node, nullptr); + AssertInvariants(); + if (use_order_list_size_ > capacity_) { + GPR_ASSERT(use_order_list_tail_); + node = use_order_list_tail_; + Remove(node); + // Order matters, key is destroyed after deleting node. + entry_by_key_ = grpc_avl_remove(entry_by_key_, node->AvlKey(), nullptr); + grpc_core::Delete(node); + AssertInvariants(); + } +} + +SslSessionPtr SslSessionLRUCache::Get(const char* key) { + grpc_core::mu_guard guard(&lock_); + // Key is only used for lookups. + grpc_slice key_slice = grpc_slice_from_static_string(key); + Node* node = FindLocked(key_slice); + if (node == nullptr) { + return nullptr; + } + return node->CopySession(); +} + +void SslSessionLRUCache::Remove(SslSessionLRUCache::Node* node) { + if (node->prev_ == nullptr) { + use_order_list_head_ = node->next_; + } else { + node->prev_->next_ = node->next_; + } + if (node->next_ == nullptr) { + use_order_list_tail_ = node->prev_; + } else { + node->next_->prev_ = node->prev_; + } + GPR_ASSERT(use_order_list_size_ >= 1); + use_order_list_size_--; +} + +void SslSessionLRUCache::PushFront(SslSessionLRUCache::Node* node) { + if (use_order_list_head_ == nullptr) { + use_order_list_head_ = node; + use_order_list_tail_ = node; + node->next_ = nullptr; + node->prev_ = nullptr; + } else { + node->next_ = use_order_list_head_; + node->next_->prev_ = node; + use_order_list_head_ = node; + node->prev_ = nullptr; + } + use_order_list_size_++; +} + +#ifndef NDEBUG +static size_t calculate_tree_size(grpc_avl_node* node) { + if (node == nullptr) { + return 0; + } + return 1 + calculate_tree_size(node->left) + calculate_tree_size(node->right); +} + +void SslSessionLRUCache::AssertInvariants() { + size_t size = 0; + Node* prev = nullptr; + Node* current = use_order_list_head_; + while (current != nullptr) { + size++; + GPR_ASSERT(current->prev_ == prev); + void* node = grpc_avl_get(entry_by_key_, current->AvlKey(), nullptr); + GPR_ASSERT(node == current); + prev = current; + current = current->next_; + } + GPR_ASSERT(prev == use_order_list_tail_); + GPR_ASSERT(size == use_order_list_size_); + GPR_ASSERT(calculate_tree_size(entry_by_key_.root) == use_order_list_size_); +} +#else +void SslSessionLRUCache::AssertInvariants() {} +#endif + +} // namespace tsi diff --git a/src/core/tsi/ssl/session_cache/ssl_session_cache.h b/src/core/tsi/ssl/session_cache/ssl_session_cache.h new file mode 100644 index 0000000000..488638c9bb --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_cache.h @@ -0,0 +1,93 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_CACHE_H +#define GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_CACHE_H + +#include <grpc/support/port_platform.h> + +#include <grpc/slice.h> +#include <grpc/support/sync.h> + +extern "C" { +#include <openssl/ssl.h> +} + +#include "src/core/lib/avl/avl.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +/// Cache for SSL sessions for sessions resumption. +/// +/// Older sessions may be evicted from the cache using LRU policy if capacity +/// limit is hit. All sessions are associated with some key, usually server +/// name. Note that servers are required to share session ticket encryption keys +/// in order for cache to be effective. +/// +/// This class is thread safe. + +namespace tsi { + +class SslSessionLRUCache : public grpc_core::RefCounted<SslSessionLRUCache> { + public: + /// Create new LRU cache with the given capacity. + static grpc_core::RefCountedPtr<SslSessionLRUCache> Create(size_t capacity) { + return grpc_core::MakeRefCounted<SslSessionLRUCache>(capacity); + } + + // Not copyable nor movable. + SslSessionLRUCache(const SslSessionLRUCache&) = delete; + SslSessionLRUCache& operator=(const SslSessionLRUCache&) = delete; + + /// Returns current number of sessions in the cache. + size_t Size(); + /// Add \a session in the cache using \a key. This operation may discard older + /// sessions. + void Put(const char* key, SslSessionPtr session); + /// Returns the session from the cache associated with \a key or null if not + /// found. + SslSessionPtr Get(const char* key); + + private: + // So New() can call our private ctor. + template <typename T, typename... Args> + friend T* grpc_core::New(Args&&... args); + + class Node; + + explicit SslSessionLRUCache(size_t capacity); + ~SslSessionLRUCache(); + + Node* FindLocked(const grpc_slice& key); + void Remove(Node* node); + void PushFront(Node* node); + void AssertInvariants(); + + gpr_mu lock_; + size_t capacity_; + + Node* use_order_list_head_ = nullptr; + Node* use_order_list_tail_ = nullptr; + size_t use_order_list_size_ = 0; + grpc_avl entry_by_key_; +}; + +} // namespace tsi + +#endif /* GRPC_CORE_TSI_SSL_SESSION_CACHE_SSL_SESSION_CACHE_H */ diff --git a/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc b/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc new file mode 100644 index 0000000000..61c036c7eb --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +#include <grpc/support/log.h> + +#ifndef OPENSSL_IS_BORINGSSL + +// OpenSSL invalidates SSL_SESSION on SSL destruction making it pointless +// to cache sessions. The workaround is to serialize (relatively expensive) +// session into binary blob and re-create it from blob on every handshake. +// Note that it's safe to keep serialized session outside of SSL lifetime +// as openssl performs all necessary validation while attempting to use a +// session and creates a new one if something is wrong (e.g. server changed +// set of allowed codecs). + +namespace tsi { +namespace { + +class OpenSslCachedSession : public SslCachedSession { + public: + OpenSslCachedSession(SslSessionPtr session) { + int size = i2d_SSL_SESSION(session.get(), nullptr); + GPR_ASSERT(size > 0); + grpc_slice slice = grpc_slice_malloc(size_t(size)); + unsigned char* start = GRPC_SLICE_START_PTR(slice); + int second_size = i2d_SSL_SESSION(session.get(), &start); + GPR_ASSERT(size == second_size); + serialized_session_ = slice; + } + + virtual ~OpenSslCachedSession() { grpc_slice_unref(serialized_session_); } + + SslSessionPtr CopySession() const override { + const unsigned char* data = GRPC_SLICE_START_PTR(serialized_session_); + size_t length = GRPC_SLICE_LENGTH(serialized_session_); + SSL_SESSION* session = d2i_SSL_SESSION(nullptr, &data, length); + if (session == nullptr) { + return SslSessionPtr(); + } + return SslSessionPtr(session); + } + + private: + grpc_slice serialized_session_; +}; + +} // namespace + +grpc_core::UniquePtr<SslCachedSession> SslCachedSession::Create( + SslSessionPtr session) { + return grpc_core::UniquePtr<SslCachedSession>( + grpc_core::New<OpenSslCachedSession>(std::move(session))); +} + +} // namespace tsi + +#endif /* OPENSSL_IS_BORINGSSL */ diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index 971170b7c5..0ba6587678 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -35,6 +35,7 @@ #include <grpc/support/alloc.h> #include <grpc/support/log.h> +#include <grpc/support/string_util.h> #include <grpc/support/sync.h> #include <grpc/support/thd_id.h> @@ -47,6 +48,8 @@ extern "C" { #include <openssl/x509v3.h> } +#include "src/core/lib/gpr/useful.h" +#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" #include "src/core/tsi/ssl_types.h" #include "src/core/tsi/transport_security.h" @@ -68,6 +71,10 @@ extern "C" { /* --- Structure definitions. ---*/ +struct tsi_ssl_root_certs_store { + X509_STORE* store; +}; + struct tsi_ssl_handshaker_factory { const tsi_ssl_handshaker_factory_vtable* vtable; gpr_refcount refcount; @@ -78,6 +85,7 @@ struct tsi_ssl_client_handshaker_factory { SSL_CTX* ssl_context; unsigned char* alpn_protocol_list; size_t alpn_protocol_list_length; + grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache; }; struct tsi_ssl_server_handshaker_factory { @@ -111,17 +119,19 @@ typedef struct { /* --- Library Initialization. ---*/ -static gpr_once init_openssl_once = GPR_ONCE_INIT; -static gpr_mu* openssl_mutexes = nullptr; +static gpr_once g_init_openssl_once = GPR_ONCE_INIT; +static gpr_mu* g_openssl_mutexes = nullptr; +static int g_ssl_ctx_ex_factory_index = -1; static void openssl_locking_cb(int mode, int type, const char* file, int line) GRPC_UNUSED; static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED; +static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'}; static void openssl_locking_cb(int mode, int type, const char* file, int line) { if (mode & CRYPTO_LOCK) { - gpr_mu_lock(&openssl_mutexes[type]); + gpr_mu_lock(&g_openssl_mutexes[type]); } else { - gpr_mu_unlock(&openssl_mutexes[type]); + gpr_mu_unlock(&g_openssl_mutexes[type]); } } @@ -137,13 +147,16 @@ static void init_openssl(void) { OpenSSL_add_all_algorithms(); num_locks = CRYPTO_num_locks(); GPR_ASSERT(num_locks > 0); - openssl_mutexes = static_cast<gpr_mu*>( + g_openssl_mutexes = static_cast<gpr_mu*>( gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu))); for (i = 0; i < CRYPTO_num_locks(); i++) { - gpr_mu_init(&openssl_mutexes[i]); + gpr_mu_init(&g_openssl_mutexes[i]); } CRYPTO_set_locking_callback(openssl_locking_cb); CRYPTO_set_id_callback(openssl_thread_id_cb); + g_ssl_ctx_ex_factory_index = + SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1); } /* --- Ssl utils. ---*/ @@ -544,21 +557,18 @@ static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key, /* Loads in-memory PEM verification certs into the SSL context and optionally returns the verification cert names (root_names can be NULL). */ -static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context, - const char* pem_roots, - size_t pem_roots_size, - STACK_OF(X509_NAME) * - *root_names) { +static tsi_result x509_store_load_certs(X509_STORE* cert_store, + const char* pem_roots, + size_t pem_roots_size, + STACK_OF(X509_NAME) * *root_names) { tsi_result result = TSI_OK; size_t num_roots = 0; X509* root = nullptr; X509_NAME* root_name = nullptr; BIO* pem; - X509_STORE* root_store; GPR_ASSERT(pem_roots_size <= INT_MAX); pem = BIO_new_mem_buf((void*)pem_roots, static_cast<int>(pem_roots_size)); - root_store = SSL_CTX_get_cert_store(context); - if (root_store == nullptr) return TSI_INVALID_ARGUMENT; + if (cert_store == nullptr) return TSI_INVALID_ARGUMENT; if (pem == nullptr) return TSI_OUT_OF_RESOURCES; if (root_names != nullptr) { *root_names = sk_X509_NAME_new_null(); @@ -586,7 +596,7 @@ static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context, sk_X509_NAME_push(*root_names, root_name); root_name = nullptr; } - if (!X509_STORE_add_cert(root_store, root)) { + if (!X509_STORE_add_cert(cert_store, root)) { gpr_log(GPR_ERROR, "Could not add root certificate to ssl context."); result = TSI_INTERNAL_ERROR; break; @@ -612,6 +622,16 @@ static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context, return result; } +static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context, + const char* pem_roots, + size_t pem_roots_size, + STACK_OF(X509_NAME) * + *root_name) { + X509_STORE* cert_store = SSL_CTX_get_cert_store(context); + return x509_store_load_certs(cert_store, pem_roots, pem_roots_size, + root_name); +} + /* Populates the SSL context with a private key and a cert chain, and sets the cipher list and the ephemeral ECDH key. */ static tsi_result populate_ssl_context( @@ -721,6 +741,60 @@ static int NullVerifyCallback(int preverify_ok, X509_STORE_CTX* ctx) { return 1; } +/* --- tsi_ssl_root_certs_store methods implementation. ---*/ + +tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create( + const char* pem_roots) { + if (pem_roots == nullptr) { + gpr_log(GPR_ERROR, "The root certificates are empty."); + return nullptr; + } + tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>( + gpr_zalloc(sizeof(tsi_ssl_root_certs_store))); + if (root_store == nullptr) { + gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store."); + return nullptr; + } + root_store->store = X509_STORE_new(); + if (root_store->store == nullptr) { + gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE."); + gpr_free(root_store); + return nullptr; + } + tsi_result result = x509_store_load_certs(root_store->store, pem_roots, + strlen(pem_roots), nullptr); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Could not load root certificates."); + X509_STORE_free(root_store->store); + gpr_free(root_store); + return nullptr; + } + return root_store; +} + +void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) { + if (self == nullptr) return; + X509_STORE_free(self->store); + gpr_free(self); +} + +/* --- tsi_ssl_session_cache methods implementation. ---*/ + +tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) { + /* Pointer will be dereferenced by unref call. */ + return reinterpret_cast<tsi_ssl_session_cache*>( + tsi::SslSessionLRUCache::Create(capacity).release()); +} + +void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) { + /* Pointer will be dereferenced by unref call. */ + reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Ref().release(); +} + +void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) { + reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Unref(); +} + /* --- tsi_frame_protector methods implementation. ---*/ static tsi_result ssl_protector_protect(tsi_frame_protector* self, @@ -1015,25 +1089,34 @@ static tsi_result ssl_handshaker_extract_peer(tsi_handshaker* self, SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected, &alpn_selected_len); } + + // 1 is for session reused property. + size_t new_property_count = peer->property_count + 1; + if (alpn_selected != nullptr) new_property_count++; + tsi_peer_property* new_properties = static_cast<tsi_peer_property*>( + gpr_zalloc(sizeof(*new_properties) * new_property_count)); + for (size_t i = 0; i < peer->property_count; i++) { + new_properties[i] = peer->properties[i]; + } + if (peer->properties != nullptr) gpr_free(peer->properties); + peer->properties = new_properties; + if (alpn_selected != nullptr) { - size_t i; - tsi_peer_property* new_properties = static_cast<tsi_peer_property*>( - gpr_zalloc(sizeof(*new_properties) * (peer->property_count + 1))); - for (i = 0; i < peer->property_count; i++) { - new_properties[i] = peer->properties[i]; - } result = tsi_construct_string_peer_property( TSI_SSL_ALPN_SELECTED_PROTOCOL, reinterpret_cast<const char*>(alpn_selected), alpn_selected_len, - &new_properties[peer->property_count]); - if (result != TSI_OK) { - gpr_free(new_properties); - return result; - } - if (peer->properties != nullptr) gpr_free(peer->properties); + &peer->properties[peer->property_count]); + if (result != TSI_OK) return result; peer->property_count++; - peer->properties = new_properties; } + + const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false"; + result = tsi_construct_string_peer_property( + TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused, + strlen(session_reused) + 1, &peer->properties[peer->property_count]); + if (result != TSI_OK) return result; + peer->property_count++; + return result; } @@ -1103,6 +1186,19 @@ static const tsi_handshaker_vtable handshaker_vtable = { /* --- tsi_ssl_handshaker_factory common methods. --- */ +static void tsi_ssl_handshaker_resume_session( + SSL* ssl, tsi::SslSessionLRUCache* session_cache) { + const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (server_name == nullptr) { + return; + } + tsi::SslSessionPtr session = session_cache->Get(server_name); + if (session != nullptr) { + // SSL_set_session internally increments reference counter. + SSL_set_session(ssl, session.get()); + } +} + static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client, const char* server_name_indication, tsi_ssl_handshaker_factory* factory, @@ -1139,6 +1235,12 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client, return TSI_INTERNAL_ERROR; } } + tsi_ssl_client_handshaker_factory* client_factory = + reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory); + if (client_factory->session_cache != nullptr) { + tsi_ssl_handshaker_resume_session(ssl, + client_factory->session_cache.get()); + } ssl_result = SSL_do_handshake(ssl); ssl_result = SSL_get_error(ssl, ssl_result); if (ssl_result != SSL_ERROR_WANT_READ) { @@ -1214,6 +1316,7 @@ static void tsi_ssl_client_handshaker_factory_destroy( reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory); if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context); if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list); + self->session_cache.reset(); gpr_free(self); } @@ -1357,6 +1460,30 @@ static int server_handshaker_factory_npn_advertised_callback( return SSL_TLSEXT_ERR_OK; } +/// This callback is called when new \a session is established and ready to +/// be cached. This session can be reused for new connections to similar +/// servers at later point of time. +/// It's intended to be used with SSL_CTX_sess_set_new_cb function. +/// +/// It returns 1 if callback takes ownership over \a session and 0 otherwise. +static int server_handshaker_factory_new_session_callback( + SSL* ssl, SSL_SESSION* session) { + SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl); + if (ssl_context == nullptr) { + return 0; + } + void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index); + tsi_ssl_client_handshaker_factory* factory = + static_cast<tsi_ssl_client_handshaker_factory*>(arg); + const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (server_name == nullptr) { + return 0; + } + factory->session_cache->Put(server_name, tsi::SslSessionPtr(session)); + // Return 1 to indicate transfered ownership over the given session. + return 1; +} + /* --- tsi_ssl_handshaker_factory constructors. --- */ static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = { @@ -1367,15 +1494,31 @@ tsi_result tsi_create_ssl_client_handshaker_factory( const char* pem_root_certs, const char* cipher_suites, const char** alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_client_handshaker_factory** factory) { + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_key_cert_pair = pem_key_cert_pair; + options.pem_root_certs = pem_root_certs; + options.cipher_suites = cipher_suites; + options.alpn_protocols = alpn_protocols; + options.num_alpn_protocols = num_alpn_protocols; + return tsi_create_ssl_client_handshaker_factory_with_options(&options, + factory); +} + +tsi_result tsi_create_ssl_client_handshaker_factory_with_options( + const tsi_ssl_client_handshaker_options* options, + tsi_ssl_client_handshaker_factory** factory) { SSL_CTX* ssl_context = nullptr; tsi_ssl_client_handshaker_factory* impl = nullptr; tsi_result result = TSI_OK; - gpr_once_init(&init_openssl_once, init_openssl); + gpr_once_init(&g_init_openssl_once, init_openssl); if (factory == nullptr) return TSI_INVALID_ARGUMENT; *factory = nullptr; - if (pem_root_certs == nullptr) return TSI_INVALID_ARGUMENT; + if (options->pem_root_certs == nullptr && options->root_store == nullptr) { + return TSI_INVALID_ARGUMENT; + } ssl_context = SSL_CTX_new(TLSv1_2_method()); if (ssl_context == nullptr) { @@ -1387,24 +1530,44 @@ tsi_result tsi_create_ssl_client_handshaker_factory( gpr_zalloc(sizeof(*impl))); tsi_ssl_handshaker_factory_init(&impl->base); impl->base.vtable = &client_handshaker_factory_vtable; - impl->ssl_context = ssl_context; + if (options->session_cache != nullptr) { + // Unref is called manually on factory destruction. + impl->session_cache = + reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache) + ->Ref(); + SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl); + SSL_CTX_sess_set_new_cb(ssl_context, + server_handshaker_factory_new_session_callback); + SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT); + } do { - result = - populate_ssl_context(ssl_context, pem_key_cert_pair, cipher_suites); + result = populate_ssl_context(ssl_context, options->pem_key_cert_pair, + options->cipher_suites); if (result != TSI_OK) break; - result = ssl_ctx_load_verification_certs(ssl_context, pem_root_certs, - strlen(pem_root_certs), nullptr); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Cannot load server root certificates."); - break; + +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + // X509_STORE_up_ref is only available since OpenSSL 1.1. + if (options->root_store != nullptr) { + X509_STORE_up_ref(options->root_store->store); + SSL_CTX_set_cert_store(ssl_context, options->root_store->store); + } +#endif + if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) { + result = ssl_ctx_load_verification_certs( + ssl_context, options->pem_root_certs, strlen(options->pem_root_certs), + nullptr); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Cannot load server root certificates."); + break; + } } - if (num_alpn_protocols != 0) { - result = build_alpn_protocol_name_list(alpn_protocols, num_alpn_protocols, - &impl->alpn_protocol_list, - &impl->alpn_protocol_list_length); + if (options->num_alpn_protocols != 0) { + result = build_alpn_protocol_name_list( + options->alpn_protocols, options->num_alpn_protocols, + &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Building alpn list failed with error %s.", tsi_result_to_string(result)); @@ -1457,15 +1620,32 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( tsi_client_certificate_request_type client_certificate_request, const char* cipher_suites, const char** alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) { + tsi_ssl_server_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_key_cert_pairs = pem_key_cert_pairs; + options.num_key_cert_pairs = num_key_cert_pairs; + options.pem_client_root_certs = pem_client_root_certs; + options.client_certificate_request = client_certificate_request; + options.cipher_suites = cipher_suites; + options.alpn_protocols = alpn_protocols; + options.num_alpn_protocols = num_alpn_protocols; + return tsi_create_ssl_server_handshaker_factory_with_options(&options, + factory); +} + +tsi_result tsi_create_ssl_server_handshaker_factory_with_options( + const tsi_ssl_server_handshaker_options* options, + tsi_ssl_server_handshaker_factory** factory) { tsi_ssl_server_handshaker_factory* impl = nullptr; tsi_result result = TSI_OK; size_t i = 0; - gpr_once_init(&init_openssl_once, init_openssl); + gpr_once_init(&g_init_openssl_once, init_openssl); if (factory == nullptr) return TSI_INVALID_ARGUMENT; *factory = nullptr; - if (num_key_cert_pairs == 0 || pem_key_cert_pairs == nullptr) { + if (options->num_key_cert_pairs == 0 || + options->pem_key_cert_pairs == nullptr) { return TSI_INVALID_ARGUMENT; } @@ -1474,28 +1654,28 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( tsi_ssl_handshaker_factory_init(&impl->base); impl->base.vtable = &server_handshaker_factory_vtable; - impl->ssl_contexts = - static_cast<SSL_CTX**>(gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX*))); - impl->ssl_context_x509_subject_names = - static_cast<tsi_peer*>(gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer))); + impl->ssl_contexts = static_cast<SSL_CTX**>( + gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*))); + impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>( + gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer))); if (impl->ssl_contexts == nullptr || impl->ssl_context_x509_subject_names == nullptr) { tsi_ssl_handshaker_factory_unref(&impl->base); return TSI_OUT_OF_RESOURCES; } - impl->ssl_context_count = num_key_cert_pairs; + impl->ssl_context_count = options->num_key_cert_pairs; - if (num_alpn_protocols > 0) { - result = build_alpn_protocol_name_list(alpn_protocols, num_alpn_protocols, - &impl->alpn_protocol_list, - &impl->alpn_protocol_list_length); + if (options->num_alpn_protocols > 0) { + result = build_alpn_protocol_name_list( + options->alpn_protocols, options->num_alpn_protocols, + &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); if (result != TSI_OK) { tsi_ssl_handshaker_factory_unref(&impl->base); return result; } } - for (i = 0; i < num_key_cert_pairs; i++) { + for (i = 0; i < options->num_key_cert_pairs; i++) { do { impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method()); if (impl->ssl_contexts[i] == nullptr) { @@ -1504,20 +1684,44 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( break; } result = populate_ssl_context(impl->ssl_contexts[i], - &pem_key_cert_pairs[i], cipher_suites); + &options->pem_key_cert_pairs[i], + options->cipher_suites); if (result != TSI_OK) break; - if (pem_client_root_certs != nullptr) { + // TODO(elessar): Provide ability to disable session ticket keys. + + // Allow client cache sessions (it's needed for OpenSSL only). + int set_sid_ctx_result = SSL_CTX_set_session_id_context( + impl->ssl_contexts[i], kSslSessionIdContext, + GPR_ARRAY_SIZE(kSslSessionIdContext)); + if (set_sid_ctx_result == 0) { + gpr_log(GPR_ERROR, "Failed to set session id context."); + result = TSI_INTERNAL_ERROR; + break; + } + + if (options->session_ticket_key != nullptr) { + if (SSL_CTX_set_tlsext_ticket_keys( + impl->ssl_contexts[i], + const_cast<char*>(options->session_ticket_key), + options->session_ticket_key_size) == 0) { + gpr_log(GPR_ERROR, "Invalid STEK size."); + result = TSI_INVALID_ARGUMENT; + break; + } + } + + if (options->pem_client_root_certs != nullptr) { STACK_OF(X509_NAME)* root_names = nullptr; result = ssl_ctx_load_verification_certs( - impl->ssl_contexts[i], pem_client_root_certs, - strlen(pem_client_root_certs), &root_names); + impl->ssl_contexts[i], options->pem_client_root_certs, + strlen(options->pem_client_root_certs), &root_names); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Invalid verification certs."); break; } SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names); - switch (client_certificate_request) { + switch (options->client_certificate_request) { case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); break; @@ -1544,7 +1748,7 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( } result = extract_x509_subject_names_from_pem_cert( - pem_key_cert_pairs[i].cert_chain, + options->pem_key_cert_pairs[i].cert_chain, &impl->ssl_context_x509_subject_names[i]); if (result != TSI_OK) break; diff --git a/src/core/tsi/ssl_transport_security.h b/src/core/tsi/ssl_transport_security.h index edebadc1be..cabf583098 100644 --- a/src/core/tsi/ssl_transport_security.h +++ b/src/core/tsi/ssl_transport_security.h @@ -30,11 +30,41 @@ #define TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY "x509_subject_common_name" #define TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY \ "x509_subject_alternative_name" +#define TSI_SSL_SESSION_REUSED_PEER_PROPERTY "ssl_session_reused" #define TSI_X509_PEM_CERT_PROPERTY "x509_pem_cert" #define TSI_SSL_ALPN_SELECTED_PROTOCOL "ssl_alpn_selected_protocol" +/* --- tsi_ssl_root_certs_store object --- + + This object stores SSL root certificates. It can be shared by multiple SSL + context. */ +typedef struct tsi_ssl_root_certs_store tsi_ssl_root_certs_store; + +/* Given a NULL-terminated string containing the PEM encoding of the root + certificates, creates a tsi_ssl_root_certs_store object. */ +tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create( + const char* pem_roots); + +/* Destroys the tsi_ssl_root_certs_store object. */ +void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self); + +/* --- tsi_ssl_session_cache object --- + + Cache for SSL sessions for sessions resumption. */ + +typedef struct tsi_ssl_session_cache tsi_ssl_session_cache; + +/* Create LRU cache for SSL sessions with \a capacity. */ +tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity); + +/* Increment reference counter of \a cache. */ +void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache); + +/* Decrement reference counter of \a cache. */ +void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache); + /* --- tsi_ssl_client_handshaker_factory object --- This object creates a client tsi_handshaker objects implemented in terms of @@ -54,13 +84,13 @@ typedef struct { const char* cert_chain; } tsi_ssl_pem_key_cert_pair; -/* Creates a client handshaker factory. +/* TO BE DEPRECATED. + Creates a client handshaker factory. - pem_key_cert_pair is a pointer to the object containing client's private key and certificate chain. This parameter can be NULL if the client does not have such a key/cert pair. - pem_roots_cert is the NULL-terminated string containing the PEM encoding of - the client root certificates. This parameter may be NULL if the server does - not want the client to be authenticated with SSL. + the server root certificates. - cipher_suites contains an optional list of the ciphers that the client supports. The format of this string is described in: https://www.openssl.org/docs/apps/ciphers.html. @@ -81,6 +111,47 @@ tsi_result tsi_create_ssl_client_handshaker_factory( const char** alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_client_handshaker_factory** factory); +typedef struct { + /* pem_key_cert_pair is a pointer to the object containing client's private + key and certificate chain. This parameter can be NULL if the client does + not have such a key/cert pair. */ + const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair; + /* pem_roots_cert is the NULL-terminated string containing the PEM encoding of + the client root certificates. */ + const char* pem_root_certs; + /* root_store is a pointer to the ssl_root_certs_store object. If root_store + is not nullptr and SSL implementation permits, root_store will be used as + root certificates. Otherwise, pem_roots_cert will be used to load server + root certificates. */ + const tsi_ssl_root_certs_store* root_store; + /* cipher_suites contains an optional list of the ciphers that the client + supports. The format of this string is described in: + https://www.openssl.org/docs/apps/ciphers.html. + This parameter can be set to NULL to use the default set of ciphers. + TODO(jboeuf): Revisit the format of this parameter. */ + const char* cipher_suites; + /* alpn_protocols is an array containing the NULL terminated protocol names + that the handshakers created with this factory support. This parameter can + be NULL. */ + const char** alpn_protocols; + /* num_alpn_protocols is the number of alpn protocols and associated lengths + specified. If this parameter is 0, the other alpn parameters must be + NULL. */ + size_t num_alpn_protocols; + /* ssl_session_cache is a cache for reusable client-side sessions. */ + tsi_ssl_session_cache* session_cache; +} tsi_ssl_client_handshaker_options; + +/* Creates a client handshaker factory. + - options is the options used to create a factory. + - factory is the address of the factory pointer to be created. + + - This method returns TSI_OK on success or TSI_INVALID_PARAMETER in the case + where a parameter is invalid. */ +tsi_result tsi_create_ssl_client_handshaker_factory_with_options( + const tsi_ssl_client_handshaker_options* options, + tsi_ssl_client_handshaker_factory** factory); + /* Creates a client handshaker. - self is the factory from which the handshaker will be created. - server_name_indication indicates the name of the server the client is @@ -107,12 +178,14 @@ void tsi_ssl_client_handshaker_factory_unref( typedef struct tsi_ssl_server_handshaker_factory tsi_ssl_server_handshaker_factory; -/* Creates a server handshaker factory. +/* TO BE DEPRECATED. + Creates a server handshaker factory. - pem_key_cert_pairs is an array private key / certificate chains of the server. - num_key_cert_pairs is the number of items in the pem_key_cert_pairs array. - pem_root_certs is the NULL-terminated string containing the PEM encoding - of the server root certificates. + of the client root certificates. This parameter may be NULL if the server + does not want the client to be authenticated with SSL. - cipher_suites contains an optional list of the ciphers that the server supports. The format of this string is described in: https://www.openssl.org/docs/apps/ciphers.html. @@ -134,7 +207,8 @@ tsi_result tsi_create_ssl_server_handshaker_factory( const char** alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory); -/* Same as tsi_create_ssl_server_handshaker_factory method except uses +/* TO BE DEPRECATED. + Same as tsi_create_ssl_server_handshaker_factory method except uses tsi_client_certificate_request_type to support more ways to handle client certificate authentication. - client_certificate_request, if set to non-zero will force the client to @@ -147,6 +221,52 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( const char* cipher_suites, const char** alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory); +typedef struct { + /* pem_key_cert_pairs is an array private key / certificate chains of the + server. */ + const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs; + /* num_key_cert_pairs is the number of items in the pem_key_cert_pairs + array. */ + size_t num_key_cert_pairs; + /* pem_root_certs is the NULL-terminated string containing the PEM encoding + of the server root certificates. This parameter may be NULL if the server + does not want the client to be authenticated with SSL. */ + const char* pem_client_root_certs; + /* client_certificate_request, if set to non-zero will force the client to + authenticate with an SSL cert. Note that this option is ignored if + pem_client_root_certs is NULL or pem_client_roots_certs_size is 0. */ + tsi_client_certificate_request_type client_certificate_request; + /* cipher_suites contains an optional list of the ciphers that the server + supports. The format of this string is described in: + https://www.openssl.org/docs/apps/ciphers.html. + This parameter can be set to NULL to use the default set of ciphers. + TODO(jboeuf): Revisit the format of this parameter. */ + const char* cipher_suites; + /* alpn_protocols is an array containing the NULL terminated protocol names + that the handshakers created with this factory support. This parameter can + be NULL. */ + const char** alpn_protocols; + /* num_alpn_protocols is the number of alpn protocols and associated lengths + specified. If this parameter is 0, the other alpn parameters must be + NULL. */ + uint16_t num_alpn_protocols; + /* session_ticket_key is optional key for encrypting session keys. If paramter + is not specified it must be NULL. */ + const char* session_ticket_key; + /* session_ticket_key_size is a size of session ticket encryption key. */ + size_t session_ticket_key_size; +} tsi_ssl_server_handshaker_options; + +/* Creates a server handshaker factory. + - options is the options used to create a factory. + - factory is the address of the factory pointer to be created. + + - This method returns TSI_OK on success or TSI_INVALID_PARAMETER in the case + where a parameter is invalid. */ +tsi_result tsi_create_ssl_server_handshaker_factory_with_options( + const tsi_ssl_server_handshaker_options* options, + tsi_ssl_server_handshaker_factory** factory); + /* Creates a server handshaker. - self is the factory from which the handshaker will be created. - handshaker is the address of the handshaker pointer to be created. |