diff options
-rw-r--r-- | test/cpp/util/test_credentials_provider.cc | 118 | ||||
-rw-r--r-- | test/cpp/util/test_credentials_provider.h | 16 |
2 files changed, 107 insertions, 27 deletions
diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc index 1086e14258..cfd3ebbb11 100644 --- a/test/cpp/util/test_credentials_provider.cc +++ b/test/cpp/util/test_credentials_provider.cc @@ -34,48 +34,112 @@ #include "test/cpp/util/test_credentials_provider.h" +#include <grpc/support/sync.h> +#include <grpc++/impl/sync.h> + #include "test/core/end2end/data/ssl_test_data.h" +namespace { + +using grpc::ChannelArguments; +using grpc::ChannelCredentials; +using grpc::InsecureChannelCredentials; +using grpc::InsecureServerCredentials; +using grpc::ServerCredentials; +using grpc::SslCredentialsOptions; +using grpc::SslServerCredentialsOptions; +using grpc::testing::CredentialsProvider; + +class DefaultCredentialsProvider : public CredentialsProvider { + public: + ~DefaultCredentialsProvider() override {} + + std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureChannelCredentials(); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + args->SetSslTargetNameOverride("foo.test.google.fr"); + return SslCredentials(ssl_opts); + } else { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + } + return nullptr; + } + + std::shared_ptr<ServerCredentials> GetServerCredentials( + const grpc::string& type) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureServerCredentials(); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, + test_server1_cert}; + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + return SslServerCredentials(ssl_opts); + } else { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + } + return nullptr; + } + std::vector<grpc::string> GetSecureCredentialsTypeList() override { + std::vector<grpc::string> types; + types.push_back(grpc::testing::kTlsCredentialsType); + return types; + } +}; + +gpr_once g_once_init_provider_mu = GPR_ONCE_INIT; +grpc::mutex* g_provider_mu = nullptr; +CredentialsProvider* g_provider = nullptr; + +void InitProviderMu() { + g_provider_mu = new grpc::mutex; +} + +grpc::mutex& GetMu() { + gpr_once_init(&g_once_init_provider_mu, &InitProviderMu); + return *g_provider_mu; +} + +CredentialsProvider* GetProvider() { + grpc::unique_lock<grpc::mutex> lock(GetMu()); + if (g_provider == nullptr) { + g_provider = new DefaultCredentialsProvider; + } + return g_provider; +} + +} // namespace + namespace grpc { namespace testing { -const char kTlsCredentialsType[] = "TLS_CREDENTIALS"; +// Note that it is not thread-safe to set a provider while concurrently using +// the previously set provider, as this deletes and replaces it. nullptr may be +// given to reset to the default. +void SetTestCredentialsProvider(std::unique_ptr<CredentialsProvider> provider) { + grpc::unique_lock<grpc::mutex> lock(GetMu()); + if (g_provider != nullptr) { + delete g_provider; + } + g_provider = provider.release(); +} std::shared_ptr<ChannelCredentials> GetChannelCredentials( const grpc::string& type, ChannelArguments* args) { - if (type == kInsecureCredentialsType) { - return InsecureChannelCredentials(); - } else if (type == kTlsCredentialsType) { - SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; - args->SetSslTargetNameOverride("foo.test.google.fr"); - return SslCredentials(ssl_opts); - } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); - } - return nullptr; + return GetProvider()->GetChannelCredentials(type, args); } std::shared_ptr<ServerCredentials> GetServerCredentials( const grpc::string& type) { - if (type == kInsecureCredentialsType) { - return InsecureServerCredentials(); - } else if (type == kTlsCredentialsType) { - SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, - test_server1_cert}; - SslServerCredentialsOptions ssl_opts; - ssl_opts.pem_root_certs = ""; - ssl_opts.pem_key_cert_pairs.push_back(pkcp); - return SslServerCredentials(ssl_opts); - } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); - } - return nullptr; + return GetProvider()->GetServerCredentials(type); } std::vector<grpc::string> GetSecureCredentialsTypeList() { - std::vector<grpc::string> types; - types.push_back(kTlsCredentialsType); - return types; + return GetProvider()->GetSecureCredentialsTypeList(); } } // namespace testing diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h index f7253051a9..a6b547cb07 100644 --- a/test/cpp/util/test_credentials_provider.h +++ b/test/cpp/util/test_credentials_provider.h @@ -44,6 +44,22 @@ namespace grpc { namespace testing { const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS"; +const char kTlsCredentialsType[] = "TLS_CREDENTIALS"; + +class CredentialsProvider { + public: + virtual ~CredentialsProvider() {} + + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) = 0; + virtual std::shared_ptr<ServerCredentials> GetServerCredentials( + const grpc::string& type) = 0; + virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0; +}; + +// Set the CredentialsProvider used by the other functions in this file. If this +// is not set, a default provider will be used. +void SetTestCredentialsProvider(std::unique_ptr<CredentialsProvider> provider); // Provide channel credentials according to the given type. Alter the channel // arguments if needed. |