diff options
Diffstat (limited to 'src/cpp/client/secure_credentials.cc')
-rw-r--r-- | src/cpp/client/secure_credentials.cc | 80 |
1 files changed, 67 insertions, 13 deletions
diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index 057a058a3f..13bbc3075d 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -21,6 +21,7 @@ #include <grpc++/impl/grpc_library.h> #include <grpc++/support/channel_arguments.h> #include <grpc/support/log.h> +#include <grpc/support/string_util.h> #include "src/cpp/client/create_channel_internal.h" #include "src/cpp/common/secure_auth_context.h" @@ -150,6 +151,18 @@ std::shared_ptr<ChannelCredentials> CompositeChannelCredentials( return nullptr; } +std::shared_ptr<CallCredentials> CompositeCallCredentials( + const std::shared_ptr<CallCredentials>& creds1, + const std::shared_ptr<CallCredentials>& creds2) { + SecureCallCredentials* s_creds1 = creds1->AsSecureCredentials(); + SecureCallCredentials* s_creds2 = creds2->AsSecureCredentials(); + if (s_creds1 != nullptr && s_creds2 != nullptr) { + return WrapCallCredentials(grpc_composite_call_credentials_create( + s_creds1->GetRawCreds(), s_creds2->GetRawCreds(), nullptr)); + } + return nullptr; +} + void MetadataCredentialsPluginWrapper::Destroy(void* wrapper) { if (wrapper == nullptr) return; MetadataCredentialsPluginWrapper* w = @@ -157,28 +170,50 @@ void MetadataCredentialsPluginWrapper::Destroy(void* wrapper) { delete w; } -void MetadataCredentialsPluginWrapper::GetMetadata( +int MetadataCredentialsPluginWrapper::GetMetadata( void* wrapper, grpc_auth_metadata_context context, - grpc_credentials_plugin_metadata_cb cb, void* user_data) { + grpc_credentials_plugin_metadata_cb cb, void* user_data, + grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], + size_t* num_creds_md, grpc_status_code* status, + const char** error_details) { GPR_ASSERT(wrapper); MetadataCredentialsPluginWrapper* w = reinterpret_cast<MetadataCredentialsPluginWrapper*>(wrapper); if (!w->plugin_) { - cb(user_data, NULL, 0, GRPC_STATUS_OK, NULL); - return; + *num_creds_md = 0; + *status = GRPC_STATUS_OK; + *error_details = nullptr; + return true; } if (w->plugin_->IsBlocking()) { + // Asynchronous return. w->thread_pool_->Add( std::bind(&MetadataCredentialsPluginWrapper::InvokePlugin, w, context, - cb, user_data)); + cb, user_data, nullptr, nullptr, nullptr, nullptr)); + return 0; } else { - w->InvokePlugin(context, cb, user_data); + // Synchronous return. + w->InvokePlugin(context, cb, user_data, creds_md, num_creds_md, status, + error_details); + return 1; + } +} + +namespace { + +void UnrefMetadata(const std::vector<grpc_metadata>& md) { + for (auto it = md.begin(); it != md.end(); ++it) { + grpc_slice_unref(it->key); + grpc_slice_unref(it->value); } } +} // namespace + void MetadataCredentialsPluginWrapper::InvokePlugin( grpc_auth_metadata_context context, grpc_credentials_plugin_metadata_cb cb, - void* user_data) { + void* user_data, grpc_metadata creds_md[4], size_t* num_creds_md, + grpc_status_code* status_code, const char** error_details) { std::multimap<grpc::string, grpc::string> metadata; // const_cast is safe since the SecureAuthContext does not take owndership and @@ -196,12 +231,31 @@ void MetadataCredentialsPluginWrapper::InvokePlugin( md_entry.flags = 0; md.push_back(md_entry); } - cb(user_data, md.empty() ? nullptr : &md[0], md.size(), - static_cast<grpc_status_code>(status.error_code()), - status.error_message().c_str()); - for (auto it = md.begin(); it != md.end(); ++it) { - grpc_slice_unref(it->key); - grpc_slice_unref(it->value); + if (creds_md != nullptr) { + // Synchronous return. + if (md.size() > GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX) { + *num_creds_md = 0; + *status_code = GRPC_STATUS_INTERNAL; + *error_details = gpr_strdup( + "blocking plugin credentials returned too many metadata keys"); + UnrefMetadata(md); + } else { + for (const auto& elem : md) { + creds_md[*num_creds_md].key = elem.key; + creds_md[*num_creds_md].value = elem.value; + creds_md[*num_creds_md].flags = elem.flags; + ++(*num_creds_md); + } + *status_code = static_cast<grpc_status_code>(status.error_code()); + *error_details = + status.ok() ? nullptr : gpr_strdup(status.error_message().c_str()); + } + } else { + // Asynchronous return. + cb(user_data, md.empty() ? nullptr : &md[0], md.size(), + static_cast<grpc_status_code>(status.error_code()), + status.error_message().c_str()); + UnrefMetadata(md); } } |