diff options
author | 2017-09-07 15:54:59 -0700 | |
---|---|---|
committer | 2017-09-07 15:54:59 -0700 | |
commit | 56969c2784911583825d75212ca027806d1de056 (patch) | |
tree | 35ccaa433f197d7977533e24210ddefb7632a12e /test | |
parent | f0ba70a9ea92b6c7422f40206e763141c05eb281 (diff) | |
parent | 41630a29507c8dd5b6110f0397b346b7feab442b (diff) |
Merge github.com:grpc/grpc into pollset_kick_stats
Diffstat (limited to 'test')
28 files changed, 1836 insertions, 131 deletions
diff --git a/test/core/channel/BUILD b/test/core/channel/BUILD index ef861cc5e7..5ac77c449b 100644 --- a/test/core/channel/BUILD +++ b/test/core/channel/BUILD @@ -41,3 +41,15 @@ grpc_cc_test( "//test/core/util:grpc_test_util", ], ) + +grpc_cc_test( + name = "channel_stack_builder_test", + srcs = ["channel_stack_builder_test.c"], + language = "C", + deps = [ + "//:gpr", + "//:grpc", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + ], +) diff --git a/test/core/channel/channel_stack_builder_test.c b/test/core/channel/channel_stack_builder_test.c new file mode 100644 index 0000000000..be6afb7c07 --- /dev/null +++ b/test/core/channel/channel_stack_builder_test.c @@ -0,0 +1,152 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channel_stack_builder.h" + +#include <limits.h> +#include <string.h> + +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> + +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/util/test_config.h" + +static grpc_error *channel_init_func(grpc_exec_ctx *exec_ctx, + grpc_channel_element *elem, + grpc_channel_element_args *args) { + return GRPC_ERROR_NONE; +} + +static grpc_error *call_init_func(grpc_exec_ctx *exec_ctx, + grpc_call_element *elem, + const grpc_call_element_args *args) { + return GRPC_ERROR_NONE; +} + +static void channel_destroy_func(grpc_exec_ctx *exec_ctx, + grpc_channel_element *elem) {} + +static void call_destroy_func(grpc_exec_ctx *exec_ctx, grpc_call_element *elem, + const grpc_call_final_info *final_info, + grpc_closure *ignored) {} + +static void call_func(grpc_exec_ctx *exec_ctx, grpc_call_element *elem, + grpc_transport_stream_op_batch *op) {} + +static void channel_func(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, + grpc_transport_op *op) { + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(op->disconnect_with_error); + } + GRPC_CLOSURE_SCHED(exec_ctx, op->on_consumed, GRPC_ERROR_NONE); +} + +static char *get_peer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem) { + return gpr_strdup("peer"); +} + +bool g_replacement_fn_called = false; +bool g_original_fn_called = false; +void set_arg_once_fn(grpc_channel_stack *channel_stack, + grpc_channel_element *elem, void *arg) { + bool *called = arg; + // Make sure this function is only called once per arg. + GPR_ASSERT(*called == false); + *called = true; +} + +static void test_channel_stack_builder_filter_replace(void) { + grpc_channel *channel = + grpc_insecure_channel_create("target name isn't used", NULL, NULL); + GPR_ASSERT(channel != NULL); + // Make sure the high priority filter has been created. + GPR_ASSERT(g_replacement_fn_called); + // ... and that the low priority one hasn't. + GPR_ASSERT(!g_original_fn_called); + grpc_channel_destroy(channel); +} + +const grpc_channel_filter replacement_filter = { + call_func, + channel_func, + 0, + call_init_func, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + call_destroy_func, + 0, + channel_init_func, + channel_destroy_func, + get_peer, + grpc_channel_next_get_info, + "filter_name"}; + +const grpc_channel_filter original_filter = { + call_func, + channel_func, + 0, + call_init_func, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + call_destroy_func, + 0, + channel_init_func, + channel_destroy_func, + get_peer, + grpc_channel_next_get_info, + "filter_name"}; + +static bool add_replacement_filter(grpc_exec_ctx *exec_ctx, + grpc_channel_stack_builder *builder, + void *arg) { + const grpc_channel_filter *filter = arg; + // Get rid of any other version of the filter, as determined by having the + // same name. + GPR_ASSERT(grpc_channel_stack_builder_remove_filter(builder, filter->name)); + return grpc_channel_stack_builder_prepend_filter( + builder, filter, set_arg_once_fn, &g_replacement_fn_called); +} + +static bool add_original_filter(grpc_exec_ctx *exec_ctx, + grpc_channel_stack_builder *builder, + void *arg) { + return grpc_channel_stack_builder_prepend_filter( + builder, (const grpc_channel_filter *)arg, set_arg_once_fn, + &g_original_fn_called); +} + +static void init_plugin(void) { + grpc_channel_init_register_stage(GRPC_CLIENT_CHANNEL, INT_MAX, + add_original_filter, + (void *)&original_filter); + grpc_channel_init_register_stage(GRPC_CLIENT_CHANNEL, INT_MAX, + add_replacement_filter, + (void *)&replacement_filter); +} + +static void destroy_plugin(void) {} + +int main(int argc, char **argv) { + grpc_test_init(argc, argv); + grpc_register_plugin(init_plugin, destroy_plugin); + grpc_init(); + test_channel_stack_builder_filter_replace(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/channel/channel_stack_test.c b/test/core/channel/channel_stack_test.c index 0c4bae6ded..7c3614b4a2 100644 --- a/test/core/channel/channel_stack_test.c +++ b/test/core/channel/channel_stack_test.c @@ -67,10 +67,6 @@ static void channel_func(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, ++*(int *)(elem->channel_data); } -static char *get_peer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem) { - return gpr_strdup("peer"); -} - static void free_channel(grpc_exec_ctx *exec_ctx, void *arg, grpc_error *error) { grpc_channel_stack_destroy(exec_ctx, arg); @@ -93,7 +89,6 @@ static void test_create_channel_stack(void) { sizeof(int), channel_init_func, channel_destroy_func, - get_peer, grpc_channel_next_get_info, "some_test_filter"}; const grpc_channel_filter *filters = &filter; diff --git a/test/core/channel/minimal_stack_is_minimal_test.c b/test/core/channel/minimal_stack_is_minimal_test.c index c99b54c6ac..b4528346f7 100644 --- a/test/core/channel/minimal_stack_is_minimal_test.c +++ b/test/core/channel/minimal_stack_is_minimal_test.c @@ -89,14 +89,14 @@ int main(int argc, char **argv) { "connected", NULL); errors += CHECK_STACK("unknown", NULL, GRPC_SERVER_CHANNEL, "server", "message_size", "deadline", "connected", NULL); - errors += - CHECK_STACK("chttp2", NULL, GRPC_CLIENT_DIRECT_CHANNEL, "message_size", - "deadline", "http-client", "compress", "connected", NULL); + errors += CHECK_STACK("chttp2", NULL, GRPC_CLIENT_DIRECT_CHANNEL, + "message_size", "deadline", "http-client", + "message_compress", "connected", NULL); errors += CHECK_STACK("chttp2", NULL, GRPC_CLIENT_SUBCHANNEL, "message_size", - "http-client", "compress", "connected", NULL); - errors += - CHECK_STACK("chttp2", NULL, GRPC_SERVER_CHANNEL, "server", "message_size", - "deadline", "http-server", "compress", "connected", NULL); + "http-client", "message_compress", "connected", NULL); + errors += CHECK_STACK("chttp2", NULL, GRPC_SERVER_CHANNEL, "server", + "message_size", "deadline", "http-server", + "message_compress", "connected", NULL); errors += CHECK_STACK(NULL, NULL, GRPC_CLIENT_CHANNEL, "client-channel", NULL); diff --git a/test/core/debug/stats_test.cc b/test/core/debug/stats_test.cc index 65ccc7a5c8..c85ab3598a 100644 --- a/test/core/debug/stats_test.cc +++ b/test/core/debug/stats_test.cc @@ -69,34 +69,32 @@ static int FindExpectedBucket(int i, int j) { if (j < 0) { return 0; } - if (j >= - grpc_stats_histo_bucket_boundaries[i][grpc_stats_histo_buckets[i] - 1]) { + if (j >= grpc_stats_histo_bucket_boundaries[i][grpc_stats_histo_buckets[i]]) { return grpc_stats_histo_buckets[i] - 1; } - int r = 0; - while (grpc_stats_histo_bucket_boundaries[i][r + 1] <= j) r++; - return r; -} - -static int FindNonZeroBucket(const grpc_stats_data& data, int i) { - for (int j = 0; j < grpc_stats_histo_buckets[i]; j++) { - if (data.histograms[grpc_stats_histo_start[i] + j] != 0) { - return j; - } - } - return -1; + return std::upper_bound(grpc_stats_histo_bucket_boundaries[i], + grpc_stats_histo_bucket_boundaries[i] + + grpc_stats_histo_buckets[i], + j) - + grpc_stats_histo_bucket_boundaries[i] - 1; } TEST(StatsTest, IncHistogram) { for (int i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + std::vector<int> test_values; for (int j = -1000; j < grpc_stats_histo_bucket_boundaries[i] [grpc_stats_histo_buckets[i] - 1] + 1000; j++) { - gpr_log(GPR_DEBUG, "histo:%d value:%d", i, j); - + test_values.push_back(j); + } + std::random_shuffle(test_values.begin(), test_values.end()); + if (test_values.size() > 10000) { + test_values.resize(10000); + } + for (auto j : test_values) { Snapshot snapshot; int expected_bucket = FindExpectedBucket(i, j); @@ -106,9 +104,7 @@ TEST(StatsTest, IncHistogram) { grpc_exec_ctx_finish(&exec_ctx); auto delta = snapshot.delta(); - int got_bucket = FindNonZeroBucket(delta, i); - EXPECT_EQ(expected_bucket, got_bucket); EXPECT_EQ(delta.histograms[grpc_stats_histo_start[i] + expected_bucket], 1); } diff --git a/test/core/end2end/tests/cancel_after_accept.c b/test/core/end2end/tests/cancel_after_accept.c index a360c6f0d2..c3ac0c3201 100644 --- a/test/core/end2end/tests/cancel_after_accept.c +++ b/test/core/end2end/tests/cancel_after_accept.c @@ -39,10 +39,13 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + cancellation_mode mode, + bool use_service_config, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s/%s/%s", test_name, config.name, + mode.name, use_service_config ? "service_config" : "client_api"); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -135,8 +138,8 @@ static void test_cancel_after_accept(grpc_end2end_test_config config, args = grpc_channel_args_copy_and_add(args, &arg, 1); } - grpc_end2end_test_fixture f = - begin_test(config, "cancel_after_accept", args, NULL); + grpc_end2end_test_fixture f = begin_test(config, "cancel_after_accept", mode, + use_service_config, args, NULL); cq_verifier *cqv = cq_verifier_create(f.cq); gpr_timespec deadline = use_service_config diff --git a/test/core/end2end/tests/cancel_after_client_done.c b/test/core/end2end/tests/cancel_after_client_done.c index 502005b6af..0e2a751d83 100644 --- a/test/core/end2end/tests/cancel_after_client_done.c +++ b/test/core/end2end/tests/cancel_after_client_done.c @@ -33,10 +33,12 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + cancellation_mode mode, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s/%s", test_name, config.name, + mode.name); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -93,7 +95,7 @@ static void test_cancel_after_accept_and_writes_closed( grpc_call *c; grpc_call *s; grpc_end2end_test_fixture f = begin_test( - config, "test_cancel_after_accept_and_writes_closed", NULL, NULL); + config, "test_cancel_after_accept_and_writes_closed", mode, NULL, NULL); cq_verifier *cqv = cq_verifier_create(f.cq); grpc_metadata_array initial_metadata_recv; grpc_metadata_array trailing_metadata_recv; diff --git a/test/core/end2end/tests/cancel_after_round_trip.c b/test/core/end2end/tests/cancel_after_round_trip.c index ad24b4e538..bc41bd3a6d 100644 --- a/test/core/end2end/tests/cancel_after_round_trip.c +++ b/test/core/end2end/tests/cancel_after_round_trip.c @@ -39,10 +39,13 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + cancellation_mode mode, + bool use_service_config, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s/%s/%s", test_name, config.name, + mode.name, use_service_config ? "service_config" : "client_api"); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -137,8 +140,8 @@ static void test_cancel_after_round_trip(grpc_end2end_test_config config, args = grpc_channel_args_copy_and_add(args, &arg, 1); } - grpc_end2end_test_fixture f = - begin_test(config, "cancel_after_round_trip", args, NULL); + grpc_end2end_test_fixture f = begin_test( + config, "cancel_after_round_trip", mode, use_service_config, args, NULL); cq_verifier *cqv = cq_verifier_create(f.cq); gpr_timespec deadline = use_service_config diff --git a/test/core/end2end/tests/cancel_before_invoke.c b/test/core/end2end/tests/cancel_before_invoke.c index 423194b63e..397e8b8ba6 100644 --- a/test/core/end2end/tests/cancel_before_invoke.c +++ b/test/core/end2end/tests/cancel_before_invoke.c @@ -32,10 +32,12 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + size_t num_ops, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s [%" PRIdPTR " ops]", test_name, + config.name, num_ops); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -91,7 +93,7 @@ static void test_cancel_before_invoke(grpc_end2end_test_config config, grpc_op *op; grpc_call *c; grpc_end2end_test_fixture f = - begin_test(config, "cancel_before_invoke", NULL, NULL); + begin_test(config, "cancel_before_invoke", test_ops, NULL, NULL); cq_verifier *cqv = cq_verifier_create(f.cq); grpc_metadata_array initial_metadata_recv; grpc_metadata_array trailing_metadata_recv; diff --git a/test/core/end2end/tests/cancel_in_a_vacuum.c b/test/core/end2end/tests/cancel_in_a_vacuum.c index f64cbdf929..cd9551bef9 100644 --- a/test/core/end2end/tests/cancel_in_a_vacuum.c +++ b/test/core/end2end/tests/cancel_in_a_vacuum.c @@ -33,10 +33,12 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + cancellation_mode mode, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s/%s", test_name, config.name, + mode.name); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -90,7 +92,7 @@ static void test_cancel_in_a_vacuum(grpc_end2end_test_config config, cancellation_mode mode) { grpc_call *c; grpc_end2end_test_fixture f = - begin_test(config, "test_cancel_in_a_vacuum", NULL, NULL); + begin_test(config, "test_cancel_in_a_vacuum", mode, NULL, NULL); cq_verifier *v_client = cq_verifier_create(f.cq); gpr_timespec deadline = five_seconds_from_now(); diff --git a/test/core/end2end/tests/cancel_with_status.c b/test/core/end2end/tests/cancel_with_status.c index fd26fd122e..ab8c4f4187 100644 --- a/test/core/end2end/tests/cancel_with_status.c +++ b/test/core/end2end/tests/cancel_with_status.c @@ -35,10 +35,12 @@ static void *tag(intptr_t t) { return (void *)t; } static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char *test_name, + size_t num_ops, grpc_channel_args *client_args, grpc_channel_args *server_args) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + gpr_log(GPR_INFO, "Running test: %s/%s [%" PRIdPTR " ops]", test_name, + config.name, num_ops); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -165,7 +167,7 @@ static void test_invoke_simple_request(grpc_end2end_test_config config, size_t num_ops) { grpc_end2end_test_fixture f; - f = begin_test(config, "test_invoke_simple_request", NULL, NULL); + f = begin_test(config, "test_invoke_simple_request", num_ops, NULL, NULL); simple_request_body(config, f, num_ops); end_test(&f); config.tear_down_data(&f); diff --git a/test/core/end2end/tests/filter_call_init_fails.c b/test/core/end2end/tests/filter_call_init_fails.c index b6be375a51..09e9dbcd7b 100644 --- a/test/core/end2end/tests/filter_call_init_fails.c +++ b/test/core/end2end/tests/filter_call_init_fails.c @@ -430,7 +430,6 @@ static const grpc_channel_filter test_filter = { 0, init_channel_elem, destroy_channel_elem, - grpc_call_next_get_peer, grpc_channel_next_get_info, "filter_call_init_fails"}; diff --git a/test/core/end2end/tests/filter_causes_close.c b/test/core/end2end/tests/filter_causes_close.c index aff39dd89d..5a8c96d121 100644 --- a/test/core/end2end/tests/filter_causes_close.c +++ b/test/core/end2end/tests/filter_causes_close.c @@ -197,7 +197,7 @@ static void recv_im_ready(grpc_exec_ctx *exec_ctx, void *arg, grpc_error *error) { grpc_call_element *elem = arg; call_data *calld = elem->call_data; - GRPC_CLOSURE_SCHED( + GRPC_CLOSURE_RUN( exec_ctx, calld->recv_im_ready, grpc_error_set_int(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Failure that's not preventable.", &error, 1), @@ -247,7 +247,6 @@ static const grpc_channel_filter test_filter = { sizeof(channel_data), init_channel_elem, destroy_channel_elem, - grpc_call_next_get_peer, grpc_channel_next_get_info, "filter_causes_close"}; diff --git a/test/core/end2end/tests/filter_latency.c b/test/core/end2end/tests/filter_latency.c index 5dbbc4d18d..8918c3b2f6 100644 --- a/test/core/end2end/tests/filter_latency.c +++ b/test/core/end2end/tests/filter_latency.c @@ -290,7 +290,6 @@ static const grpc_channel_filter test_client_filter = { 0, init_channel_elem, destroy_channel_elem, - grpc_call_next_get_peer, grpc_channel_next_get_info, "client_filter_latency"}; @@ -304,7 +303,6 @@ static const grpc_channel_filter test_server_filter = { 0, init_channel_elem, destroy_channel_elem, - grpc_call_next_get_peer, grpc_channel_next_get_info, "server_filter_latency"}; diff --git a/test/core/iomgr/fd_conservation_posix_test.c b/test/core/iomgr/fd_conservation_posix_test.c index 3c61173ecd..d29b1e8e41 100644 --- a/test/core/iomgr/fd_conservation_posix_test.c +++ b/test/core/iomgr/fd_conservation_posix_test.c @@ -30,9 +30,8 @@ int main(int argc, char **argv) { grpc_endpoint_pair p; grpc_test_init(argc, argv); + grpc_init(); grpc_exec_ctx exec_ctx = GRPC_EXEC_CTX_INIT; - grpc_iomgr_init(&exec_ctx); - grpc_iomgr_start(&exec_ctx); /* set max # of file descriptors to a low value, and verify we can create and destroy many more than this number @@ -51,7 +50,7 @@ int main(int argc, char **argv) { grpc_resource_quota_unref(resource_quota); - grpc_iomgr_shutdown(&exec_ctx); grpc_exec_ctx_finish(&exec_ctx); + grpc_shutdown(); return 0; } diff --git a/test/core/support/string_test.c b/test/core/support/string_test.c index a3c33c3fa4..bee2139477 100644 --- a/test/core/support/string_test.c +++ b/test/core/support/string_test.c @@ -279,6 +279,21 @@ static void test_memrchr(void) { GPR_ASSERT(0 == strcmp((const char *)gpr_memrchr("hello", 'l', 5), "lo")); } +static void test_is_true(void) { + LOG_TEST_NAME("test_is_true"); + + GPR_ASSERT(true == gpr_is_true("True")); + GPR_ASSERT(true == gpr_is_true("true")); + GPR_ASSERT(true == gpr_is_true("TRUE")); + GPR_ASSERT(true == gpr_is_true("Yes")); + GPR_ASSERT(true == gpr_is_true("yes")); + GPR_ASSERT(true == gpr_is_true("YES")); + GPR_ASSERT(true == gpr_is_true("1")); + GPR_ASSERT(false == gpr_is_true(NULL)); + GPR_ASSERT(false == gpr_is_true("")); + GPR_ASSERT(false == gpr_is_true("0")); +} + int main(int argc, char **argv) { grpc_test_init(argc, argv); test_strdup(); @@ -292,5 +307,6 @@ int main(int argc, char **argv) { test_leftpad(); test_stricmp(); test_memrchr(); + test_is_true(); return 0; } diff --git a/test/core/tsi/BUILD b/test/core/tsi/BUILD index 82e0e5fb80..0c5509dda6 100644 --- a/test/core/tsi/BUILD +++ b/test/core/tsi/BUILD @@ -18,13 +18,63 @@ licenses(["notice"]) # Apache v2 grpc_package(name = "test/core/tsi") +grpc_cc_library( + name = "transport_security_test_lib", + srcs = ["transport_security_test_lib.c"], + hdrs = ["transport_security_test_lib.h"], + deps = [ + "//:grpc", + "//:tsi", + ], +) + +grpc_cc_test( + name = "fake_transport_security_test", + srcs = ["fake_transport_security_test.c"], + language = "C", + deps = [ + ":transport_security_test_lib", + "//:grpc", + "//:gpr", + "//:tsi", + "//test/core/util:gpr_test_util", + ], +) + + +grpc_cc_test( + name = "ssl_transport_security_test", + srcs = ["ssl_transport_security_test.c"], + data = [ + "//src/core/tsi/test_creds:badclient.key", + "//src/core/tsi/test_creds:badclient.pem", + "//src/core/tsi/test_creds:badserver.key", + "//src/core/tsi/test_creds:badserver.pem", + "//src/core/tsi/test_creds:ca.pem", + "//src/core/tsi/test_creds:client.key", + "//src/core/tsi/test_creds:client.pem", + "//src/core/tsi/test_creds:server0.key", + "//src/core/tsi/test_creds:server0.pem", + "//src/core/tsi/test_creds:server1.key", + "//src/core/tsi/test_creds:server1.pem", + ], + language = "C", + deps = [ + ":transport_security_test_lib", + "//:grpc", + "//:gpr", + "//:tsi", + "//test/core/util:gpr_test_util", + ], +) + grpc_cc_test( name = "transport_security_test", srcs = ["transport_security_test.c"], language = "C", deps = [ - "//:gpr", "//:grpc", + "//:gpr", "//test/core/util:gpr_test_util", "//test/core/util:grpc_test_util", ], diff --git a/test/core/tsi/fake_transport_security_test.c b/test/core/tsi/fake_transport_security_test.c new file mode 100644 index 0000000000..11be8802b7 --- /dev/null +++ b/test/core/tsi/fake_transport_security_test.c @@ -0,0 +1,148 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdbool.h> +#include <stdio.h> +#include <string.h> + +#include "src/core/lib/security/transport/security_connector.h" +#include "src/core/tsi/fake_transport_security.h" +#include "test/core/tsi/transport_security_test_lib.h" +#include "test/core/util/test_config.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> + +typedef struct fake_tsi_test_fixture { + tsi_test_fixture base; +} fake_tsi_test_fixture; + +static void fake_test_setup_handshakers(tsi_test_fixture *fixture) { + fixture->client_handshaker = + tsi_create_fake_handshaker(true /* is_client. */); + fixture->server_handshaker = + tsi_create_fake_handshaker(false /* is_client. */); +} + +static void validate_handshaker_peers(tsi_handshaker_result *result) { + GPR_ASSERT(result != NULL); + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + const tsi_peer_property *property = + tsi_peer_get_property_by_name(&peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); + GPR_ASSERT(property != NULL); + GPR_ASSERT(memcmp(property->value.data, TSI_FAKE_CERTIFICATE_TYPE, + property->value.length) == 0); + tsi_peer_destruct(&peer); +} + +static void fake_test_check_handshaker_peers(tsi_test_fixture *fixture) { + validate_handshaker_peers(fixture->client_result); + validate_handshaker_peers(fixture->server_result); +} + +static void fake_test_destruct(tsi_test_fixture *fixture) {} + +static const struct tsi_test_fixture_vtable vtable = { + fake_test_setup_handshakers, fake_test_check_handshaker_peers, + fake_test_destruct}; + +static tsi_test_fixture *fake_tsi_test_fixture_create() { + fake_tsi_test_fixture *fake_fixture = gpr_zalloc(sizeof(*fake_fixture)); + tsi_test_fixture_init(&fake_fixture->base); + fake_fixture->base.vtable = &vtable; + return &fake_fixture->base; +} + +void fake_tsi_test_do_handshake_tiny_handshake_buffer() { + tsi_test_fixture *fixture = fake_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_TINY_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_handshake_small_handshake_buffer() { + tsi_test_fixture *fixture = fake_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_SMALL_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_handshake() { + tsi_test_fixture *fixture = fake_tsi_test_fixture_create(); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_round_trip_for_all_configs() { + unsigned int *bit_array = + gpr_zalloc(sizeof(unsigned int) * TSI_TEST_NUM_OF_ARGUMENTS); + const unsigned int mask = 1U << (TSI_TEST_NUM_OF_ARGUMENTS - 1); + for (unsigned int val = 0; val < TSI_TEST_NUM_OF_COMBINATIONS; val++) { + unsigned int v = val; + for (unsigned int ind = 0; ind < TSI_TEST_NUM_OF_ARGUMENTS; ind++) { + bit_array[ind] = (v & mask) ? 1 : 0; + v <<= 1; + } + tsi_test_fixture *fixture = fake_tsi_test_fixture_create(); + fake_tsi_test_fixture *fake_fixture = (fake_tsi_test_fixture *)fixture; + tsi_test_frame_protector_config_destroy(fake_fixture->base.config); + fake_fixture->base.config = tsi_test_frame_protector_config_create( + bit_array[0], bit_array[1], bit_array[2], bit_array[3], bit_array[4], + bit_array[5], bit_array[6], bit_array[7]); + tsi_test_do_round_trip(&fake_fixture->base); + tsi_test_fixture_destroy(fixture); + } + gpr_free(bit_array); +} + +void fake_tsi_test_do_round_trip_odd_buffer_size() { + const size_t odd_sizes[] = {1025, 2051, 4103, 8207, 16409}; + const size_t size = sizeof(odd_sizes) / sizeof(size_t); + for (size_t ind1 = 0; ind1 < size; ind1++) { + for (size_t ind2 = 0; ind2 < size; ind2++) { + for (size_t ind3 = 0; ind3 < size; ind3++) { + for (size_t ind4 = 0; ind4 < size; ind4++) { + for (size_t ind5 = 0; ind5 < size; ind5++) { + tsi_test_fixture *fixture = fake_tsi_test_fixture_create(); + fake_tsi_test_fixture *fake_fixture = + (fake_tsi_test_fixture *)fixture; + tsi_test_frame_protector_config_set_buffer_size( + fake_fixture->base.config, odd_sizes[ind1], odd_sizes[ind2], + odd_sizes[ind3], odd_sizes[ind4], odd_sizes[ind5]); + tsi_test_do_round_trip(&fake_fixture->base); + tsi_test_fixture_destroy(fixture); + } + } + } + } + } +} + +int main(int argc, char **argv) { + grpc_test_init(argc, argv); + grpc_init(); + fake_tsi_test_do_handshake_tiny_handshake_buffer(); + fake_tsi_test_do_handshake_small_handshake_buffer(); + fake_tsi_test_do_handshake(); + fake_tsi_test_do_round_trip_for_all_configs(); + fake_tsi_test_do_round_trip_odd_buffer_size(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/ssl_transport_security_test.c b/test/core/tsi/ssl_transport_security_test.c new file mode 100644 index 0000000000..364dfa1b73 --- /dev/null +++ b/test/core/tsi/ssl_transport_security_test.c @@ -0,0 +1,558 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdbool.h> +#include <stdio.h> +#include <string.h> + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/transport/security_connector.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security_adapter.h" +#include "test/core/tsi/transport_security_test_lib.h" +#include "test/core/util/test_config.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> + +#define SSL_TSI_TEST_ALPN1 "foo" +#define SSL_TSI_TEST_ALPN2 "toto" +#define SSL_TSI_TEST_ALPN3 "baz" +#define SSL_TSI_TEST_ALPN_NUM 2 +#define SSL_TSI_TEST_SERVER_KEY_CERT_PAIRS_NUM 2 +#define SSL_TSI_TEST_BAD_SERVER_KEY_CERT_PAIRS_NUM 1 +#define SSL_TSI_TEST_CREDENTIALS_DIR "src/core/tsi/test_creds/" + +typedef enum AlpnMode { + NO_ALPN, + ALPN_CLIENT_NO_SERVER, + ALPN_SERVER_NO_CLIENT, + ALPN_CLIENT_SERVER_OK, + ALPN_CLIENT_SERVER_MISMATCH +} AlpnMode; + +typedef struct ssl_alpn_lib { + AlpnMode alpn_mode; + char **server_alpn_protocols; + char **client_alpn_protocols; + uint16_t num_server_alpn_protocols; + uint16_t num_client_alpn_protocols; +} ssl_alpn_lib; + +typedef struct ssl_key_cert_lib { + bool use_bad_server_cert; + bool use_bad_client_cert; + char *root_cert; + tsi_ssl_pem_key_cert_pair *server_pem_key_cert_pairs; + tsi_ssl_pem_key_cert_pair *bad_server_pem_key_cert_pairs; + tsi_ssl_pem_key_cert_pair client_pem_key_cert_pair; + tsi_ssl_pem_key_cert_pair bad_client_pem_key_cert_pair; + uint16_t server_num_key_cert_pairs; + uint16_t bad_server_num_key_cert_pairs; +} ssl_key_cert_lib; + +typedef struct ssl_tsi_test_fixture { + tsi_test_fixture base; + ssl_key_cert_lib *key_cert_lib; + ssl_alpn_lib *alpn_lib; + bool force_client_auth; + char *server_name_indication; + tsi_ssl_server_handshaker_factory *server_handshaker_factory; + tsi_ssl_client_handshaker_factory *client_handshaker_factory; +} ssl_tsi_test_fixture; + +static void ssl_test_setup_handshakers(tsi_test_fixture *fixture) { + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + GPR_ASSERT(ssl_fixture != NULL); + GPR_ASSERT(ssl_fixture->key_cert_lib != NULL); + GPR_ASSERT(ssl_fixture->alpn_lib != NULL); + ssl_key_cert_lib *key_cert_lib = ssl_fixture->key_cert_lib; + ssl_alpn_lib *alpn_lib = ssl_fixture->alpn_lib; + /* Create client handshaker factory. */ + tsi_ssl_pem_key_cert_pair *client_key_cert_pair = NULL; + if (ssl_fixture->force_client_auth) { + client_key_cert_pair = key_cert_lib->use_bad_client_cert + ? &key_cert_lib->bad_client_pem_key_cert_pair + : &key_cert_lib->client_pem_key_cert_pair; + } + char **client_alpn_protocols = NULL; + uint16_t num_client_alpn_protocols = 0; + if (alpn_lib->alpn_mode == ALPN_CLIENT_NO_SERVER || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + client_alpn_protocols = alpn_lib->client_alpn_protocols; + num_client_alpn_protocols = alpn_lib->num_client_alpn_protocols; + } + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory( + client_key_cert_pair, key_cert_lib->root_cert, NULL, + (const char **)client_alpn_protocols, + num_client_alpn_protocols, + &ssl_fixture->client_handshaker_factory) == TSI_OK); + /* Create server handshaker factory. */ + char **server_alpn_protocols = NULL; + uint16_t num_server_alpn_protocols = 0; + if (alpn_lib->alpn_mode == ALPN_SERVER_NO_CLIENT || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + server_alpn_protocols = alpn_lib->server_alpn_protocols; + num_server_alpn_protocols = alpn_lib->num_server_alpn_protocols; + if (alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + num_server_alpn_protocols--; + } + } + GPR_ASSERT(tsi_create_ssl_server_handshaker_factory( + key_cert_lib->use_bad_server_cert + ? key_cert_lib->bad_server_pem_key_cert_pairs + : key_cert_lib->server_pem_key_cert_pairs, + key_cert_lib->use_bad_server_cert + ? key_cert_lib->bad_server_num_key_cert_pairs + : key_cert_lib->server_num_key_cert_pairs, + key_cert_lib->root_cert, ssl_fixture->force_client_auth, NULL, + (const char **)server_alpn_protocols, + num_server_alpn_protocols, + &ssl_fixture->server_handshaker_factory) == TSI_OK); + /* Create server and client handshakers. */ + tsi_handshaker *client_handshaker = NULL; + GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker( + ssl_fixture->client_handshaker_factory, + ssl_fixture->server_name_indication, + &client_handshaker) == TSI_OK); + ssl_fixture->base.client_handshaker = + tsi_create_adapter_handshaker(client_handshaker); + tsi_handshaker *server_handshaker = NULL; + GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker( + ssl_fixture->server_handshaker_factory, &server_handshaker) == + TSI_OK); + ssl_fixture->base.server_handshaker = + tsi_create_adapter_handshaker(server_handshaker); +} + +static void check_alpn(ssl_tsi_test_fixture *ssl_fixture, + const tsi_peer *peer) { + GPR_ASSERT(ssl_fixture != NULL); + GPR_ASSERT(ssl_fixture->alpn_lib != NULL); + ssl_alpn_lib *alpn_lib = ssl_fixture->alpn_lib; + const tsi_peer_property *alpn_property = + tsi_peer_get_property_by_name(peer, TSI_SSL_ALPN_SELECTED_PROTOCOL); + if (alpn_lib->alpn_mode != ALPN_CLIENT_SERVER_OK) { + GPR_ASSERT(alpn_property == NULL); + } else { + GPR_ASSERT(alpn_property != NULL); + const char *expected_match = "baz"; + GPR_ASSERT(memcmp(alpn_property->value.data, expected_match, + alpn_property->value.length) == 0); + } +} + +static const tsi_peer_property * +check_basic_authenticated_peer_and_get_common_name(const tsi_peer *peer) { + const tsi_peer_property *cert_type_property = + tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); + GPR_ASSERT(cert_type_property != NULL); + GPR_ASSERT(memcmp(cert_type_property->value.data, TSI_X509_CERTIFICATE_TYPE, + cert_type_property->value.length) == 0); + const tsi_peer_property *property = tsi_peer_get_property_by_name( + peer, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY); + GPR_ASSERT(property != NULL); + return property; +} + +void check_server0_peer(tsi_peer *peer) { + const tsi_peer_property *property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char *expected_match = "*.test.google.com.au"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + GPR_ASSERT(tsi_peer_get_property_by_name( + peer, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == + NULL); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.google.com.au") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.com.au") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.blah") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.bar.test.google.com.au") == + 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "test.google.com.au") == 0); + tsi_peer_destruct(peer); +} + +static bool check_subject_alt_name(tsi_peer *peer, const char *name) { + for (size_t i = 0; i < peer->property_count; i++) { + const tsi_peer_property *prop = &peer->properties[i]; + if (strcmp(prop->name, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == + 0) { + if (memcmp(prop->value.data, name, prop->value.length) == 0) { + return true; + } + } + } + return false; +} + +void check_server1_peer(tsi_peer *peer) { + const tsi_peer_property *property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char *expected_match = "*.test.google.com"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + GPR_ASSERT(check_subject_alt_name(peer, "*.test.google.fr") == 1); + GPR_ASSERT(check_subject_alt_name(peer, "waterzooi.test.google.be") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.google.fr") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.fr") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "waterzooi.test.google.be") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.youtube.com") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.foo.test.google.com") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "test.google.fr") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "tartines.test.google.be") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "tartines.youtube.com") == 0); + tsi_peer_destruct(peer); +} + +static void check_client_peer(ssl_tsi_test_fixture *ssl_fixture, + tsi_peer *peer) { + GPR_ASSERT(ssl_fixture != NULL); + GPR_ASSERT(ssl_fixture->alpn_lib != NULL); + ssl_alpn_lib *alpn_lib = ssl_fixture->alpn_lib; + if (!ssl_fixture->force_client_auth) { + GPR_ASSERT(peer->property_count == + (alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK ? 1 : 0)); + } else { + const tsi_peer_property *property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char *expected_match = "testclient"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + } + tsi_peer_destruct(peer); +} + +static void ssl_test_check_handshaker_peers(tsi_test_fixture *fixture) { + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + GPR_ASSERT(ssl_fixture != NULL); + GPR_ASSERT(ssl_fixture->key_cert_lib != NULL); + ssl_key_cert_lib *key_cert_lib = ssl_fixture->key_cert_lib; + tsi_peer peer; + bool expect_success = + !(key_cert_lib->use_bad_server_cert || + (key_cert_lib->use_bad_client_cert && ssl_fixture->force_client_auth)); + if (expect_success) { + GPR_ASSERT(tsi_handshaker_result_extract_peer( + ssl_fixture->base.client_result, &peer) == TSI_OK); + check_alpn(ssl_fixture, &peer); + + if (ssl_fixture->server_name_indication != NULL) { + check_server1_peer(&peer); + } else { + check_server0_peer(&peer); + } + } else { + GPR_ASSERT(ssl_fixture->base.client_result == NULL); + } + if (expect_success) { + GPR_ASSERT(tsi_handshaker_result_extract_peer( + ssl_fixture->base.server_result, &peer) == TSI_OK); + check_alpn(ssl_fixture, &peer); + check_client_peer(ssl_fixture, &peer); + } else { + GPR_ASSERT(ssl_fixture->base.server_result == NULL); + } +} + +static void ssl_test_pem_key_cert_pair_destroy(tsi_ssl_pem_key_cert_pair kp) { + gpr_free((void *)kp.private_key); + gpr_free((void *)kp.cert_chain); +} + +static void ssl_test_destruct(tsi_test_fixture *fixture) { + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + if (ssl_fixture == NULL) { + return; + } + /* Destroy ssl_alpn_lib. */ + ssl_alpn_lib *alpn_lib = ssl_fixture->alpn_lib; + for (size_t i = 0; i < alpn_lib->num_server_alpn_protocols; i++) { + gpr_free(alpn_lib->server_alpn_protocols[i]); + } + gpr_free(alpn_lib->server_alpn_protocols); + for (size_t i = 0; i < alpn_lib->num_client_alpn_protocols; i++) { + gpr_free(alpn_lib->client_alpn_protocols[i]); + } + gpr_free(alpn_lib->client_alpn_protocols); + gpr_free(alpn_lib); + /* Destroy ssl_key_cert_lib. */ + ssl_key_cert_lib *key_cert_lib = ssl_fixture->key_cert_lib; + for (size_t i = 0; i < key_cert_lib->server_num_key_cert_pairs; i++) { + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->server_pem_key_cert_pairs[i]); + } + gpr_free(key_cert_lib->server_pem_key_cert_pairs); + for (size_t i = 0; i < key_cert_lib->bad_server_num_key_cert_pairs; i++) { + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->bad_server_pem_key_cert_pairs[i]); + } + gpr_free(key_cert_lib->bad_server_pem_key_cert_pairs); + ssl_test_pem_key_cert_pair_destroy(key_cert_lib->client_pem_key_cert_pair); + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->bad_client_pem_key_cert_pair); + gpr_free(key_cert_lib->root_cert); + gpr_free(key_cert_lib); + /* Destroy others. */ + tsi_ssl_server_handshaker_factory_destroy( + ssl_fixture->server_handshaker_factory); + tsi_ssl_client_handshaker_factory_destroy( + ssl_fixture->client_handshaker_factory); +} + +static const struct tsi_test_fixture_vtable vtable = { + ssl_test_setup_handshakers, ssl_test_check_handshaker_peers, + ssl_test_destruct}; + +static char *load_file(const char *dir_path, const char *file_name) { + char *file_path = + gpr_zalloc(sizeof(char) * (strlen(dir_path) + strlen(file_name) + 1)); + memcpy(file_path, dir_path, strlen(dir_path)); + memcpy(file_path + strlen(dir_path), file_name, strlen(file_name)); + grpc_slice slice; + GPR_ASSERT(grpc_load_file(file_path, 1, &slice) == GRPC_ERROR_NONE); + char *data = grpc_slice_to_c_string(slice); + grpc_slice_unref(slice); + gpr_free(file_path); + return data; +} + +static tsi_test_fixture *ssl_tsi_test_fixture_create() { + ssl_tsi_test_fixture *ssl_fixture = gpr_zalloc(sizeof(*ssl_fixture)); + tsi_test_fixture_init(&ssl_fixture->base); + ssl_fixture->base.test_unused_bytes = false; + ssl_fixture->base.vtable = &vtable; + /* Create ssl_key_cert_lib. */ + ssl_key_cert_lib *key_cert_lib = gpr_zalloc(sizeof(*key_cert_lib)); + key_cert_lib->use_bad_server_cert = false; + key_cert_lib->use_bad_client_cert = false; + key_cert_lib->server_num_key_cert_pairs = + SSL_TSI_TEST_SERVER_KEY_CERT_PAIRS_NUM; + key_cert_lib->bad_server_num_key_cert_pairs = + SSL_TSI_TEST_BAD_SERVER_KEY_CERT_PAIRS_NUM; + key_cert_lib->server_pem_key_cert_pairs = + gpr_malloc(sizeof(tsi_ssl_pem_key_cert_pair) * + key_cert_lib->server_num_key_cert_pairs); + key_cert_lib->bad_server_pem_key_cert_pairs = + gpr_malloc(sizeof(tsi_ssl_pem_key_cert_pair) * + key_cert_lib->bad_server_num_key_cert_pairs); + key_cert_lib->server_pem_key_cert_pairs[0].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key"); + key_cert_lib->server_pem_key_cert_pairs[0].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem"); + key_cert_lib->server_pem_key_cert_pairs[1].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server1.key"); + key_cert_lib->server_pem_key_cert_pairs[1].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server1.pem"); + key_cert_lib->bad_server_pem_key_cert_pairs[0].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badserver.key"); + key_cert_lib->bad_server_pem_key_cert_pairs[0].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badserver.pem"); + key_cert_lib->client_pem_key_cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.key"); + key_cert_lib->client_pem_key_cert_pair.cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem"); + key_cert_lib->bad_client_pem_key_cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badclient.key"); + key_cert_lib->bad_client_pem_key_cert_pair.cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badclient.pem"); + key_cert_lib->root_cert = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "ca.pem"); + ssl_fixture->key_cert_lib = key_cert_lib; + /* Create ssl_alpn_lib. */ + ssl_alpn_lib *alpn_lib = gpr_zalloc(sizeof(*alpn_lib)); + alpn_lib->server_alpn_protocols = + gpr_zalloc(sizeof(char *) * SSL_TSI_TEST_ALPN_NUM); + alpn_lib->client_alpn_protocols = + gpr_zalloc(sizeof(char *) * SSL_TSI_TEST_ALPN_NUM); + alpn_lib->server_alpn_protocols[0] = gpr_strdup(SSL_TSI_TEST_ALPN1); + alpn_lib->server_alpn_protocols[1] = gpr_strdup(SSL_TSI_TEST_ALPN3); + alpn_lib->client_alpn_protocols[0] = gpr_strdup(SSL_TSI_TEST_ALPN2); + alpn_lib->client_alpn_protocols[1] = gpr_strdup(SSL_TSI_TEST_ALPN3); + alpn_lib->num_server_alpn_protocols = SSL_TSI_TEST_ALPN_NUM; + alpn_lib->num_client_alpn_protocols = SSL_TSI_TEST_ALPN_NUM; + alpn_lib->alpn_mode = NO_ALPN; + ssl_fixture->alpn_lib = alpn_lib; + ssl_fixture->base.vtable = &vtable; + ssl_fixture->server_name_indication = NULL; + ssl_fixture->force_client_auth = false; + return &ssl_fixture->base; +} + +void ssl_tsi_test_do_handshake_tiny_handshake_buffer() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_TINY_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_small_handshake_buffer() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_SMALL_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_client_authentication() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->force_client_auth = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_server_name_indication_exact_domain() { + /* server1 cert contains "waterzooi.test.google.be" in SAN. */ + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->server_name_indication = "waterzooi.test.google.be"; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_server_name_indication_wild_star_domain() { + /* server1 cert contains "*.test.google.fr" in SAN. */ + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->server_name_indication = "juju.test.google.fr"; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_bad_server_cert() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->key_cert_lib->use_bad_server_cert = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_bad_client_cert() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->key_cert_lib->use_bad_client_cert = true; + ssl_fixture->force_client_auth = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_no_server() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_NO_SERVER; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_server_no_client() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->alpn_lib->alpn_mode = ALPN_SERVER_NO_CLIENT; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_server_mismatch() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_SERVER_MISMATCH; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_server_ok() { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_SERVER_OK; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_round_trip_for_all_configs() { + unsigned int *bit_array = + gpr_zalloc(sizeof(unsigned int) * TSI_TEST_NUM_OF_ARGUMENTS); + const unsigned int mask = 1U << (TSI_TEST_NUM_OF_ARGUMENTS - 1); + for (unsigned int val = 0; val < TSI_TEST_NUM_OF_COMBINATIONS; val++) { + unsigned int v = val; + for (unsigned int ind = 0; ind < TSI_TEST_NUM_OF_ARGUMENTS; ind++) { + bit_array[ind] = (v & mask) ? 1 : 0; + v <<= 1; + } + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + tsi_test_frame_protector_config_destroy(ssl_fixture->base.config); + ssl_fixture->base.config = tsi_test_frame_protector_config_create( + bit_array[0], bit_array[1], bit_array[2], bit_array[3], bit_array[4], + bit_array[5], bit_array[6], bit_array[7]); + tsi_test_do_round_trip(&ssl_fixture->base); + tsi_test_fixture_destroy(fixture); + } + gpr_free(bit_array); +} + +void ssl_tsi_test_do_round_trip_odd_buffer_size() { + const size_t odd_sizes[] = {1025, 2051, 4103, 8207, 16409}; + const size_t size = sizeof(odd_sizes) / sizeof(size_t); + for (size_t ind1 = 0; ind1 < size; ind1++) { + for (size_t ind2 = 0; ind2 < size; ind2++) { + for (size_t ind3 = 0; ind3 < size; ind3++) { + for (size_t ind4 = 0; ind4 < size; ind4++) { + for (size_t ind5 = 0; ind5 < size; ind5++) { + tsi_test_fixture *fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture *ssl_fixture = (ssl_tsi_test_fixture *)fixture; + tsi_test_frame_protector_config_set_buffer_size( + ssl_fixture->base.config, odd_sizes[ind1], odd_sizes[ind2], + odd_sizes[ind3], odd_sizes[ind4], odd_sizes[ind5]); + tsi_test_do_round_trip(&ssl_fixture->base); + tsi_test_fixture_destroy(fixture); + } + } + } + } + } +} + +int main(int argc, char **argv) { + grpc_test_init(argc, argv); + grpc_init(); + ssl_tsi_test_do_handshake_tiny_handshake_buffer(); + ssl_tsi_test_do_handshake_small_handshake_buffer(); + ssl_tsi_test_do_handshake(); + ssl_tsi_test_do_handshake_with_client_authentication(); + ssl_tsi_test_do_handshake_with_server_name_indication_exact_domain(); + ssl_tsi_test_do_handshake_with_server_name_indication_wild_star_domain(); + ssl_tsi_test_do_handshake_with_bad_server_cert(); + ssl_tsi_test_do_handshake_with_bad_client_cert(); + ssl_tsi_test_do_handshake_alpn_client_no_server(); + ssl_tsi_test_do_handshake_alpn_server_no_client(); + ssl_tsi_test_do_handshake_alpn_client_server_mismatch(); + ssl_tsi_test_do_handshake_alpn_client_server_ok(); + ssl_tsi_test_do_round_trip_for_all_configs(); + ssl_tsi_test_do_round_trip_odd_buffer_size(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/transport_security_test_lib.c b/test/core/tsi/transport_security_test_lib.c new file mode 100644 index 0000000000..7d66e110f2 --- /dev/null +++ b/test/core/tsi/transport_security_test_lib.c @@ -0,0 +1,550 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include "src/core/lib/security/transport/tsi_error.h" +#include "test/core/tsi/transport_security_test_lib.h" + +typedef struct handshaker_args { + tsi_test_fixture *fixture; + unsigned char *handshake_buffer; + size_t handshake_buffer_size; + bool is_client; + bool transferred_data; + bool appended_unused_bytes; + grpc_error *error; +} handshaker_args; + +static handshaker_args *handshaker_args_create(tsi_test_fixture *fixture, + bool is_client) { + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(fixture->config != NULL); + handshaker_args *args = gpr_zalloc(sizeof(*args)); + args->fixture = fixture; + args->handshake_buffer_size = fixture->handshake_buffer_size; + args->handshake_buffer = gpr_zalloc(args->handshake_buffer_size); + args->is_client = is_client; + args->error = GRPC_ERROR_NONE; + return args; +} + +static void handshaker_args_destroy(handshaker_args *args) { + gpr_free(args->handshake_buffer); + GRPC_ERROR_UNREF(args->error); + gpr_free(args); +} + +static void do_handshaker_next(handshaker_args *args); + +static void setup_handshakers(tsi_test_fixture *fixture) { + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(fixture->vtable != NULL); + GPR_ASSERT(fixture->vtable->setup_handshakers != NULL); + fixture->vtable->setup_handshakers(fixture); +} + +static void check_unused_bytes(tsi_test_fixture *fixture) { + tsi_handshaker_result *result_with_unused_bytes = + fixture->has_client_finished_first ? fixture->server_result + : fixture->client_result; + tsi_handshaker_result *result_without_unused_bytes = + fixture->has_client_finished_first ? fixture->client_result + : fixture->server_result; + const unsigned char *bytes = NULL; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes( + result_with_unused_bytes, &bytes, &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == strlen(TSI_TEST_UNUSED_BYTES)); + GPR_ASSERT(memcmp(bytes, TSI_TEST_UNUSED_BYTES, bytes_size) == 0); + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes( + result_without_unused_bytes, &bytes, &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == 0); + GPR_ASSERT(bytes == NULL); +} + +static void check_handshake_results(tsi_test_fixture *fixture) { + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(fixture->vtable != NULL); + GPR_ASSERT(fixture->vtable->check_handshaker_peers != NULL); + /* Check handshaker peers. */ + fixture->vtable->check_handshaker_peers(fixture); + /* Check unused bytes. */ + if (fixture->test_unused_bytes) { + if (fixture->server_result != NULL && fixture->client_result != NULL) { + check_unused_bytes(fixture); + } + fixture->bytes_written_to_server_channel = 0; + fixture->bytes_written_to_client_channel = 0; + fixture->bytes_read_from_client_channel = 0; + fixture->bytes_read_from_server_channel = 0; + } +} + +static void send_bytes_to_peer(tsi_test_fixture *fixture, + const unsigned char *buf, size_t buf_size, + bool is_client) { + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(buf != NULL); + uint8_t *channel = + is_client ? fixture->server_channel : fixture->client_channel; + GPR_ASSERT(channel != NULL); + size_t *bytes_written = is_client ? &fixture->bytes_written_to_server_channel + : &fixture->bytes_written_to_client_channel; + GPR_ASSERT(bytes_written != NULL); + GPR_ASSERT(*bytes_written + buf_size <= TSI_TEST_DEFAULT_CHANNEL_SIZE); + /* Write data to channel. */ + memcpy(channel + *bytes_written, buf, buf_size); + *bytes_written += buf_size; +} + +static void maybe_append_unused_bytes(handshaker_args *args) { + GPR_ASSERT(args != NULL); + GPR_ASSERT(args->fixture != NULL); + tsi_test_fixture *fixture = args->fixture; + if (fixture->test_unused_bytes && !args->appended_unused_bytes) { + args->appended_unused_bytes = true; + send_bytes_to_peer(fixture, (const unsigned char *)TSI_TEST_UNUSED_BYTES, + strlen(TSI_TEST_UNUSED_BYTES), args->is_client); + if (fixture->client_result != NULL && fixture->server_result == NULL) { + fixture->has_client_finished_first = true; + } + } +} + +static void receive_bytes_from_peer(tsi_test_fixture *fixture, + unsigned char **buf, size_t *buf_size, + bool is_client) { + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(*buf != NULL); + GPR_ASSERT(buf_size != NULL); + uint8_t *channel = + is_client ? fixture->client_channel : fixture->server_channel; + GPR_ASSERT(channel != NULL); + size_t *bytes_read = is_client ? &fixture->bytes_read_from_client_channel + : &fixture->bytes_read_from_server_channel; + size_t *bytes_written = is_client ? &fixture->bytes_written_to_client_channel + : &fixture->bytes_written_to_server_channel; + GPR_ASSERT(bytes_read != NULL); + GPR_ASSERT(bytes_written != NULL); + size_t to_read = *buf_size < *bytes_written - *bytes_read + ? *buf_size + : *bytes_written - *bytes_read; + /* Read data from channel. */ + memcpy(*buf, channel + *bytes_read, to_read); + *buf_size = to_read; + *bytes_read += to_read; +} + +static void send_message_to_peer(tsi_test_fixture *fixture, + tsi_frame_protector *protector, + bool is_client) { + /* Initialization. */ + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(fixture->config != NULL); + GPR_ASSERT(protector != NULL); + tsi_test_frame_protector_config *config = fixture->config; + unsigned char *protected_buffer = gpr_zalloc(config->protected_buffer_size); + size_t message_size = + is_client ? config->client_message_size : config->server_message_size; + uint8_t *message = + is_client ? config->client_message : config->server_message; + GPR_ASSERT(message != NULL); + const unsigned char *message_bytes = (const unsigned char *)message; + tsi_result result = TSI_OK; + /* Do protect and send protected data to peer. */ + while (message_size > 0 && result == TSI_OK) { + size_t protected_buffer_size_to_send = config->protected_buffer_size; + size_t processed_message_size = message_size; + /* Do protect. */ + result = tsi_frame_protector_protect( + protector, message_bytes, &processed_message_size, protected_buffer, + &protected_buffer_size_to_send); + GPR_ASSERT(result == TSI_OK); + /* Send protected data to peer. */ + send_bytes_to_peer(fixture, protected_buffer, protected_buffer_size_to_send, + is_client); + message_bytes += processed_message_size; + message_size -= processed_message_size; + /* Flush if we're done. */ + if (message_size == 0) { + size_t still_pending_size; + do { + protected_buffer_size_to_send = config->protected_buffer_size; + result = tsi_frame_protector_protect_flush( + protector, protected_buffer, &protected_buffer_size_to_send, + &still_pending_size); + GPR_ASSERT(result == TSI_OK); + send_bytes_to_peer(fixture, protected_buffer, + protected_buffer_size_to_send, is_client); + } while (still_pending_size > 0 && result == TSI_OK); + GPR_ASSERT(result == TSI_OK); + } + } + GPR_ASSERT(result == TSI_OK); + gpr_free(protected_buffer); +} + +static void receive_message_from_peer(tsi_test_fixture *fixture, + tsi_frame_protector *protector, + unsigned char *message, + size_t *bytes_received, bool is_client) { + /* Initialization. */ + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(protector != NULL); + GPR_ASSERT(message != NULL); + GPR_ASSERT(bytes_received != NULL); + GPR_ASSERT(fixture->config != NULL); + tsi_test_frame_protector_config *config = fixture->config; + size_t read_offset = 0; + size_t message_offset = 0; + size_t read_from_peer_size = 0; + tsi_result result = TSI_OK; + bool done = false; + unsigned char *read_buffer = gpr_zalloc(config->read_buffer_allocated_size); + unsigned char *message_buffer = + gpr_zalloc(config->message_buffer_allocated_size); + /* Do unprotect on data received from peer. */ + while (!done && result == TSI_OK) { + /* Receive data from peer. */ + if (read_from_peer_size == 0) { + read_from_peer_size = config->read_buffer_allocated_size; + receive_bytes_from_peer(fixture, &read_buffer, &read_from_peer_size, + is_client); + read_offset = 0; + } + if (read_from_peer_size == 0) { + done = true; + } + /* Do unprotect. */ + size_t message_buffer_size; + do { + message_buffer_size = config->message_buffer_allocated_size; + size_t processed_size = read_from_peer_size; + result = tsi_frame_protector_unprotect( + protector, read_buffer + read_offset, &processed_size, message_buffer, + &message_buffer_size); + GPR_ASSERT(result == TSI_OK); + if (message_buffer_size > 0) { + memcpy(message + message_offset, message_buffer, message_buffer_size); + message_offset += message_buffer_size; + } + read_offset += processed_size; + read_from_peer_size -= processed_size; + } while ((read_from_peer_size > 0 || message_buffer_size > 0) && + result == TSI_OK); + GPR_ASSERT(result == TSI_OK); + } + GPR_ASSERT(result == TSI_OK); + *bytes_received = message_offset; + gpr_free(read_buffer); + gpr_free(message_buffer); +} + +grpc_error *on_handshake_next_done(tsi_result result, void *user_data, + const unsigned char *bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result *handshaker_result) { + handshaker_args *args = (handshaker_args *)user_data; + GPR_ASSERT(args != NULL); + GPR_ASSERT(args->fixture != NULL); + tsi_test_fixture *fixture = args->fixture; + grpc_error *error = GRPC_ERROR_NONE; + /* Read more data if we need to. */ + if (result == TSI_INCOMPLETE_DATA) { + GPR_ASSERT(bytes_to_send_size == 0); + return error; + } + if (result != TSI_OK) { + return grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake failed"), result); + } + /* Update handshaker result. */ + if (handshaker_result != NULL) { + tsi_handshaker_result **result_to_write = + args->is_client ? &fixture->client_result : &fixture->server_result; + GPR_ASSERT(*result_to_write == NULL); + *result_to_write = handshaker_result; + } + /* Send data to peer, if needed. */ + if (bytes_to_send_size > 0) { + send_bytes_to_peer(args->fixture, bytes_to_send, bytes_to_send_size, + args->is_client); + args->transferred_data = true; + } + if (handshaker_result != NULL) { + maybe_append_unused_bytes(args); + } + return error; +} + +static void on_handshake_next_done_wrapper( + tsi_result result, void *user_data, const unsigned char *bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result *handshaker_result) { + handshaker_args *args = (handshaker_args *)user_data; + args->error = on_handshake_next_done(result, user_data, bytes_to_send, + bytes_to_send_size, handshaker_result); +} + +static bool is_handshake_finished_properly(handshaker_args *args) { + GPR_ASSERT(args != NULL); + GPR_ASSERT(args->fixture != NULL); + tsi_test_fixture *fixture = args->fixture; + if ((args->is_client && fixture->client_result != NULL) || + (!args->is_client && fixture->server_result != NULL)) { + return true; + } + return false; +} + +static void do_handshaker_next(handshaker_args *args) { + /* Initialization. */ + GPR_ASSERT(args != NULL); + GPR_ASSERT(args->fixture != NULL); + tsi_test_fixture *fixture = args->fixture; + tsi_handshaker *handshaker = + args->is_client ? fixture->client_handshaker : fixture->server_handshaker; + if (is_handshake_finished_properly(args)) { + return; + } + tsi_handshaker_result *handshaker_result = NULL; + unsigned char *bytes_to_send = NULL; + size_t bytes_to_send_size = 0; + /* Receive data from peer, if available. */ + size_t buf_size = args->handshake_buffer_size; + receive_bytes_from_peer(args->fixture, &args->handshake_buffer, &buf_size, + args->is_client); + if (buf_size > 0) { + args->transferred_data = true; + } + /* Peform handshaker next. */ + tsi_result result = tsi_handshaker_next( + handshaker, args->handshake_buffer, buf_size, + (const unsigned char **)&bytes_to_send, &bytes_to_send_size, + &handshaker_result, &on_handshake_next_done_wrapper, args); + if (result != TSI_ASYNC) { + args->error = on_handshake_next_done(result, args, bytes_to_send, + bytes_to_send_size, handshaker_result); + } +} + +void tsi_test_do_handshake(tsi_test_fixture *fixture) { + /* Initializaiton. */ + setup_handshakers(fixture); + handshaker_args *client_args = + handshaker_args_create(fixture, true /* is_client */); + handshaker_args *server_args = + handshaker_args_create(fixture, false /* is_client */); + /* Do handshake. */ + do { + client_args->transferred_data = false; + server_args->transferred_data = false; + do_handshaker_next(client_args); + if (client_args->error != GRPC_ERROR_NONE) { + break; + } + do_handshaker_next(server_args); + if (server_args->error != GRPC_ERROR_NONE) { + break; + } + GPR_ASSERT(client_args->transferred_data || server_args->transferred_data); + } while (fixture->client_result == NULL || fixture->server_result == NULL); + /* Verify handshake results. */ + check_handshake_results(fixture); + /* Cleanup. */ + handshaker_args_destroy(client_args); + handshaker_args_destroy(server_args); +} + +void tsi_test_do_round_trip(tsi_test_fixture *fixture) { + /* Initialization. */ + GPR_ASSERT(fixture != NULL); + GPR_ASSERT(fixture->config != NULL); + tsi_test_frame_protector_config *config = fixture->config; + tsi_frame_protector *client_frame_protector = NULL; + tsi_frame_protector *server_frame_protector = NULL; + /* Perform handshake. */ + tsi_test_do_handshake(fixture); + /* Create frame protectors.*/ + size_t client_max_output_protected_frame_size = + config->client_max_output_protected_frame_size; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + fixture->client_result, + client_max_output_protected_frame_size == 0 + ? NULL + : &client_max_output_protected_frame_size, + &client_frame_protector) == TSI_OK); + size_t server_max_output_protected_frame_size = + config->server_max_output_protected_frame_size; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + fixture->server_result, + server_max_output_protected_frame_size == 0 + ? NULL + : &server_max_output_protected_frame_size, + &server_frame_protector) == TSI_OK); + /* Client sends a message to server. */ + send_message_to_peer(fixture, client_frame_protector, true /* is_client */); + unsigned char *server_received_message = + gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE); + size_t server_received_message_size = 0; + receive_message_from_peer( + fixture, server_frame_protector, server_received_message, + &server_received_message_size, false /* is_client */); + GPR_ASSERT(config->client_message_size == server_received_message_size); + GPR_ASSERT(memcmp(config->client_message, server_received_message, + server_received_message_size) == 0); + /* Server sends a message to client. */ + send_message_to_peer(fixture, server_frame_protector, false /* is_client */); + unsigned char *client_received_message = + gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE); + size_t client_received_message_size = 0; + receive_message_from_peer( + fixture, client_frame_protector, client_received_message, + &client_received_message_size, true /* is_client */); + GPR_ASSERT(config->server_message_size == client_received_message_size); + GPR_ASSERT(memcmp(config->server_message, client_received_message, + client_received_message_size) == 0); + /* Destroy server and client frame protectors. */ + tsi_frame_protector_destroy(client_frame_protector); + tsi_frame_protector_destroy(server_frame_protector); + gpr_free(server_received_message); + gpr_free(client_received_message); +} + +static unsigned char *generate_random_message(size_t size) { + size_t i; + unsigned char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + unsigned char *output = gpr_zalloc(sizeof(unsigned char) * size); + for (i = 0; i < size - 1; ++i) { + output[i] = chars[rand() % (int)(sizeof(chars) - 1)]; + } + return output; +} + +tsi_test_frame_protector_config *tsi_test_frame_protector_config_create( + bool use_default_read_buffer_allocated_size, + bool use_default_message_buffer_allocated_size, + bool use_default_protected_buffer_size, bool use_default_client_message, + bool use_default_server_message, + bool use_default_client_max_output_protected_frame_size, + bool use_default_server_max_output_protected_frame_size, + bool use_default_handshake_buffer_size) { + tsi_test_frame_protector_config *config = gpr_zalloc(sizeof(*config)); + /* Set the value for read_buffer_allocated_size. */ + config->read_buffer_allocated_size = + use_default_read_buffer_allocated_size + ? TSI_TEST_DEFAULT_BUFFER_SIZE + : TSI_TEST_SMALL_READ_BUFFER_ALLOCATED_SIZE; + /* Set the value for message_buffer_allocated_size. */ + config->message_buffer_allocated_size = + use_default_message_buffer_allocated_size + ? TSI_TEST_DEFAULT_BUFFER_SIZE + : TSI_TEST_SMALL_MESSAGE_BUFFER_ALLOCATED_SIZE; + /* Set the value for protected_buffer_size. */ + config->protected_buffer_size = use_default_protected_buffer_size + ? TSI_TEST_DEFAULT_PROTECTED_BUFFER_SIZE + : TSI_TEST_SMALL_PROTECTED_BUFFER_SIZE; + /* Set the value for client message. */ + config->client_message_size = use_default_client_message + ? TSI_TEST_BIG_MESSAGE_SIZE + : TSI_TEST_SMALL_MESSAGE_SIZE; + config->client_message = + use_default_client_message + ? generate_random_message(TSI_TEST_BIG_MESSAGE_SIZE) + : generate_random_message(TSI_TEST_SMALL_MESSAGE_SIZE); + /* Set the value for server message. */ + config->server_message_size = use_default_server_message + ? TSI_TEST_BIG_MESSAGE_SIZE + : TSI_TEST_SMALL_MESSAGE_SIZE; + config->server_message = + use_default_server_message + ? generate_random_message(TSI_TEST_BIG_MESSAGE_SIZE) + : generate_random_message(TSI_TEST_SMALL_MESSAGE_SIZE); + /* Set the value for client max_output_protected_frame_size. + If it is 0, we pass NULL to tsi_handshaker_result_create_frame_protector(), + which then uses default protected frame size for it. */ + config->client_max_output_protected_frame_size = + use_default_client_max_output_protected_frame_size + ? 0 + : TSI_TEST_SMALL_CLIENT_MAX_OUTPUT_PROTECTED_FRAME_SIZE; + /* Set the value for server max_output_protected_frame_size. + If it is 0, we pass NULL to tsi_handshaker_result_create_frame_protector(), + which then uses default protected frame size for it. */ + config->server_max_output_protected_frame_size = + use_default_server_max_output_protected_frame_size + ? 0 + : TSI_TEST_SMALL_SERVER_MAX_OUTPUT_PROTECTED_FRAME_SIZE; + return config; +} + +void tsi_test_frame_protector_config_set_buffer_size( + tsi_test_frame_protector_config *config, size_t read_buffer_allocated_size, + size_t message_buffer_allocated_size, size_t protected_buffer_size, + size_t client_max_output_protected_frame_size, + size_t server_max_output_protected_frame_size) { + GPR_ASSERT(config != NULL); + config->read_buffer_allocated_size = read_buffer_allocated_size; + config->message_buffer_allocated_size = message_buffer_allocated_size; + config->protected_buffer_size = protected_buffer_size; + config->client_max_output_protected_frame_size = + client_max_output_protected_frame_size; + config->server_max_output_protected_frame_size = + server_max_output_protected_frame_size; +} + +void tsi_test_frame_protector_config_destroy( + tsi_test_frame_protector_config *config) { + GPR_ASSERT(config != NULL); + gpr_free(config->client_message); + gpr_free(config->server_message); + gpr_free(config); +} + +void tsi_test_fixture_init(tsi_test_fixture *fixture) { + fixture->config = tsi_test_frame_protector_config_create( + true, true, true, true, true, true, true, true); + fixture->handshake_buffer_size = TSI_TEST_DEFAULT_BUFFER_SIZE; + fixture->client_channel = gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE); + fixture->server_channel = gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE); + fixture->bytes_written_to_client_channel = 0; + fixture->bytes_written_to_server_channel = 0; + fixture->bytes_read_from_client_channel = 0; + fixture->bytes_read_from_server_channel = 0; + fixture->test_unused_bytes = true; + fixture->has_client_finished_first = false; +} + +void tsi_test_fixture_destroy(tsi_test_fixture *fixture) { + GPR_ASSERT(fixture != NULL); + tsi_test_frame_protector_config_destroy(fixture->config); + tsi_handshaker_destroy(fixture->client_handshaker); + tsi_handshaker_destroy(fixture->server_handshaker); + tsi_handshaker_result_destroy(fixture->client_result); + tsi_handshaker_result_destroy(fixture->server_result); + gpr_free(fixture->client_channel); + gpr_free(fixture->server_channel); + GPR_ASSERT(fixture->vtable != NULL); + GPR_ASSERT(fixture->vtable->destruct != NULL); + fixture->vtable->destruct(fixture); + gpr_free(fixture); +} diff --git a/test/core/tsi/transport_security_test_lib.h b/test/core/tsi/transport_security_test_lib.h new file mode 100644 index 0000000000..8ae2024ee4 --- /dev/null +++ b/test/core/tsi/transport_security_test_lib.h @@ -0,0 +1,165 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CORE_TSI_TRANSPORT_SECURITY_TEST_LIB_H_ +#define GRPC_TEST_CORE_TSI_TRANSPORT_SECURITY_TEST_LIB_H_ + +#include "src/core/tsi/transport_security_interface.h" + +#define TSI_TEST_TINY_HANDSHAKE_BUFFER_SIZE 32 +#define TSI_TEST_SMALL_HANDSHAKE_BUFFER_SIZE 128 +#define TSI_TEST_SMALL_READ_BUFFER_ALLOCATED_SIZE 41 +#define TSI_TEST_SMALL_PROTECTED_BUFFER_SIZE 37 +#define TSI_TEST_SMALL_MESSAGE_BUFFER_ALLOCATED_SIZE 42 +#define TSI_TEST_SMALL_CLIENT_MAX_OUTPUT_PROTECTED_FRAME_SIZE 39 +#define TSI_TEST_SMALL_SERVER_MAX_OUTPUT_PROTECTED_FRAME_SIZE 43 +#define TSI_TEST_DEFAULT_BUFFER_SIZE 4096 +#define TSI_TEST_DEFAULT_PROTECTED_BUFFER_SIZE 16384 +#define TSI_TEST_DEFAULT_CHANNEL_SIZE 32768 +#define TSI_TEST_BIG_MESSAGE_SIZE 17000 +#define TSI_TEST_SMALL_MESSAGE_SIZE 10 +#define TSI_TEST_NUM_OF_ARGUMENTS 8 +#define TSI_TEST_NUM_OF_COMBINATIONS 256 +#define TSI_TEST_UNUSED_BYTES "HELLO GOOGLE" + +/* --- tsi_test_fixture object --- + The tests for specific TSI implementations should create their own + custom "subclass" of this fixture, which wraps all information + that will be used to test correctness of TSI handshakes and frame + protect/unprotect operations with respect to TSI implementations. */ +typedef struct tsi_test_fixture tsi_test_fixture; + +/* --- tsi_test_frame_protector_config object --- + + This object is used to configure different parameters of TSI frame protector + APIs. */ +typedef struct tsi_test_frame_protector_config tsi_test_frame_protector_config; + +/* V-table for tsi_test_fixture operations that are implemented differently in + different TSI implementations. */ +typedef struct tsi_test_fixture_vtable { + void (*setup_handshakers)(tsi_test_fixture *fixture); + void (*check_handshaker_peers)(tsi_test_fixture *fixture); + void (*destruct)(tsi_test_fixture *fixture); +} tranport_security_test_vtable; + +struct tsi_test_fixture { + const struct tsi_test_fixture_vtable *vtable; + /* client/server TSI handshaker used to perform TSI handshakes, and will get + instantiated during the call to setup_handshakers. */ + tsi_handshaker *client_handshaker; + tsi_handshaker *server_handshaker; + /* client/server TSI handshaker results used to store the result of TSI + handshake. If the handshake fails, the result will store NULL upon + finishing the handshake. */ + tsi_handshaker_result *client_result; + tsi_handshaker_result *server_result; + /* size of buffer used to store data received from the peer. */ + size_t handshake_buffer_size; + /* simulated channels between client and server. If the server (client) + wants to send data to the client (server), he will write data to + client_channel (server_channel), which will be read by client (server). */ + uint8_t *client_channel; + uint8_t *server_channel; + /* size of data written to the client/server channel. */ + size_t bytes_written_to_client_channel; + size_t bytes_written_to_server_channel; + /* size of data read from the client/server channel */ + size_t bytes_read_from_client_channel; + size_t bytes_read_from_server_channel; + /* tsi_test_frame_protector_config instance */ + tsi_test_frame_protector_config *config; + /* a flag indicating if client has finished TSI handshake first (i.e., before + server). + The flag should be referred if and only if TSI handshake finishes + successfully. */ + bool has_client_finished_first; + /* a flag indicating whether to test tsi_handshaker_result_get_unused_bytes() + for TSI implementation. This field is true by default, and false + for SSL TSI implementation due to grpc issue #12164 + (https://github.com/grpc/grpc/issues/12164). + */ + bool test_unused_bytes; +}; + +struct tsi_test_frame_protector_config { + /* size of buffer used to store protected frames to be unprotected. */ + size_t read_buffer_allocated_size; + /* size of buffer used to store bytes resulted from unprotect operations. */ + size_t message_buffer_allocated_size; + /* size of buffer used to store frames resulted from protect operations. */ + size_t protected_buffer_size; + /* size of client/server maximum frame size. */ + size_t client_max_output_protected_frame_size; + size_t server_max_output_protected_frame_size; + /* pointer that points to client/server message to be protected. */ + uint8_t *client_message; + uint8_t *server_message; + /* size of client/server message. */ + size_t client_message_size; + size_t server_message_size; +}; + +/* This method creates a tsi_test_frame_protector_config instance. Each + parameter of this function is a boolean value indicating whether to set the + corresponding parameter with a default value or not. If it's false, it will + be set with a specific value which is usually much smaller than the default. + Both values are defined with #define directive. */ +tsi_test_frame_protector_config *tsi_test_frame_protector_config_create( + bool use_default_read_buffer_allocated_size, + bool use_default_message_buffer_allocated_size, + bool use_default_protected_buffer_size, bool use_default_client_message, + bool use_default_server_message, + bool use_default_client_max_output_protected_frame_size, + bool use_default_server_max_output_protected_frame_size, + bool use_default_handshake_buffer_size); + +/* This method sets different buffer and frame sizes of a + tsi_test_frame_protector_config instance with user provided values. */ +void tsi_test_frame_protector_config_set_buffer_size( + tsi_test_frame_protector_config *config, size_t read_buffer_allocated_size, + size_t message_buffer_allocated_size, size_t protected_buffer_size, + size_t client_max_output_protected_frame_size, + size_t server_max_output_protected_frame_size); + +/* This method destroys a tsi_test_frame_protector_config instance. */ +void tsi_test_frame_protector_config_destroy( + tsi_test_frame_protector_config *config); + +/* This method initializes members of tsi_test_fixture instance. + Note that the struct instance should be allocated before making + this call. */ +void tsi_test_fixture_init(tsi_test_fixture *fixture); + +/* This method destroys a tsi_test_fixture instance. Note that the + fixture intance must be dynamically allocated and will be freed by + this function. */ +void tsi_test_fixture_destroy(tsi_test_fixture *fixture); + +/* This method performs a full TSI handshake between a client and a server. + Note that the test library will implement the new TSI handshaker API to + perform handshakes. */ +void tsi_test_do_handshake(tsi_test_fixture *fixture); + +/* This method performs a round trip test between the client and the server. + That is, the client sends a protected message to a server who receives the + message, and unprotects it. The same operation is triggered again with + the client and server switching its role. */ +void tsi_test_do_round_trip(tsi_test_fixture *fixture); + +#endif // GRPC_TEST_CORE_TSI_TRANSPORT_SECURITY_TEST_LIB_H_ diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD index b8505c1ae7..b29a13d4fb 100644 --- a/test/cpp/end2end/BUILD +++ b/test/cpp/end2end/BUILD @@ -193,6 +193,7 @@ grpc_cc_test( "//test/cpp/util:test_util", ], external_deps = [ + "gmock", "gtest", ], ) @@ -235,6 +236,7 @@ grpc_cc_test( "//test/cpp/util:test_util", ], external_deps = [ + "gmock", "gtest", ], ) diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index 7cb7b262de..e841a702d4 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -260,11 +260,31 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { server_address_ << "localhost:" << port_; // Setup server + BuildAndStartServer(); + + gpr_tls_set(&g_is_async_end2end_test, 1); + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + stub_.reset(); + poll_overrider_.reset(); + gpr_tls_set(&g_is_async_end2end_test, 0); + grpc_recycle_unused_port(port_); + } + + void BuildAndStartServer() { ServerBuilder builder; auto server_creds = GetCredentialsProvider()->GetServerCredentials( GetParam().credentials_type); builder.AddListeningPort(server_address_.str(), server_creds); - builder.RegisterService(&service_); + service_.reset(new grpc::testing::EchoTestService::AsyncService()); + builder.RegisterService(service_.get()); if (GetParam().health_check_service) { builder.RegisterService(&health_check_); } @@ -276,20 +296,6 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { new ServerBuilderSyncPluginDisabler()); builder.SetOption(move(sync_plugin_disabler)); server_ = builder.BuildAndStart(); - - gpr_tls_set(&g_is_async_end2end_test, 1); - } - - void TearDown() override { - server_->Shutdown(); - void* ignored_tag; - bool ignored_ok; - cq_->Shutdown(); - while (cq_->Next(&ignored_tag, &ignored_ok)) - ; - poll_overrider_.reset(); - gpr_tls_set(&g_is_async_end2end_test, 0); - grpc_recycle_unused_port(port_); } void ResetStub() { @@ -319,8 +325,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, + cq_.get(), cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); @@ -341,7 +347,7 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { std::unique_ptr<ServerCompletionQueue> cq_; std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; std::unique_ptr<Server> server_; - grpc::testing::EchoTestService::AsyncService service_; + std::unique_ptr<grpc::testing::EchoTestService::AsyncService> service_; HealthCheck health_check_; std::ostringstream server_address_; int port_; @@ -359,6 +365,26 @@ TEST_P(AsyncEnd2endTest, SequentialRpcs) { SendRpc(10); } +TEST_P(AsyncEnd2endTest, ReconnectChannel) { + if (GetParam().inproc) { + return; + } + ResetStub(); + SendRpc(1); + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + BuildAndStartServer(); + // It needs more than kConnectivityCheckIntervalMsec time to reconnect the + // channel. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(1600, GPR_TIMESPAN))); + SendRpc(1); +} + // We do not need to protect notify because the use is synchronized. void ServerWait(Server* server, int* notify) { server->Wait(); @@ -407,8 +433,8 @@ TEST_P(AsyncEnd2endTest, AsyncNextRpc) { Verifier(GetParam().disable_blocking).Verify(cq_.get(), time_now); Verifier(GetParam().disable_blocking).Verify(cq_.get(), time_now); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking) .Expect(2, true) @@ -444,8 +470,8 @@ TEST_P(AsyncEnd2endTest, SimpleClientStreaming) { std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); - service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); Verifier(GetParam().disable_blocking) .Expect(2, true) @@ -506,8 +532,8 @@ TEST_P(AsyncEnd2endTest, SimpleClientStreamingWithCoalescingApi) { std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); - service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); cli_stream->Write(send_request, tag(3)); @@ -579,8 +605,8 @@ TEST_P(AsyncEnd2endTest, SimpleServerStreaming) { std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); - service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, - cq_.get(), cq_.get(), tag(2)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); Verifier(GetParam().disable_blocking) .Expect(1, true) @@ -635,8 +661,8 @@ TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWAF) { std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); - service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, - cq_.get(), cq_.get(), tag(2)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); Verifier(GetParam().disable_blocking) .Expect(1, true) @@ -687,8 +713,8 @@ TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWL) { std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); - service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, - cq_.get(), cq_.get(), tag(2)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); Verifier(GetParam().disable_blocking) .Expect(1, true) @@ -741,8 +767,8 @@ TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) { std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); - service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); Verifier(GetParam().disable_blocking) .Expect(1, true) @@ -801,8 +827,8 @@ TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWAF) { std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); - service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); @@ -869,8 +895,8 @@ TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWL) { std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); - service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); @@ -946,8 +972,8 @@ TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) { std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); auto client_initial_metadata = srv_ctx.client_metadata(); @@ -991,8 +1017,8 @@ TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) { std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); srv_ctx.AddInitialMetadata(meta1.first, meta1.second); @@ -1041,8 +1067,8 @@ TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) { std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); response_writer.SendInitialMetadata(tag(3)); @@ -1104,8 +1130,8 @@ TEST_P(AsyncEnd2endTest, MetadataRpc) { std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); auto client_initial_metadata = srv_ctx.client_metadata(); @@ -1168,8 +1194,8 @@ TEST_P(AsyncEnd2endTest, ServerCheckCancellation) { stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); srv_ctx.AsyncNotifyWhenDone(tag(5)); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); @@ -1203,8 +1229,8 @@ TEST_P(AsyncEnd2endTest, ServerCheckDone) { stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); srv_ctx.AsyncNotifyWhenDone(tag(5)); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), - cq_.get(), tag(2)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); @@ -1295,8 +1321,8 @@ class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { // On the server, request to be notified of 'RequestStream' calls // and receive the 'RequestStream' call just made by the client srv_ctx.AsyncNotifyWhenDone(tag(11)); - service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); // Client sends 3 messages (tags 3, 4 and 5) @@ -1426,8 +1452,8 @@ class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { // On the server, request to be notified of 'ResponseStream' calls and // receive the call just made by the client srv_ctx.AsyncNotifyWhenDone(tag(11)); - service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, - cq_.get(), cq_.get(), tag(2)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); EXPECT_EQ(send_request.message(), recv_request.message()); @@ -1562,8 +1588,8 @@ class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { // On the server, request to be notified of the 'BidiStream' call and // receive the call just made by the client srv_ctx.AsyncNotifyWhenDone(tag(11)); - service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), - tag(2)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); Verifier(GetParam().disable_blocking).Expect(2, true).Verify(cq_.get()); // Client sends the first and the only message diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index b588eda84f..54408db600 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -180,16 +180,18 @@ class ClientLbEnd2endTest : public ::testing::Test { std::unique_ptr<Server> server_; MyTestServiceImpl service_; std::unique_ptr<std::thread> thread_; + bool server_ready_ = false; explicit ServerData(const grpc::string& server_host, int port = 0) { port_ = port > 0 ? port : grpc_pick_unused_port_or_die(); gpr_log(GPR_INFO, "starting server on port %d", port_); std::mutex mu; + std::unique_lock<std::mutex> lock(mu); std::condition_variable cond; thread_.reset(new std::thread( std::bind(&ServerData::Start, this, server_host, &mu, &cond))); - std::unique_lock<std::mutex> lock(mu); - cond.wait(lock); + cond.wait(lock, [this] { return server_ready_; }); + server_ready_ = false; gpr_log(GPR_INFO, "server startup complete"); } @@ -203,6 +205,7 @@ class ClientLbEnd2endTest : public ::testing::Test { builder.RegisterService(&service_); server_ = builder.BuildAndStart(); std::lock_guard<std::mutex> lock(*mu); + server_ready_ = true; cond->notify_one(); } diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 8bada48a2b..1f4861a7e6 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -238,6 +238,18 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> { int port = grpc_pick_unused_port_or_die(); server_address_ << "127.0.0.1:" << port; // Setup server + BuildAndStartServer(processor); + } + + void RestartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) { + if (is_server_started_) { + server_->Shutdown(); + BuildAndStartServer(processor); + } + } + + void BuildAndStartServer( + const std::shared_ptr<AuthMetadataProcessor>& processor) { ServerBuilder builder; ConfigureServerBuilder(&builder); auto server_creds = GetCredentialsProvider()->GetServerCredentials( @@ -685,6 +697,20 @@ TEST_P(End2endTest, MultipleRpcs) { } } +TEST_P(End2endTest, ReconnectChannel) { + if (GetParam().inproc) { + return; + } + ResetStub(); + SendRpc(stub_.get(), 1, false); + RestartServer(std::shared_ptr<AuthMetadataProcessor>()); + // It needs more than kConnectivityCheckIntervalMsec time to reconnect the + // channel. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(1600, GPR_TIMESPAN))); + SendRpc(stub_.get(), 1, false); +} + TEST_P(End2endTest, RequestStreamOneRequest) { ResetStub(); EchoRequest request; diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc index 508f7f94d6..518c65ac8d 100644 --- a/test/cpp/microbenchmarks/bm_call_create.cc +++ b/test/cpp/microbenchmarks/bm_call_create.cc @@ -39,6 +39,7 @@ extern "C" { #include "src/core/ext/filters/message_size/message_size_filter.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/iomgr/call_combiner.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/surface/channel.h" #include "src/core/lib/transport/transport_impl.h" @@ -396,10 +397,6 @@ grpc_error *InitChannelElem(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, void DestroyChannelElem(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {} -char *GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem) { - return gpr_strdup("peer"); -} - void GetChannelInfo(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, const grpc_channel_info *channel_info) {} @@ -412,7 +409,6 @@ static const grpc_channel_filter dummy_filter = {StartTransportStreamOp, 0, InitChannelElem, DestroyChannelElem, - GetPeer, GetChannelInfo, "dummy_filter"}; @@ -459,11 +455,6 @@ void DestroyStream(grpc_exec_ctx *exec_ctx, grpc_transport *self, /* implementation of grpc_transport_destroy */ void Destroy(grpc_exec_ctx *exec_ctx, grpc_transport *self) {} -/* implementation of grpc_transport_get_peer */ -char *GetPeer(grpc_exec_ctx *exec_ctx, grpc_transport *self) { - return gpr_strdup("transport_peer"); -} - /* implementation of grpc_transport_get_endpoint */ grpc_endpoint *GetEndpoint(grpc_exec_ctx *exec_ctx, grpc_transport *self) { return nullptr; @@ -473,7 +464,7 @@ static const grpc_transport_vtable dummy_transport_vtable = { 0, "dummy_http2", InitStream, SetPollset, SetPollsetSet, PerformStreamOp, PerformOp, DestroyStream, Destroy, - GetPeer, GetEndpoint}; + GetEndpoint}; static grpc_transport dummy_transport = {&dummy_transport_vtable}; @@ -639,18 +630,22 @@ BENCHMARK_TEMPLATE(BM_IsolatedFilter, LoadReportingFilter, SendEmptyMetadata); namespace isolated_call_filter { +typedef struct { grpc_call_combiner *call_combiner; } call_data; + static void StartTransportStreamOp(grpc_exec_ctx *exec_ctx, grpc_call_element *elem, grpc_transport_stream_op_batch *op) { + call_data *calld = static_cast<call_data *>(elem->call_data); if (op->recv_initial_metadata) { - GRPC_CLOSURE_SCHED( - exec_ctx, + GRPC_CALL_COMBINER_START( + exec_ctx, calld->call_combiner, op->payload->recv_initial_metadata.recv_initial_metadata_ready, - GRPC_ERROR_NONE); + GRPC_ERROR_NONE, "recv_initial_metadata"); } if (op->recv_message) { - GRPC_CLOSURE_SCHED(exec_ctx, op->payload->recv_message.recv_message_ready, - GRPC_ERROR_NONE); + GRPC_CALL_COMBINER_START(exec_ctx, calld->call_combiner, + op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE, "recv_message"); } GRPC_CLOSURE_SCHED(exec_ctx, op->on_complete, GRPC_ERROR_NONE); } @@ -667,6 +662,8 @@ static void StartTransportOp(grpc_exec_ctx *exec_ctx, static grpc_error *InitCallElem(grpc_exec_ctx *exec_ctx, grpc_call_element *elem, const grpc_call_element_args *args) { + call_data *calld = static_cast<call_data *>(elem->call_data); + calld->call_combiner = args->call_combiner; return GRPC_ERROR_NONE; } @@ -687,24 +684,19 @@ grpc_error *InitChannelElem(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, void DestroyChannelElem(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {} -char *GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem) { - return gpr_strdup("peer"); -} - void GetChannelInfo(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, const grpc_channel_info *channel_info) {} static const grpc_channel_filter isolated_call_filter = { StartTransportStreamOp, StartTransportOp, - 0, + sizeof(call_data), InitCallElem, SetPollsetOrPollsetSet, DestroyCallElem, 0, InitChannelElem, DestroyChannelElem, - GetPeer, GetChannelInfo, "isolated_call_filter"}; } // namespace isolated_call_filter diff --git a/test/cpp/microbenchmarks/bm_chttp2_transport.cc b/test/cpp/microbenchmarks/bm_chttp2_transport.cc index cb113c5254..936681fec1 100644 --- a/test/cpp/microbenchmarks/bm_chttp2_transport.cc +++ b/test/cpp/microbenchmarks/bm_chttp2_transport.cc @@ -286,6 +286,7 @@ static void BM_StreamCreateSendInitialMetadataDestroy(benchmark::State &state) { Stream s(&f); grpc_transport_stream_op_batch op; grpc_transport_stream_op_batch_payload op_payload; + memset(&op_payload, 0, sizeof(op_payload)); std::unique_ptr<Closure> start; std::unique_ptr<Closure> done; @@ -337,6 +338,7 @@ static void BM_TransportEmptyOp(benchmark::State &state) { s.Init(state); grpc_transport_stream_op_batch op; grpc_transport_stream_op_batch_payload op_payload; + memset(&op_payload, 0, sizeof(op_payload)); auto reset_op = [&]() { memset(&op, 0, sizeof(op)); op.payload = &op_payload; @@ -364,6 +366,7 @@ static void BM_TransportStreamSend(benchmark::State &state) { s.Init(state); grpc_transport_stream_op_batch op; grpc_transport_stream_op_batch_payload op_payload; + memset(&op_payload, 0, sizeof(op_payload)); auto reset_op = [&]() { memset(&op, 0, sizeof(op)); op.payload = &op_payload; @@ -485,6 +488,7 @@ static void BM_TransportStreamRecv(benchmark::State &state) { Stream s(&f); s.Init(state); grpc_transport_stream_op_batch_payload op_payload; + memset(&op_payload, 0, sizeof(op_payload)); grpc_transport_stream_op_batch op; grpc_byte_stream *recv_stream; grpc_slice incoming_data = CreateIncomingDataSlice(state.range(0), 16384); diff --git a/test/cpp/qps/BUILD b/test/cpp/qps/BUILD index 31f210dec0..3352269517 100644 --- a/test/cpp/qps/BUILD +++ b/test/cpp/qps/BUILD @@ -46,6 +46,7 @@ grpc_cc_library( ":usage_timer", "//:grpc", "//:grpc++", + "//:grpc++_core_stats", "//src/proto/grpc/testing:control_proto", "//src/proto/grpc/testing:payloads_proto", "//src/proto/grpc/testing:services_proto", |