aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/core/lib/security/security_connector/alts/alts_security_connector.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/lib/security/security_connector/alts/alts_security_connector.cc')
-rw-r--r--src/core/lib/security/security_connector/alts/alts_security_connector.cc329
1 files changed, 152 insertions, 177 deletions
diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc
index dd71c8bc60..3ad0cc353c 100644
--- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc
+++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc
@@ -28,6 +28,7 @@
#include <grpc/support/log.h>
#include <grpc/support/string_util.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/alts/alts_credentials.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/lib/slice/slice_internal.h"
@@ -35,64 +36,9 @@
#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
#include "src/core/tsi/transport_security.h"
-typedef struct {
- grpc_channel_security_connector base;
- char* target_name;
-} grpc_alts_channel_security_connector;
+namespace {
-typedef struct {
- grpc_server_security_connector base;
-} grpc_alts_server_security_connector;
-
-static void alts_channel_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- grpc_call_credentials_unref(c->base.request_metadata_creds);
- grpc_channel_credentials_unref(c->base.channel_creds);
- gpr_free(c->target_name);
- gpr_free(sc);
-}
-
-static void alts_server_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
- grpc_server_credentials_unref(c->base.server_creds);
- gpr_free(sc);
-}
-
-static void alts_channel_add_handshakers(
- grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- grpc_alts_credentials* creds =
- reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
- GPR_ASSERT(alts_tsi_handshaker_create(
- creds->options, c->target_name, creds->handshaker_service_url,
- true, interested_parties, &handshaker) == TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static void alts_server_add_handshakers(
- grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
- grpc_alts_server_credentials* creds =
- reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
- GPR_ASSERT(alts_tsi_handshaker_create(
- creds->options, nullptr, creds->handshaker_service_url, false,
- interested_parties, &handshaker) == TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static void alts_set_rpc_protocol_versions(
+void alts_set_rpc_protocol_versions(
grpc_gcp_rpc_protocol_versions* rpc_versions) {
grpc_gcp_rpc_protocol_versions_set_max(rpc_versions,
GRPC_PROTOCOL_VERSION_MAX_MAJOR,
@@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions(
GRPC_PROTOCOL_VERSION_MIN_MINOR);
}
+void alts_check_peer(tsi_peer peer,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) {
+ *auth_context =
+ grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer);
+ tsi_peer_destruct(&peer);
+ grpc_error* error =
+ *auth_context != nullptr
+ ? GRPC_ERROR_NONE
+ : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "Could not get ALTS auth context from TSI peer");
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
+}
+
+class grpc_alts_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ grpc_alts_channel_security_connector(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name)
+ : grpc_channel_security_connector(/*url_scheme=*/nullptr,
+ std::move(channel_creds),
+ std::move(request_metadata_creds)),
+ target_name_(gpr_strdup(target_name)) {
+ grpc_alts_credentials* creds =
+ static_cast<grpc_alts_credentials*>(mutable_channel_creds());
+ alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
+ }
+
+ ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ const grpc_alts_credentials* creds =
+ static_cast<const grpc_alts_credentials*>(channel_creds());
+ GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_,
+ creds->handshaker_service_url(), true,
+ interested_parties,
+ &handshaker) == TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ alts_check_peer(peer, auth_context, on_peer_checked);
+ }
+
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_alts_channel_security_connector*>(other_sc);
+ int c = channel_security_connector_cmp(other);
+ if (c != 0) return c;
+ return strcmp(target_name_, other->target_name_);
+ }
+
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ if (host == nullptr || strcmp(host, target_name_) != 0) {
+ *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "ALTS call host does not match target name");
+ }
+ return true;
+ }
+
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
+ }
+
+ private:
+ char* target_name_;
+};
+
+class grpc_alts_server_security_connector final
+ : public grpc_server_security_connector {
+ public:
+ grpc_alts_server_security_connector(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_server_security_connector(/*url_scheme=*/nullptr,
+ std::move(server_creds)) {
+ grpc_alts_server_credentials* creds =
+ reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds());
+ alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
+ }
+ ~grpc_alts_server_security_connector() override = default;
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ const grpc_alts_server_credentials* creds =
+ static_cast<const grpc_alts_server_credentials*>(server_creds());
+ GPR_ASSERT(alts_tsi_handshaker_create(
+ creds->options(), nullptr, creds->handshaker_service_url(),
+ false, interested_parties, &handshaker) == TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ alts_check_peer(peer, auth_context, on_peer_checked);
+ }
+
+ int cmp(const grpc_security_connector* other) const override {
+ return server_security_connector_cmp(
+ static_cast<const grpc_server_security_connector*>(other));
+ }
+};
+} // namespace
+
namespace grpc_core {
namespace internal {
-
-grpc_security_status grpc_alts_auth_context_from_tsi_peer(
- const tsi_peer* peer, grpc_auth_context** ctx) {
- if (peer == nullptr || ctx == nullptr) {
+grpc_core::RefCountedPtr<grpc_auth_context>
+grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) {
+ if (peer == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to grpc_alts_auth_context_from_tsi_peer()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- *ctx = nullptr;
/* Validate certificate type. */
const tsi_peer_property* cert_type_prop =
tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY);
@@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE,
cert_type_prop->value.length) != 0) {
gpr_log(GPR_ERROR, "Invalid or missing certificate type property.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* Validate RPC protocol versions. */
const tsi_peer_property* rpc_versions_prop =
tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS);
if (rpc_versions_prop == nullptr) {
gpr_log(GPR_ERROR, "Missing rpc protocol versions property.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
grpc_gcp_rpc_protocol_versions local_versions, peer_versions;
alts_set_rpc_protocol_versions(&local_versions);
@@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
grpc_slice_unref_internal(slice);
if (!decode_result) {
gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* TODO: Pass highest common rpc protocol version to grpc caller. */
bool check_result = grpc_gcp_rpc_protocol_versions_check(
&local_versions, &peer_versions, nullptr);
if (!check_result) {
gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* Create auth context. */
- *ctx = grpc_auth_context_create(nullptr);
+ auto ctx = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
- *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
+ ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_ALTS_TRANSPORT_SECURITY_TYPE);
size_t i = 0;
for (i = 0; i < peer->property_count; i++) {
@@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
/* Add service account to auth context. */
if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) {
grpc_auth_context_add_property(
- *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data,
- tsi_prop->value.length);
+ ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY,
+ tsi_prop->value.data, tsi_prop->value.length);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
- *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
+ ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
}
}
- if (!grpc_auth_context_peer_is_authenticated(*ctx)) {
+ if (!grpc_auth_context_peer_is_authenticated(ctx.get())) {
gpr_log(GPR_ERROR, "Invalid unauthenticated peer.");
- GRPC_AUTH_CONTEXT_UNREF(*ctx, "test");
- *ctx = nullptr;
- return GRPC_SECURITY_ERROR;
+ ctx.reset(DEBUG_LOCATION, "test");
+ return nullptr;
}
- return GRPC_SECURITY_OK;
+ return ctx;
}
} // namespace internal
} // namespace grpc_core
-static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_security_status status;
- status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(
- &peer, auth_context);
- tsi_peer_destruct(&peer);
- grpc_error* error =
- status == GRPC_SECURITY_OK
- ? GRPC_ERROR_NONE
- : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "Could not get ALTS auth context from TSI peer");
- GRPC_CLOSURE_SCHED(on_peer_checked, error);
-}
-
-static int alts_channel_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_alts_channel_security_connector* c1 =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc1);
- grpc_alts_channel_security_connector* c2 =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc2);
- int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
- if (c != 0) return c;
- return strcmp(c1->target_name, c2->target_name);
-}
-
-static int alts_server_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_alts_server_security_connector* c1 =
- reinterpret_cast<grpc_alts_server_security_connector*>(sc1);
- grpc_alts_server_security_connector* c2 =
- reinterpret_cast<grpc_alts_server_security_connector*>(sc2);
- return grpc_server_security_connector_cmp(&c1->base, &c2->base);
-}
-
-static grpc_security_connector_vtable alts_channel_vtable = {
- alts_channel_destroy, alts_check_peer, alts_channel_cmp};
-
-static grpc_security_connector_vtable alts_server_vtable = {
- alts_server_destroy, alts_check_peer, alts_server_cmp};
-
-static bool alts_check_call_host(grpc_channel_security_connector* sc,
- const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error) {
- grpc_alts_channel_security_connector* alts_sc =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- if (host == nullptr || alts_sc == nullptr ||
- strcmp(host, alts_sc->target_name) != 0) {
- *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "ALTS call host does not match target name");
- }
- return true;
-}
-
-static void alts_cancel_check_call_host(grpc_channel_security_connector* sc,
- grpc_closure* on_call_host_checked,
- grpc_error* error) {
- GRPC_ERROR_UNREF(error);
-}
-
-grpc_security_status grpc_alts_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds, const char* target_name,
- grpc_channel_security_connector** sc) {
- if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) {
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_alts_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name) {
+ if (channel_creds == nullptr || target_name == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_channel_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_alts_channel_security_connector*>(
- gpr_zalloc(sizeof(grpc_alts_channel_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &alts_channel_vtable;
- c->base.add_handshakers = alts_channel_add_handshakers;
- c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
- c->base.request_metadata_creds =
- grpc_call_credentials_ref(request_metadata_creds);
- c->base.check_call_host = alts_check_call_host;
- c->base.cancel_check_call_host = alts_cancel_check_call_host;
- grpc_alts_credentials* creds =
- reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
- alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
- c->target_name = gpr_strdup(target_name);
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_alts_channel_security_connector>(
+ std::move(channel_creds), std::move(request_metadata_creds), target_name);
}
-grpc_security_status grpc_alts_server_security_connector_create(
- grpc_server_credentials* server_creds,
- grpc_server_security_connector** sc) {
- if (server_creds == nullptr || sc == nullptr) {
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_alts_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
+ if (server_creds == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_server_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_alts_server_security_connector*>(
- gpr_zalloc(sizeof(grpc_alts_server_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &alts_server_vtable;
- c->base.server_creds = grpc_server_credentials_ref(server_creds);
- c->base.add_handshakers = alts_server_add_handshakers;
- grpc_alts_server_credentials* creds =
- reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
- alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_alts_server_security_connector>(
+ std::move(server_creds));
}