aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/cpp/util
diff options
context:
space:
mode:
Diffstat (limited to 'test/cpp/util')
-rw-r--r--test/cpp/util/test_credentials_provider.cc148
-rw-r--r--test/cpp/util/test_credentials_provider.h17
2 files changed, 138 insertions, 27 deletions
diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc
index 1086e14258..65d3205767 100644
--- a/test/cpp/util/test_credentials_provider.cc
+++ b/test/cpp/util/test_credentials_provider.cc
@@ -34,48 +34,142 @@
#include "test/cpp/util/test_credentials_provider.h"
+#include <unordered_map>
+
+#include <grpc/support/sync.h>
+#include <grpc++/impl/sync.h>
+
#include "test/core/end2end/data/ssl_test_data.h"
+namespace {
+
+using grpc::ChannelArguments;
+using grpc::ChannelCredentials;
+using grpc::InsecureChannelCredentials;
+using grpc::InsecureServerCredentials;
+using grpc::ServerCredentials;
+using grpc::SslCredentialsOptions;
+using grpc::SslServerCredentialsOptions;
+using grpc::testing::CredentialTypeProvider;
+
+// Provide test credentials. Thread-safe.
+class CredentialsProvider {
+ public:
+ virtual ~CredentialsProvider() {}
+
+ virtual void AddSecureType(
+ const grpc::string& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider) = 0;
+ virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const grpc::string& type, ChannelArguments* args) = 0;
+ virtual std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const grpc::string& type) = 0;
+ virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0;
+};
+
+class DefaultCredentialsProvider : public CredentialsProvider {
+ public:
+ ~DefaultCredentialsProvider() override {}
+
+ void AddSecureType(
+ const grpc::string& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider) override {
+ // This clobbers any existing entry for type, except the defaults, which
+ // can't be clobbered.
+ grpc::unique_lock<grpc::mutex> lock(mu_);
+ added_secure_types_[type] = std::move(type_provider);
+ }
+
+ std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const grpc::string& type, ChannelArguments* args) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureChannelCredentials();
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
+ args->SetSslTargetNameOverride("foo.test.google.fr");
+ return SslCredentials(ssl_opts);
+ } else {
+ grpc::unique_lock<grpc::mutex> lock(mu_);
+ auto it(added_secure_types_.find(type));
+ if (it == added_secure_types_.end()) {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ return nullptr;
+ }
+ return it->second->GetChannelCredentials(args);
+ }
+ }
+
+ std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const grpc::string& type) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureServerCredentials();
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
+ test_server1_cert};
+ SslServerCredentialsOptions ssl_opts;
+ ssl_opts.pem_root_certs = "";
+ ssl_opts.pem_key_cert_pairs.push_back(pkcp);
+ return SslServerCredentials(ssl_opts);
+ } else {
+ grpc::unique_lock<grpc::mutex> lock(mu_);
+ auto it(added_secure_types_.find(type));
+ if (it == added_secure_types_.end()) {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ return nullptr;
+ }
+ return it->second->GetServerCredentials();
+ }
+ }
+ std::vector<grpc::string> GetSecureCredentialsTypeList() override {
+ std::vector<grpc::string> types;
+ types.push_back(grpc::testing::kTlsCredentialsType);
+ grpc::unique_lock<grpc::mutex> lock(mu_);
+ for (const auto& type_pair : added_secure_types_) {
+ types.push_back(type_pair.first);
+ }
+ return types;
+ }
+
+ private:
+ grpc::mutex mu_;
+ std::unordered_map<grpc::string, std::unique_ptr<CredentialTypeProvider> >
+ added_secure_types_;
+};
+
+gpr_once g_once_init_provider = GPR_ONCE_INIT;
+CredentialsProvider* g_provider = nullptr;
+
+void CreateDefaultProvider() {
+ g_provider = new DefaultCredentialsProvider;
+}
+
+CredentialsProvider* GetProvider() {
+ gpr_once_init(&g_once_init_provider, &CreateDefaultProvider);
+ return g_provider;
+}
+
+} // namespace
+
namespace grpc {
namespace testing {
-const char kTlsCredentialsType[] = "TLS_CREDENTIALS";
+void AddSecureType(const grpc::string& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider) {
+ GetProvider()->AddSecureType(type, std::move(type_provider));
+}
std::shared_ptr<ChannelCredentials> GetChannelCredentials(
const grpc::string& type, ChannelArguments* args) {
- if (type == kInsecureCredentialsType) {
- return InsecureChannelCredentials();
- } else if (type == kTlsCredentialsType) {
- SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
- args->SetSslTargetNameOverride("foo.test.google.fr");
- return SslCredentials(ssl_opts);
- } else {
- gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
- }
- return nullptr;
+ return GetProvider()->GetChannelCredentials(type, args);
}
std::shared_ptr<ServerCredentials> GetServerCredentials(
const grpc::string& type) {
- if (type == kInsecureCredentialsType) {
- return InsecureServerCredentials();
- } else if (type == kTlsCredentialsType) {
- SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
- test_server1_cert};
- SslServerCredentialsOptions ssl_opts;
- ssl_opts.pem_root_certs = "";
- ssl_opts.pem_key_cert_pairs.push_back(pkcp);
- return SslServerCredentials(ssl_opts);
- } else {
- gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
- }
- return nullptr;
+ return GetProvider()->GetServerCredentials(type);
}
std::vector<grpc::string> GetSecureCredentialsTypeList() {
- std::vector<grpc::string> types;
- types.push_back(kTlsCredentialsType);
- return types;
+ return GetProvider()->GetSecureCredentialsTypeList();
}
} // namespace testing
diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h
index f7253051a9..50fadb53a2 100644
--- a/test/cpp/util/test_credentials_provider.h
+++ b/test/cpp/util/test_credentials_provider.h
@@ -44,6 +44,23 @@ namespace grpc {
namespace testing {
const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS";
+const char kTlsCredentialsType[] = "TLS_CREDENTIALS";
+
+// Provide test credentials of a particular type.
+class CredentialTypeProvider {
+ public:
+ virtual ~CredentialTypeProvider() {}
+
+ virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ ChannelArguments* args) = 0;
+ virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0;
+};
+
+// Add a secure type in addition to the defaults above
+// (kInsecureCredentialsType, kTlsCredentialsType) that can be returned from the
+// functions below.
+void AddSecureType(const grpc::string& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider);
// Provide channel credentials according to the given type. Alter the channel
// arguments if needed.