aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--test/cpp/util/test_credentials_provider.cc118
-rw-r--r--test/cpp/util/test_credentials_provider.h16
2 files changed, 107 insertions, 27 deletions
diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc
index 1086e14258..cfd3ebbb11 100644
--- a/test/cpp/util/test_credentials_provider.cc
+++ b/test/cpp/util/test_credentials_provider.cc
@@ -34,48 +34,112 @@
#include "test/cpp/util/test_credentials_provider.h"
+#include <grpc/support/sync.h>
+#include <grpc++/impl/sync.h>
+
#include "test/core/end2end/data/ssl_test_data.h"
+namespace {
+
+using grpc::ChannelArguments;
+using grpc::ChannelCredentials;
+using grpc::InsecureChannelCredentials;
+using grpc::InsecureServerCredentials;
+using grpc::ServerCredentials;
+using grpc::SslCredentialsOptions;
+using grpc::SslServerCredentialsOptions;
+using grpc::testing::CredentialsProvider;
+
+class DefaultCredentialsProvider : public CredentialsProvider {
+ public:
+ ~DefaultCredentialsProvider() override {}
+
+ std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const grpc::string& type, ChannelArguments* args) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureChannelCredentials();
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
+ args->SetSslTargetNameOverride("foo.test.google.fr");
+ return SslCredentials(ssl_opts);
+ } else {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ }
+ return nullptr;
+ }
+
+ std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const grpc::string& type) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureServerCredentials();
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
+ test_server1_cert};
+ SslServerCredentialsOptions ssl_opts;
+ ssl_opts.pem_root_certs = "";
+ ssl_opts.pem_key_cert_pairs.push_back(pkcp);
+ return SslServerCredentials(ssl_opts);
+ } else {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ }
+ return nullptr;
+ }
+ std::vector<grpc::string> GetSecureCredentialsTypeList() override {
+ std::vector<grpc::string> types;
+ types.push_back(grpc::testing::kTlsCredentialsType);
+ return types;
+ }
+};
+
+gpr_once g_once_init_provider_mu = GPR_ONCE_INIT;
+grpc::mutex* g_provider_mu = nullptr;
+CredentialsProvider* g_provider = nullptr;
+
+void InitProviderMu() {
+ g_provider_mu = new grpc::mutex;
+}
+
+grpc::mutex& GetMu() {
+ gpr_once_init(&g_once_init_provider_mu, &InitProviderMu);
+ return *g_provider_mu;
+}
+
+CredentialsProvider* GetProvider() {
+ grpc::unique_lock<grpc::mutex> lock(GetMu());
+ if (g_provider == nullptr) {
+ g_provider = new DefaultCredentialsProvider;
+ }
+ return g_provider;
+}
+
+} // namespace
+
namespace grpc {
namespace testing {
-const char kTlsCredentialsType[] = "TLS_CREDENTIALS";
+// Note that it is not thread-safe to set a provider while concurrently using
+// the previously set provider, as this deletes and replaces it. nullptr may be
+// given to reset to the default.
+void SetTestCredentialsProvider(std::unique_ptr<CredentialsProvider> provider) {
+ grpc::unique_lock<grpc::mutex> lock(GetMu());
+ if (g_provider != nullptr) {
+ delete g_provider;
+ }
+ g_provider = provider.release();
+}
std::shared_ptr<ChannelCredentials> GetChannelCredentials(
const grpc::string& type, ChannelArguments* args) {
- if (type == kInsecureCredentialsType) {
- return InsecureChannelCredentials();
- } else if (type == kTlsCredentialsType) {
- SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
- args->SetSslTargetNameOverride("foo.test.google.fr");
- return SslCredentials(ssl_opts);
- } else {
- gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
- }
- return nullptr;
+ return GetProvider()->GetChannelCredentials(type, args);
}
std::shared_ptr<ServerCredentials> GetServerCredentials(
const grpc::string& type) {
- if (type == kInsecureCredentialsType) {
- return InsecureServerCredentials();
- } else if (type == kTlsCredentialsType) {
- SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
- test_server1_cert};
- SslServerCredentialsOptions ssl_opts;
- ssl_opts.pem_root_certs = "";
- ssl_opts.pem_key_cert_pairs.push_back(pkcp);
- return SslServerCredentials(ssl_opts);
- } else {
- gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
- }
- return nullptr;
+ return GetProvider()->GetServerCredentials(type);
}
std::vector<grpc::string> GetSecureCredentialsTypeList() {
- std::vector<grpc::string> types;
- types.push_back(kTlsCredentialsType);
- return types;
+ return GetProvider()->GetSecureCredentialsTypeList();
}
} // namespace testing
diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h
index f7253051a9..a6b547cb07 100644
--- a/test/cpp/util/test_credentials_provider.h
+++ b/test/cpp/util/test_credentials_provider.h
@@ -44,6 +44,22 @@ namespace grpc {
namespace testing {
const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS";
+const char kTlsCredentialsType[] = "TLS_CREDENTIALS";
+
+class CredentialsProvider {
+ public:
+ virtual ~CredentialsProvider() {}
+
+ virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const grpc::string& type, ChannelArguments* args) = 0;
+ virtual std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const grpc::string& type) = 0;
+ virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0;
+};
+
+// Set the CredentialsProvider used by the other functions in this file. If this
+// is not set, a default provider will be used.
+void SetTestCredentialsProvider(std::unique_ptr<CredentialsProvider> provider);
// Provide channel credentials according to the given type. Alter the channel
// arguments if needed.