diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/core/support/BUILD | 2 | ||||
-rw-r--r-- | test/core/util/memory_counters.h | 2 | ||||
-rw-r--r-- | test/cpp/end2end/async_end2end_test.cc | 19 | ||||
-rw-r--r-- | test/cpp/end2end/end2end_test.cc | 20 | ||||
-rw-r--r-- | test/cpp/interop/client.cc | 1 | ||||
-rw-r--r-- | test/cpp/interop/client_helper.cc | 10 | ||||
-rw-r--r-- | test/cpp/interop/http2_client.cc | 272 | ||||
-rw-r--r-- | test/cpp/interop/http2_client.h | 80 | ||||
-rw-r--r-- | test/cpp/interop/interop_server.cc | 1 | ||||
-rw-r--r-- | test/cpp/interop/server_helper.cc | 18 | ||||
-rw-r--r-- | test/cpp/interop/stress_test.cc | 1 | ||||
-rw-r--r-- | test/cpp/qps/client.h | 13 | ||||
-rw-r--r-- | test/cpp/qps/qps_json_driver.cc | 1 | ||||
-rw-r--r-- | test/cpp/util/create_test_channel.cc | 64 | ||||
-rw-r--r-- | test/cpp/util/create_test_channel.h | 4 | ||||
-rw-r--r-- | test/cpp/util/test_credentials_provider.cc | 52 | ||||
-rw-r--r-- | test/cpp/util/test_credentials_provider.h | 50 | ||||
-rw-r--r-- | test/http2_test/http2_base_server.py | 1 | ||||
-rw-r--r-- | test/http2_test/http2_test_server.py | 34 |
19 files changed, 540 insertions, 105 deletions
diff --git a/test/core/support/BUILD b/test/core/support/BUILD index 77f0a9a048..dfe952eb37 100644 --- a/test/core/support/BUILD +++ b/test/core/support/BUILD @@ -27,8 +27,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer") - cc_test( name = "alloc_test", srcs = ["alloc_test.c"], diff --git a/test/core/util/memory_counters.h b/test/core/util/memory_counters.h index f332816501..b9b2b3adda 100644 --- a/test/core/util/memory_counters.h +++ b/test/core/util/memory_counters.h @@ -34,6 +34,8 @@ #ifndef GRPC_TEST_CORE_UTIL_MEMORY_COUNTERS_H #define GRPC_TEST_CORE_UTIL_MEMORY_COUNTERS_H +#include <stddef.h> + struct grpc_memory_counters { size_t total_size_relative; size_t total_size_absolute; diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index 8e385d100c..2ce3f2f7bd 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -254,7 +254,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { // Setup server ServerBuilder builder; - auto server_creds = GetServerCredentials(GetParam().credentials_type); + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); builder.AddListeningPort(server_address_.str(), server_creds); builder.RegisterService(&service_); cq_ = builder.AddCompletionQueue(); @@ -283,8 +284,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { void ResetStub() { ChannelArguments args; - auto channel_creds = - GetChannelCredentials(GetParam().credentials_type, &args); + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); std::shared_ptr<Channel> channel = CreateCustomChannel(server_address_.str(), channel_creds, args); stub_ = grpc::testing::EchoTestService::NewStub(channel); @@ -892,8 +893,8 @@ TEST_P(AsyncEnd2endTest, ServerCheckDone) { TEST_P(AsyncEnd2endTest, UnimplementedRpc) { ChannelArguments args; - auto channel_creds = - GetChannelCredentials(GetParam().credentials_type, &args); + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); std::shared_ptr<Channel> channel = CreateCustomChannel(server_address_.str(), channel_creds, args); std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; @@ -1404,11 +1405,15 @@ std::vector<TestScenario> CreateTestScenarios(bool test_disable_blocking, std::vector<grpc::string> credentials_types; std::vector<grpc::string> messages; - credentials_types.push_back(kInsecureCredentialsType); - auto sec_list = GetSecureCredentialsTypeList(); + if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType, + nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { credentials_types.push_back(*sec); } + GPR_ASSERT(!credentials_types.empty()); messages.push_back("Hello"); for (int sz = 1; sz < test_big_limit; sz *= 2) { diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 9bb892c694..1a1a94e87c 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -242,7 +242,8 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> { // Setup server ServerBuilder builder; ConfigureServerBuilder(&builder); - auto server_creds = GetServerCredentials(GetParam().credentials_type); + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); if (GetParam().credentials_type != kInsecureCredentialsType) { server_creds->SetAuthMetadataProcessor(processor); } @@ -270,8 +271,8 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> { } EXPECT_TRUE(is_server_started_); ChannelArguments args; - auto channel_creds = - GetChannelCredentials(GetParam().credentials_type, &args); + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); if (!user_agent_prefix_.empty()) { args.SetUserAgentPrefix(user_agent_prefix_); } @@ -1520,11 +1521,18 @@ std::vector<TestScenario> CreateTestScenarios(bool use_proxy, std::vector<TestScenario> scenarios; std::vector<grpc::string> credentials_types; if (test_secure) { - credentials_types = GetSecureCredentialsTypeList(); + credentials_types = + GetCredentialsProvider()->GetSecureCredentialsTypeList(); } if (test_insecure) { - credentials_types.push_back(kInsecureCredentialsType); + // Only add insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + if (GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } } + GPR_ASSERT(!credentials_types.empty()); for (auto it = credentials_types.begin(); it != credentials_types.end(); ++it) { scenarios.emplace_back(false, *it); @@ -1541,7 +1549,7 @@ INSTANTIATE_TEST_CASE_P(End2end, End2endTest, INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest, ::testing::ValuesIn(CreateTestScenarios(false, true, - false))); + true))); INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest, ::testing::ValuesIn(CreateTestScenarios(true, true, diff --git a/test/cpp/interop/client.cc b/test/cpp/interop/client.cc index c58910abc3..3265554444 100644 --- a/test/cpp/interop/client.cc +++ b/test/cpp/interop/client.cc @@ -49,6 +49,7 @@ #include "test/cpp/util/test_config.h" DEFINE_bool(use_tls, false, "Whether to use tls."); +DEFINE_string(custom_credentials_type, "", "User provided credentials type."); DEFINE_bool(use_test_ca, false, "False to use SSL roots for google"); DEFINE_int32(server_port, 0, "Server port."); DEFINE_string(server_host, "127.0.0.1", "Server host to connect to"); diff --git a/test/cpp/interop/client_helper.cc b/test/cpp/interop/client_helper.cc index c171969e14..91564e5dce 100644 --- a/test/cpp/interop/client_helper.cc +++ b/test/cpp/interop/client_helper.cc @@ -50,8 +50,10 @@ #include "src/cpp/client/secure_credentials.h" #include "test/core/security/oauth2_utils.h" #include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_credentials_provider.h" DECLARE_bool(use_tls); +DECLARE_string(custom_credentials_type); DECLARE_bool(use_test_ca); DECLARE_int32(server_port); DECLARE_string(server_host); @@ -114,8 +116,12 @@ std::shared_ptr<Channel> CreateChannelForTestCase( creds = AccessTokenCredentials(raw_token); GPR_ASSERT(creds); } - return CreateTestChannel(host_port, FLAGS_server_host_override, FLAGS_use_tls, - !FLAGS_use_test_ca, creds); + if (FLAGS_custom_credentials_type.empty()) { + return CreateTestChannel(host_port, FLAGS_server_host_override, + FLAGS_use_tls, !FLAGS_use_test_ca, creds); + } else { + return CreateTestChannel(host_port, FLAGS_custom_credentials_type, creds); + } } } // namespace testing diff --git a/test/cpp/interop/http2_client.cc b/test/cpp/interop/http2_client.cc new file mode 100644 index 0000000000..38aee43b26 --- /dev/null +++ b/test/cpp/interop/http2_client.cc @@ -0,0 +1,272 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <thread> + +#include <gflags/gflags.h> +#include <grpc++/channel.h> +#include <grpc++/client_context.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/useful.h> + +#include "src/core/lib/transport/byte_stream.h" +#include "src/proto/grpc/testing/messages.grpc.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/cpp/interop/http2_client.h" + +#include "src/core/lib/support/string.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +namespace { +const int kLargeRequestSize = 271828; +const int kLargeResponseSize = 314159; +} // namespace + +Http2Client::ServiceStub::ServiceStub(std::shared_ptr<Channel> channel) + : channel_(channel) { + stub_ = TestService::NewStub(channel); +} + +TestService::Stub* Http2Client::ServiceStub::Get() { return stub_.get(); } + +Http2Client::Http2Client(std::shared_ptr<Channel> channel) + : serviceStub_(channel), channel_(channel) {} + +bool Http2Client::AssertStatusCode(const Status& s, StatusCode expected_code) { + if (s.error_code() == expected_code) { + return true; + } + + gpr_log(GPR_ERROR, "Error status code: %d (expected: %d), message: %s", + s.error_code(), expected_code, s.error_message().c_str()); + abort(); +} + +bool Http2Client::DoRstAfterHeader() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream after header"); + + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::UNKNOWN); + GPR_ASSERT(!response.has_payload()); // no data should be received + + gpr_log(GPR_DEBUG, "Done testing reset stream after header"); + return true; +} + +bool Http2Client::DoRstAfterData() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream after data"); + + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::UNKNOWN); + GPR_ASSERT(response.has_payload()); // data should be received + + gpr_log(GPR_DEBUG, "Done testing reset stream after data"); + return true; +} + +bool Http2Client::DoRstDuringData() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream during data"); + + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::UNKNOWN); + GPR_ASSERT(!response.has_payload()); // no data should be received + + gpr_log(GPR_DEBUG, "Done testing reset stream during data"); + return true; +} + +bool Http2Client::DoGoaway() { + gpr_log(GPR_DEBUG, "Sending two RPCs and expecting goaway"); + + int numCalls = 2; + for (int i = 0; i < numCalls; i++) { + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + grpc::string(kLargeResponseSize, '\0')); + } + + gpr_log(GPR_DEBUG, "Done testing goaway"); + return true; +} + +bool Http2Client::DoPing() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting ping"); + + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + grpc::string(kLargeResponseSize, '\0')); + + gpr_log(GPR_DEBUG, "Done testing ping"); + return true; +} + +void Http2Client::MaxStreamsWorker(std::shared_ptr<grpc::Channel> channel) { + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = + TestService::NewStub(channel)->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + grpc::string(kLargeResponseSize, '\0')); +} + +bool Http2Client::DoMaxStreams() { + gpr_log(GPR_DEBUG, "Testing max streams"); + + // Make an initial call on the channel to ensure the server's max streams + // setting is received + ClientContext context; + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + grpc::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + Status s = + TestService::NewStub(channel_)->UnaryCall(&context, request, &response); + AssertStatusCode(s, grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + grpc::string(kLargeResponseSize, '\0')); + + std::vector<std::thread> test_threads; + + for (int i = 0; i < 10; i++) { + test_threads.emplace_back( + std::thread(&Http2Client::MaxStreamsWorker, this, channel_)); + } + + for (auto it = test_threads.begin(); it != test_threads.end(); it++) { + it->join(); + } + + gpr_log(GPR_DEBUG, "Done testing max streams"); + return true; +} + +} // namespace testing +} // namespace grpc + +DEFINE_int32(server_port, 0, "Server port."); +DEFINE_string(server_host, "127.0.0.1", "Server host to connect to"); +DEFINE_string(test_case, "rst_after_header", + "Configure different test cases. Valid options are:\n\n" + "goaway\n" + "max_streams\n" + "ping\n" + "rst_after_data\n" + "rst_after_header\n" + "rst_during_data\n"); + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + GPR_ASSERT(FLAGS_server_port); + const int host_port_buf_size = 1024; + char host_port[host_port_buf_size]; + snprintf(host_port, host_port_buf_size, "%s:%d", FLAGS_server_host.c_str(), + FLAGS_server_port); + grpc::testing::Http2Client client(grpc::CreateTestChannel(host_port, false)); + gpr_log(GPR_INFO, "Testing case: %s", FLAGS_test_case.c_str()); + int ret = 0; + if (FLAGS_test_case == "rst_after_header") { + client.DoRstAfterHeader(); + } else if (FLAGS_test_case == "rst_after_data") { + client.DoRstAfterData(); + } else if (FLAGS_test_case == "rst_during_data") { + client.DoRstDuringData(); + } else if (FLAGS_test_case == "goaway") { + client.DoGoaway(); + } else if (FLAGS_test_case == "ping") { + client.DoPing(); + } else if (FLAGS_test_case == "max_streams") { + client.DoMaxStreams(); + } else { + const char* testcases[] = { + "goaway", "max_streams", "ping", + "rst_after_data", "rst_after_header", "rst_during_data"}; + char* joined_testcases = + gpr_strjoin_sep(testcases, GPR_ARRAY_SIZE(testcases), "\n", NULL); + + gpr_log(GPR_ERROR, "Unsupported test case %s. Valid options are\n%s", + FLAGS_test_case.c_str(), joined_testcases); + gpr_free(joined_testcases); + ret = 1; + } + + return ret; +} diff --git a/test/cpp/interop/http2_client.h b/test/cpp/interop/http2_client.h new file mode 100644 index 0000000000..6a315f5abb --- /dev/null +++ b/test/cpp/interop/http2_client.h @@ -0,0 +1,80 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef GRPC_TEST_CPP_INTEROP_HTTP2_CLIENT_H +#define GRPC_TEST_CPP_INTEROP_HTTP2_CLIENT_H + +#include <memory> + +#include <grpc++/channel.h> +#include <grpc/grpc.h> +#include "src/proto/grpc/testing/messages.grpc.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" + +namespace grpc { +namespace testing { + +class Http2Client { + public: + explicit Http2Client(std::shared_ptr<Channel> channel); + ~Http2Client() {} + + bool DoRstAfterHeader(); + bool DoRstAfterData(); + bool DoRstDuringData(); + bool DoGoaway(); + bool DoPing(); + bool DoMaxStreams(); + + private: + class ServiceStub { + public: + ServiceStub(std::shared_ptr<Channel> channel); + + TestService::Stub* Get(); + + private: + std::unique_ptr<TestService::Stub> stub_; + std::shared_ptr<Channel> channel_; + }; + + void MaxStreamsWorker(std::shared_ptr<grpc::Channel> channel); + bool AssertStatusCode(const Status& s, StatusCode expected_code); + ServiceStub serviceStub_; + std::shared_ptr<Channel> channel_; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_INTEROP_HTTP2_CLIENT_H diff --git a/test/cpp/interop/interop_server.cc b/test/cpp/interop/interop_server.cc index 67456ce18b..956840ba70 100644 --- a/test/cpp/interop/interop_server.cc +++ b/test/cpp/interop/interop_server.cc @@ -56,6 +56,7 @@ #include "test/cpp/util/test_config.h" DEFINE_bool(use_tls, false, "Whether to use tls."); +DEFINE_string(custom_credentials_type, "", "User provided credentials type."); DEFINE_int32(port, 0, "Server port."); DEFINE_int32(max_send_message_size, -1, "The maximum send message size."); diff --git a/test/cpp/interop/server_helper.cc b/test/cpp/interop/server_helper.cc index 8b0b511bcb..d395f50fa5 100644 --- a/test/cpp/interop/server_helper.cc +++ b/test/cpp/interop/server_helper.cc @@ -39,23 +39,23 @@ #include <grpc++/security/server_credentials.h> #include "src/core/lib/surface/call_test_only.h" -#include "test/core/end2end/data/ssl_test_data.h" +#include "test/cpp/util/test_credentials_provider.h" DECLARE_bool(use_tls); +DECLARE_string(custom_credentials_type); namespace grpc { namespace testing { std::shared_ptr<ServerCredentials> CreateInteropServerCredentials() { - if (FLAGS_use_tls) { - SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, - test_server1_cert}; - SslServerCredentialsOptions ssl_opts; - ssl_opts.pem_root_certs = ""; - ssl_opts.pem_key_cert_pairs.push_back(pkcp); - return SslServerCredentials(ssl_opts); + if (!FLAGS_custom_credentials_type.empty()) { + return GetCredentialsProvider()->GetServerCredentials( + FLAGS_custom_credentials_type); + } else if (FLAGS_use_tls) { + return GetCredentialsProvider()->GetServerCredentials(kTlsCredentialsType); } else { - return InsecureServerCredentials(); + return GetCredentialsProvider()->GetServerCredentials( + kInsecureCredentialsType); } } diff --git a/test/cpp/interop/stress_test.cc b/test/cpp/interop/stress_test.cc index 97e658869f..562522de77 100644 --- a/test/cpp/interop/stress_test.cc +++ b/test/cpp/interop/stress_test.cc @@ -147,6 +147,7 @@ DEFINE_bool(do_not_abort_on_transient_failures, true, // Options from client.cc (for compatibility with interop test). // TODO(sreek): Consolidate overlapping options DEFINE_bool(use_tls, false, "Whether to use tls."); +DEFINE_string(custom_credentials_type, "", "User provided credentials type."); DEFINE_bool(use_test_ca, false, "False to use SSL roots for google"); DEFINE_int32(server_port, 0, "Server port."); DEFINE_string(server_host, "127.0.0.1", "Server host to connect to"); diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index fdd78ebb89..18f9778fc6 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -409,6 +409,7 @@ class ClientImpl : public Client { // old compilers happy with using this in std::vector ChannelArguments args; args.SetInt("shard_to_ensure_no_subchannel_merges", shard); + set_channel_args(config, &args); channel_ = CreateTestChannel( target, config.security_params().server_host_override(), config.has_security_params(), !config.security_params().use_test_ca(), @@ -423,6 +424,18 @@ class ClientImpl : public Client { StubType* get_stub() { return stub_.get(); } private: + void set_channel_args(const ClientConfig& config, ChannelArguments* args) { + for (auto channel_arg : config.channel_args()) { + if (channel_arg.value_case() == ChannelArg::kStrValue) { + args->SetString(channel_arg.name(), channel_arg.str_value()); + } else if (channel_arg.value_case() == ChannelArg::kIntValue) { + args->SetInt(channel_arg.name(), channel_arg.int_value()); + } else { + gpr_log(GPR_ERROR, "Empty channel arg value."); + } + } + } + std::shared_ptr<Channel> channel_; std::unique_ptr<StubType> stub_; }; diff --git a/test/cpp/qps/qps_json_driver.cc b/test/cpp/qps/qps_json_driver.cc index da835b995a..57ee5ef63c 100644 --- a/test/cpp/qps/qps_json_driver.cc +++ b/test/cpp/qps/qps_json_driver.cc @@ -212,6 +212,7 @@ static bool QpsDriver() { SearchOfferedLoad(FLAGS_initial_search_value, FLAGS_targeted_cpu_load, scenario, &success); gpr_log(GPR_INFO, "targeted_offered_load %f", targeted_offered_load); + GetCpuLoad(scenario, targeted_offered_load, &success); } else { gpr_log(GPR_ERROR, "Unimplemented search param"); } diff --git a/test/cpp/util/create_test_channel.cc b/test/cpp/util/create_test_channel.cc index fe8b5d5423..ad62e03490 100644 --- a/test/cpp/util/create_test_channel.cc +++ b/test/cpp/util/create_test_channel.cc @@ -35,11 +35,37 @@ #include <grpc++/create_channel.h> #include <grpc++/security/credentials.h> +#include <grpc/support/log.h> -#include "test/core/end2end/data/ssl_test_data.h" +#include "test/cpp/util/test_credentials_provider.h" namespace grpc { +namespace { + +const char kProdTlsCredentialsType[] = "prod_ssl"; + +class SslCredentialProvider : public testing::CredentialTypeProvider { + public: + std::shared_ptr<ChannelCredentials> GetChannelCredentials( + grpc::ChannelArguments* args) override { + return SslCredentials(SslCredentialsOptions()); + } + std::shared_ptr<ServerCredentials> GetServerCredentials() override { + return nullptr; + } +}; + +gpr_once g_once_init_add_prod_ssl_provider = GPR_ONCE_INIT; +// Register ssl with non-test roots type to the credentials provider. +void AddProdSslType() { + testing::GetCredentialsProvider()->AddSecureType( + kProdTlsCredentialsType, std::unique_ptr<testing::CredentialTypeProvider>( + new SslCredentialProvider)); +} + +} // namespace + // When ssl is enabled, if server is empty, override_hostname is used to // create channel. Otherwise, connect to server and override hostname if // override_hostname is provided. @@ -61,16 +87,22 @@ std::shared_ptr<Channel> CreateTestChannel( const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args) { ChannelArguments channel_args(args); + std::shared_ptr<ChannelCredentials> channel_creds; if (enable_ssl) { - const char* roots_certs = use_prod_roots ? "" : test_root_cert; - SslCredentialsOptions ssl_opts = {roots_certs, "", ""}; - - std::shared_ptr<ChannelCredentials> channel_creds = - SslCredentials(ssl_opts); - - if (!server.empty() && !override_hostname.empty()) { - channel_args.SetSslTargetNameOverride(override_hostname); + if (use_prod_roots) { + gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType); + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + kProdTlsCredentialsType, &channel_args); + if (!server.empty() && !override_hostname.empty()) { + channel_args.SetSslTargetNameOverride(override_hostname); + } + } else { + // override_hostname is discarded as the provider handles it. + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + testing::kTlsCredentialsType, &channel_args); } + GPR_ASSERT(channel_creds != nullptr); + const grpc::string& connect_to = server.empty() ? override_hostname : server; if (creds.get()) { @@ -103,4 +135,18 @@ std::shared_ptr<Channel> CreateTestChannel(const grpc::string& server, return CreateTestChannel(server, "foo.test.google.fr", enable_ssl, false); } +std::shared_ptr<Channel> CreateTestChannel( + const grpc::string& server, const grpc::string& credential_type, + const std::shared_ptr<CallCredentials>& creds) { + ChannelArguments channel_args; + std::shared_ptr<ChannelCredentials> channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = CompositeChannelCredentials(channel_creds, creds); + } + return CreateCustomChannel(server, channel_creds, channel_args); +} + } // namespace grpc diff --git a/test/cpp/util/create_test_channel.h b/test/cpp/util/create_test_channel.h index 4ff666dc1b..ce71a97edb 100644 --- a/test/cpp/util/create_test_channel.h +++ b/test/cpp/util/create_test_channel.h @@ -59,6 +59,10 @@ std::shared_ptr<Channel> CreateTestChannel( const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args); +std::shared_ptr<Channel> CreateTestChannel( + const grpc::string& server, const grpc::string& credential_type, + const std::shared_ptr<CallCredentials>& creds); + } // namespace grpc #endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc index 0456b96667..909b02a701 100644 --- a/test/cpp/util/test_credentials_provider.cc +++ b/test/cpp/util/test_credentials_provider.cc @@ -43,25 +43,9 @@ #include "test/core/end2end/data/ssl_test_data.h" namespace grpc { +namespace testing { namespace { -using grpc::testing::CredentialTypeProvider; - -// Provide test credentials. Thread-safe. -class CredentialsProvider { - public: - virtual ~CredentialsProvider() {} - - virtual void AddSecureType( - const grpc::string& type, - std::unique_ptr<CredentialTypeProvider> type_provider) = 0; - virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( - const grpc::string& type, ChannelArguments* args) = 0; - virtual std::shared_ptr<ServerCredentials> GetServerCredentials( - const grpc::string& type) = 0; - virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0; -}; - class DefaultCredentialsProvider : public CredentialsProvider { public: ~DefaultCredentialsProvider() override {} @@ -145,37 +129,21 @@ class DefaultCredentialsProvider : public CredentialsProvider { added_secure_type_providers_; }; -gpr_once g_once_init_provider = GPR_ONCE_INIT; CredentialsProvider* g_provider = nullptr; -void CreateDefaultProvider() { g_provider = new DefaultCredentialsProvider; } - -CredentialsProvider* GetProvider() { - gpr_once_init(&g_once_init_provider, &CreateDefaultProvider); - return g_provider; -} - } // namespace -namespace testing { - -void AddSecureType(const grpc::string& type, - std::unique_ptr<CredentialTypeProvider> type_provider) { - GetProvider()->AddSecureType(type, std::move(type_provider)); -} - -std::shared_ptr<ChannelCredentials> GetChannelCredentials( - const grpc::string& type, ChannelArguments* args) { - return GetProvider()->GetChannelCredentials(type, args); -} - -std::shared_ptr<ServerCredentials> GetServerCredentials( - const grpc::string& type) { - return GetProvider()->GetServerCredentials(type); +CredentialsProvider* GetCredentialsProvider() { + if (g_provider == nullptr) { + g_provider = new DefaultCredentialsProvider; + } + return g_provider; } -std::vector<grpc::string> GetSecureCredentialsTypeList() { - return GetProvider()->GetSecureCredentialsTypeList(); +void SetCredentialsProvider(CredentialsProvider* provider) { + // For now, forbids overriding provider. + GPR_ASSERT(g_provider == nullptr); + g_provider = provider; } } // namespace testing diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h index 1fb311e556..0bc52ebe4d 100644 --- a/test/cpp/util/test_credentials_provider.h +++ b/test/cpp/util/test_credentials_provider.h @@ -59,23 +59,39 @@ class CredentialTypeProvider { virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0; }; -// Add a secure type in addition to the defaults above -// (kInsecureCredentialsType, kTlsCredentialsType) that can be returned from the -// functions below. -void AddSecureType(const grpc::string& type, - std::unique_ptr<CredentialTypeProvider> type_provider); - -// Provide channel credentials according to the given type. Alter the channel -// arguments if needed. -std::shared_ptr<ChannelCredentials> GetChannelCredentials( - const grpc::string& type, ChannelArguments* args); - -// Provide server credentials according to the given type. -std::shared_ptr<ServerCredentials> GetServerCredentials( - const grpc::string& type); - -// Provide a list of secure credentials type. -std::vector<grpc::string> GetSecureCredentialsTypeList(); +// Provide test credentials. Thread-safe. +class CredentialsProvider { + public: + virtual ~CredentialsProvider() {} + + // Add a secure type in addition to the defaults. The default provider has + // (kInsecureCredentialsType, kTlsCredentialsType). + virtual void AddSecureType( + const grpc::string& type, + std::unique_ptr<CredentialTypeProvider> type_provider) = 0; + + // Provide channel credentials according to the given type. Alter the channel + // arguments if needed. Return nullptr if type is not registered. + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) = 0; + + // Provide server credentials according to the given type. + // Return nullptr if type is not registered. + virtual std::shared_ptr<ServerCredentials> GetServerCredentials( + const grpc::string& type) = 0; + + // Provide a list of secure credentials type. + virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0; +}; + +// Get the current provider. Create a default one if not set. +// Not thread-safe. +CredentialsProvider* GetCredentialsProvider(); + +// Set the global provider. Takes ownership. The previous set provider will be +// destroyed. +// Not thread-safe. +void SetCredentialsProvider(CredentialsProvider* provider); } // namespace testing } // namespace grpc diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py index ee7719b1a8..8de028ceb1 100644 --- a/test/http2_test/http2_base_server.py +++ b/test/http2_test/http2_base_server.py @@ -73,7 +73,6 @@ class H2ProtocolBaseServer(twisted.internet.protocol.Protocol): def on_connection_lost(self, reason): logging.info('Disconnected %s' % reason) - twisted.internet.reactor.callFromThread(twisted.internet.reactor.stop) def dataReceived(self, data): try: diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py index 44e36d34b6..abde3433ad 100644 --- a/test/http2_test/http2_test_server.py +++ b/test/http2_test/http2_test_server.py @@ -73,18 +73,32 @@ class H2Factory(twisted.internet.protocol.Factory): else: return t().get_base_server() +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--base_port', type=int, default=8080, + help='base port to run the servers (default: 8080). One test server is ' + 'started on each incrementing port, beginning with base_port, in the ' + 'following order: goaway,max_streams,ping,rst_after_data,rst_after_header,' + 'rst_during_data' + ) + return parser.parse_args() + +def start_test_servers(base_port): + """ Start one server per test case on incrementing port numbers + beginning with base_port """ + index = 0 + for test_case in sorted(_TEST_CASE_MAPPING.keys()): + portnum = base_port + index + logging.warning('serving on port %d : %s'%(portnum, test_case)) + endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( + twisted.internet.reactor, portnum, backlog=128) + endpoint.listen(H2Factory(test_case)) + index += 1 + if __name__ == '__main__': logging.basicConfig( format='%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s', level=logging.INFO) - parser = argparse.ArgumentParser() - parser.add_argument('--test_case', choices=sorted(_TEST_CASE_MAPPING.keys()), - help='test case to run', required=True) - parser.add_argument('--port', type=int, default=8080, - help='port to run the server (default: 8080)') - args = parser.parse_args() - logging.info('Running test case %s on port %d' % (args.test_case, args.port)) - endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( - twisted.internet.reactor, args.port, backlog=128) - endpoint.listen(H2Factory(args.test_case)) + args = parse_arguments() + start_test_servers(args.base_port) twisted.internet.reactor.run() |