diff options
Diffstat (limited to 'src/core/lib/security')
10 files changed, 112 insertions, 24 deletions
diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc index 1f93416b23..14051a3f00 100644 --- a/src/core/lib/security/context/security_context.cc +++ b/src/core/lib/security/context/security_context.cc @@ -326,8 +326,23 @@ grpc_arg grpc_auth_context_to_arg(grpc_auth_context* p) { &auth_context_pointer_vtable); } +grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_AUTH_CONTEXT_ARG) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_AUTH_CONTEXT_ARG); + return nullptr; + } + return static_cast<grpc_auth_context*>(arg->value.pointer.p); +} + grpc_auth_context* grpc_find_auth_context_in_args( - const grpc_channel_args* channel_args) { - return grpc_channel_args_get_pointer<grpc_auth_context>( - channel_args, GRPC_AUTH_CONTEXT_ARG); + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_auth_context* p = grpc_auth_context_from_arg(&args->args[i]); + if (p != nullptr) return p; + } + return nullptr; } diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h index 2f73a5482c..e782e4f28f 100644 --- a/src/core/lib/security/context/security_context.h +++ b/src/core/lib/security/context/security_context.h @@ -108,6 +108,7 @@ void grpc_server_security_context_destroy(void* ctx); #define GRPC_AUTH_CONTEXT_ARG "grpc.auth_context" grpc_arg grpc_auth_context_to_arg(grpc_auth_context* c); +grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg); grpc_auth_context* grpc_find_auth_context_in_args( const grpc_channel_args* args); diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc index edeea29327..c43cb440eb 100644 --- a/src/core/lib/security/credentials/credentials.cc +++ b/src/core/lib/security/credentials/credentials.cc @@ -168,10 +168,27 @@ grpc_arg grpc_channel_credentials_to_arg( &credentials_pointer_vtable); } +grpc_channel_credentials* grpc_channel_credentials_from_arg( + const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_ARG_CHANNEL_CREDENTIALS)) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_ARG_CHANNEL_CREDENTIALS); + return nullptr; + } + return static_cast<grpc_channel_credentials*>(arg->value.pointer.p); +} + grpc_channel_credentials* grpc_channel_credentials_find_in_args( - const grpc_channel_args* channel_args) { - return grpc_channel_args_get_pointer<grpc_channel_credentials>( - channel_args, GRPC_ARG_CHANNEL_CREDENTIALS); + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_channel_credentials* credentials = + grpc_channel_credentials_from_arg(&args->args[i]); + if (credentials != nullptr) return credentials; + } + return nullptr; } grpc_server_credentials* grpc_server_credentials_ref( @@ -246,8 +263,24 @@ grpc_arg grpc_server_credentials_to_arg(grpc_server_credentials* p) { &cred_ptr_vtable); } +grpc_server_credentials* grpc_server_credentials_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_SERVER_CREDENTIALS_ARG) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_SERVER_CREDENTIALS_ARG); + return nullptr; + } + return static_cast<grpc_server_credentials*>(arg->value.pointer.p); +} + grpc_server_credentials* grpc_find_server_credentials_in_args( - const grpc_channel_args* channel_args) { - return grpc_channel_args_get_pointer<grpc_server_credentials>( - channel_args, GRPC_SERVER_CREDENTIALS_ARG); + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_server_credentials* p = + grpc_server_credentials_from_arg(&args->args[i]); + if (p != nullptr) return p; + } + return nullptr; } diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index ba380283cc..b486d25ab2 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -131,6 +131,10 @@ grpc_channel_credentials_duplicate_without_call_credentials( /* Util to encapsulate the channel credentials in a channel arg. */ grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials); +/* Util to get the channel credentials from a channel arg. */ +grpc_channel_credentials* grpc_channel_credentials_from_arg( + const grpc_arg* arg); + /* Util to find the channel credentials from channel args. */ grpc_channel_credentials* grpc_channel_credentials_find_in_args( const grpc_channel_args* args); @@ -223,6 +227,7 @@ void grpc_server_credentials_unref(grpc_server_credentials* creds); #define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials" grpc_arg grpc_server_credentials_to_arg(grpc_server_credentials* c); +grpc_server_credentials* grpc_server_credentials_from_arg(const grpc_arg* arg); grpc_server_credentials* grpc_find_server_credentials_in_args( const grpc_channel_args* args); diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index 08321a85c6..858ab6b41b 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -84,8 +84,9 @@ grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) { const char* grpc_fake_transport_get_expected_targets( const grpc_channel_args* args) { - return grpc_channel_args_get_string(args, - GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS); + const grpc_arg* expected_target_arg = + grpc_channel_args_find(args, GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS); + return grpc_channel_arg_get_string(expected_target_arg); } /* -- Metadata-only test credentials. -- */ diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc index fa565d4ef8..38c9175717 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.cc +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc @@ -79,10 +79,13 @@ static grpc_security_status google_default_create_security_connector( grpc_channel_security_connector** sc, grpc_channel_args** new_args) { grpc_google_default_channel_credentials* c = reinterpret_cast<grpc_google_default_channel_credentials*>(creds); - bool is_grpclb_load_balancer = grpc_channel_args_get_bool( - args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER, false); - bool is_backend_from_grpclb_load_balancer = grpc_channel_args_get_bool( - args, GRPC_ARG_ADDRESS_IS_BACKEND_FROM_GRPCLB_LOAD_BALANCER, false); + bool is_grpclb_load_balancer = grpc_channel_arg_get_bool( + grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER), + false); + bool is_backend_from_grpclb_load_balancer = grpc_channel_arg_get_bool( + grpc_channel_args_find( + args, GRPC_ARG_ADDRESS_IS_BACKEND_FROM_GRPCLB_LOAD_BALANCER), + false); bool use_alts = is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer; grpc_security_status status = GRPC_SECURITY_ERROR; diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index 13dae19b4b..2b6377d3ec 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -60,12 +60,14 @@ static grpc_security_status ssl_create_security_connector( tsi_ssl_session_cache* ssl_session_cache = nullptr; for (size_t i = 0; args && i < args->num_args; i++) { grpc_arg* arg = &args->args[i]; - if (strcmp(arg->key, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == 0) { - overridden_target_name = grpc_channel_arg_get_string(arg); + if (strcmp(arg->key, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == 0 && + arg->type == GRPC_ARG_STRING) { + overridden_target_name = arg->value.string; } - if (strcmp(arg->key, GRPC_SSL_SESSION_CACHE_ARG) == 0) { + if (strcmp(arg->key, GRPC_SSL_SESSION_CACHE_ARG) == 0 && + arg->type == GRPC_ARG_POINTER) { ssl_session_cache = - grpc_channel_arg_get_pointer<tsi_ssl_session_cache>(arg); + static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p); } } status = grpc_ssl_channel_security_connector_create( diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index ea001d453d..b54a7643e4 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -255,10 +255,26 @@ grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc) { &connector_arg_vtable); } +grpc_security_connector* grpc_security_connector_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_ARG_SECURITY_CONNECTOR)) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_ARG_SECURITY_CONNECTOR); + return nullptr; + } + return static_cast<grpc_security_connector*>(arg->value.pointer.p); +} + grpc_security_connector* grpc_security_connector_find_in_args( - const grpc_channel_args* channel_args) { - return grpc_channel_args_get_pointer<grpc_security_connector>( - channel_args, GRPC_ARG_SECURITY_CONNECTOR); + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_security_connector* sc = + grpc_security_connector_from_arg(&args->args[i]); + if (sc != nullptr) return sc; + } + return nullptr; } static tsi_client_certificate_request_type diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 9da66ef01d..f9723166d0 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -99,6 +99,9 @@ int grpc_security_connector_cmp(grpc_security_connector* sc, /* Util to encapsulate the connector in a channel arg. */ grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc); +/* Util to get the connector from a channel arg. */ +grpc_security_connector* grpc_security_connector_from_arg(const grpc_arg* arg); + /* Util to find the connector from channel args. */ grpc_security_connector* grpc_security_connector_find_in_args( const grpc_channel_args* args); diff --git a/src/core/lib/security/transport/target_authority_table.cc b/src/core/lib/security/transport/target_authority_table.cc index 467e681a50..1eeb557f6a 100644 --- a/src/core/lib/security/transport/target_authority_table.cc +++ b/src/core/lib/security/transport/target_authority_table.cc @@ -59,8 +59,17 @@ grpc_arg CreateTargetAuthorityTableChannelArg(TargetAuthorityTable* table) { TargetAuthorityTable* FindTargetAuthorityTableInArgs( const grpc_channel_args* args) { - return grpc_channel_args_get_pointer<TargetAuthorityTable>( - args, GRPC_ARG_TARGET_AUTHORITY_TABLE); + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_TARGET_AUTHORITY_TABLE); + if (arg != nullptr) { + if (arg->type == GRPC_ARG_POINTER) { + return static_cast<TargetAuthorityTable*>(arg->value.pointer.p); + } else { + gpr_log(GPR_ERROR, "value of " GRPC_ARG_TARGET_AUTHORITY_TABLE + " channel arg was not pointer type; ignoring"); + } + } + return nullptr; } } // namespace grpc_core |