diff options
Diffstat (limited to 'test/cpp/util')
-rw-r--r-- | test/cpp/util/test_credentials_provider.cc | 148 | ||||
-rw-r--r-- | test/cpp/util/test_credentials_provider.h | 17 |
2 files changed, 138 insertions, 27 deletions
diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc index 1086e14258..65d3205767 100644 --- a/test/cpp/util/test_credentials_provider.cc +++ b/test/cpp/util/test_credentials_provider.cc @@ -34,48 +34,142 @@ #include "test/cpp/util/test_credentials_provider.h" +#include <unordered_map> + +#include <grpc/support/sync.h> +#include <grpc++/impl/sync.h> + #include "test/core/end2end/data/ssl_test_data.h" +namespace { + +using grpc::ChannelArguments; +using grpc::ChannelCredentials; +using grpc::InsecureChannelCredentials; +using grpc::InsecureServerCredentials; +using grpc::ServerCredentials; +using grpc::SslCredentialsOptions; +using grpc::SslServerCredentialsOptions; +using grpc::testing::CredentialTypeProvider; + +// Provide test credentials. Thread-safe. +class CredentialsProvider { + public: + virtual ~CredentialsProvider() {} + + virtual void AddSecureType( + const grpc::string& type, + std::unique_ptr<CredentialTypeProvider> type_provider) = 0; + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) = 0; + virtual std::shared_ptr<ServerCredentials> GetServerCredentials( + const grpc::string& type) = 0; + virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0; +}; + +class DefaultCredentialsProvider : public CredentialsProvider { + public: + ~DefaultCredentialsProvider() override {} + + void AddSecureType( + const grpc::string& type, + std::unique_ptr<CredentialTypeProvider> type_provider) override { + // This clobbers any existing entry for type, except the defaults, which + // can't be clobbered. + grpc::unique_lock<grpc::mutex> lock(mu_); + added_secure_types_[type] = std::move(type_provider); + } + + std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureChannelCredentials(); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + args->SetSslTargetNameOverride("foo.test.google.fr"); + return SslCredentials(ssl_opts); + } else { + grpc::unique_lock<grpc::mutex> lock(mu_); + auto it(added_secure_types_.find(type)); + if (it == added_secure_types_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return it->second->GetChannelCredentials(args); + } + } + + std::shared_ptr<ServerCredentials> GetServerCredentials( + const grpc::string& type) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureServerCredentials(); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, + test_server1_cert}; + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + return SslServerCredentials(ssl_opts); + } else { + grpc::unique_lock<grpc::mutex> lock(mu_); + auto it(added_secure_types_.find(type)); + if (it == added_secure_types_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return it->second->GetServerCredentials(); + } + } + std::vector<grpc::string> GetSecureCredentialsTypeList() override { + std::vector<grpc::string> types; + types.push_back(grpc::testing::kTlsCredentialsType); + grpc::unique_lock<grpc::mutex> lock(mu_); + for (const auto& type_pair : added_secure_types_) { + types.push_back(type_pair.first); + } + return types; + } + + private: + grpc::mutex mu_; + std::unordered_map<grpc::string, std::unique_ptr<CredentialTypeProvider> > + added_secure_types_; +}; + +gpr_once g_once_init_provider = GPR_ONCE_INIT; +CredentialsProvider* g_provider = nullptr; + +void CreateDefaultProvider() { + g_provider = new DefaultCredentialsProvider; +} + +CredentialsProvider* GetProvider() { + gpr_once_init(&g_once_init_provider, &CreateDefaultProvider); + return g_provider; +} + +} // namespace + namespace grpc { namespace testing { -const char kTlsCredentialsType[] = "TLS_CREDENTIALS"; +void AddSecureType(const grpc::string& type, + std::unique_ptr<CredentialTypeProvider> type_provider) { + GetProvider()->AddSecureType(type, std::move(type_provider)); +} std::shared_ptr<ChannelCredentials> GetChannelCredentials( const grpc::string& type, ChannelArguments* args) { - if (type == kInsecureCredentialsType) { - return InsecureChannelCredentials(); - } else if (type == kTlsCredentialsType) { - SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; - args->SetSslTargetNameOverride("foo.test.google.fr"); - return SslCredentials(ssl_opts); - } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); - } - return nullptr; + return GetProvider()->GetChannelCredentials(type, args); } std::shared_ptr<ServerCredentials> GetServerCredentials( const grpc::string& type) { - if (type == kInsecureCredentialsType) { - return InsecureServerCredentials(); - } else if (type == kTlsCredentialsType) { - SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, - test_server1_cert}; - SslServerCredentialsOptions ssl_opts; - ssl_opts.pem_root_certs = ""; - ssl_opts.pem_key_cert_pairs.push_back(pkcp); - return SslServerCredentials(ssl_opts); - } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); - } - return nullptr; + return GetProvider()->GetServerCredentials(type); } std::vector<grpc::string> GetSecureCredentialsTypeList() { - std::vector<grpc::string> types; - types.push_back(kTlsCredentialsType); - return types; + return GetProvider()->GetSecureCredentialsTypeList(); } } // namespace testing diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h index f7253051a9..50fadb53a2 100644 --- a/test/cpp/util/test_credentials_provider.h +++ b/test/cpp/util/test_credentials_provider.h @@ -44,6 +44,23 @@ namespace grpc { namespace testing { const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS"; +const char kTlsCredentialsType[] = "TLS_CREDENTIALS"; + +// Provide test credentials of a particular type. +class CredentialTypeProvider { + public: + virtual ~CredentialTypeProvider() {} + + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + ChannelArguments* args) = 0; + virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0; +}; + +// Add a secure type in addition to the defaults above +// (kInsecureCredentialsType, kTlsCredentialsType) that can be returned from the +// functions below. +void AddSecureType(const grpc::string& type, + std::unique_ptr<CredentialTypeProvider> type_provider); // Provide channel credentials according to the given type. Alter the channel // arguments if needed. |