aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--examples/pubsub/main.cc2
-rw-r--r--include/grpc++/client_context.h14
-rw-r--r--include/grpc++/create_channel.h2
-rw-r--r--include/grpc++/credentials.h35
-rw-r--r--src/cpp/client/client_context.cc12
-rw-r--r--src/cpp/client/create_channel.cc2
-rw-r--r--src/cpp/client/insecure_credentials.cc6
-rw-r--r--src/cpp/client/secure_credentials.cc32
-rw-r--r--src/cpp/client/secure_credentials.h1
-rw-r--r--test/cpp/client/credentials_test.cc3
-rw-r--r--test/cpp/end2end/end2end_test.cc50
-rw-r--r--test/cpp/interop/client_helper.cc6
-rw-r--r--test/cpp/util/create_test_channel.cc6
-rw-r--r--test/cpp/util/create_test_channel.h2
-rw-r--r--test/cpp/util/grpc_cli.cc2
15 files changed, 123 insertions, 52 deletions
diff --git a/examples/pubsub/main.cc b/examples/pubsub/main.cc
index 68620e64c5..b1898f18d9 100644
--- a/examples/pubsub/main.cc
+++ b/examples/pubsub/main.cc
@@ -71,7 +71,7 @@ int main(int argc, char** argv) {
ss << FLAGS_server_host << ":" << FLAGS_server_port;
- std::unique_ptr<grpc::Credentials> creds = grpc::GoogleDefaultCredentials();
+ std::shared_ptr<grpc::Credentials> creds = grpc::GoogleDefaultCredentials();
std::shared_ptr<grpc::ChannelInterface> channel =
grpc::CreateChannel(ss.str(), creds, grpc::ChannelArguments());
diff --git a/include/grpc++/client_context.h b/include/grpc++/client_context.h
index a58e9872e6..6d9015f278 100644
--- a/include/grpc++/client_context.h
+++ b/include/grpc++/client_context.h
@@ -51,6 +51,7 @@ namespace grpc {
class CallOpBuffer;
class ChannelInterface;
class CompletionQueue;
+class Credentials;
class RpcMethod;
class Status;
template <class R>
@@ -102,6 +103,11 @@ class ClientContext {
void set_authority(const grpc::string& authority) { authority_ = authority; }
+ // Set credentials for the rpc.
+ void set_credentials(const std::shared_ptr<Credentials>& creds) {
+ creds_ = creds;
+ }
+
void TryCancel();
private:
@@ -127,11 +133,8 @@ class ClientContext {
friend class ::grpc::ClientAsyncResponseReader;
grpc_call* call() { return call_; }
- void set_call(grpc_call* call, const std::shared_ptr<ChannelInterface>& channel) {
- GPR_ASSERT(call_ == nullptr);
- call_ = call;
- channel_ = channel;
- }
+ void set_call(grpc_call* call,
+ const std::shared_ptr<ChannelInterface>& channel);
grpc_completion_queue* cq() { return cq_; }
void set_cq(grpc_completion_queue* cq) { cq_ = cq; }
@@ -144,6 +147,7 @@ class ClientContext {
grpc_completion_queue* cq_;
gpr_timespec deadline_;
grpc::string authority_;
+ std::shared_ptr<Credentials> creds_;
std::multimap<grpc::string, grpc::string> send_initial_metadata_;
std::multimap<grpc::string, grpc::string> recv_initial_metadata_;
std::multimap<grpc::string, grpc::string> trailing_metadata_;
diff --git a/include/grpc++/create_channel.h b/include/grpc++/create_channel.h
index da375b97db..424a93a64c 100644
--- a/include/grpc++/create_channel.h
+++ b/include/grpc++/create_channel.h
@@ -45,7 +45,7 @@ class ChannelInterface;
// If creds does not hold an object or is invalid, a lame channel is returned.
std::shared_ptr<ChannelInterface> CreateChannel(
- const grpc::string& target, const std::unique_ptr<Credentials>& creds,
+ const grpc::string& target, const std::shared_ptr<Credentials>& creds,
const ChannelArguments& args);
} // namespace grpc
diff --git a/include/grpc++/credentials.h b/include/grpc++/credentials.h
index 61c4094691..7a40cd199d 100644
--- a/include/grpc++/credentials.h
+++ b/include/grpc++/credentials.h
@@ -47,17 +47,18 @@ class SecureCredentials;
class Credentials : public GrpcLibrary {
public:
~Credentials() GRPC_OVERRIDE;
+ virtual bool ApplyToCall(grpc_call* call) = 0;
protected:
- friend std::unique_ptr<Credentials> CompositeCredentials(
- const std::unique_ptr<Credentials>& creds1,
- const std::unique_ptr<Credentials>& creds2);
+ friend std::shared_ptr<Credentials> CompositeCredentials(
+ const std::shared_ptr<Credentials>& creds1,
+ const std::shared_ptr<Credentials>& creds2);
virtual SecureCredentials* AsSecureCredentials() = 0;
private:
friend std::shared_ptr<ChannelInterface> CreateChannel(
- const grpc::string& target, const std::unique_ptr<Credentials>& creds,
+ const grpc::string& target, const std::shared_ptr<Credentials>& creds,
const ChannelArguments& args);
virtual std::shared_ptr<ChannelInterface> CreateChannel(
@@ -80,20 +81,20 @@ struct SslCredentialsOptions {
};
// Factories for building different types of Credentials
-// The functions may return empty unique_ptr when credentials cannot be created.
+// The functions may return empty shared_ptr when credentials cannot be created.
// If a Credentials pointer is returned, it can still be invalid when used to
// create a channel. A lame channel will be created then and all rpcs will
// fail on it.
// Builds credentials with reasonable defaults.
-std::unique_ptr<Credentials> GoogleDefaultCredentials();
+std::shared_ptr<Credentials> GoogleDefaultCredentials();
// Builds SSL Credentials given SSL specific options
-std::unique_ptr<Credentials> SslCredentials(
+std::shared_ptr<Credentials> SslCredentials(
const SslCredentialsOptions& options);
// Builds credentials for use when running in GCE
-std::unique_ptr<Credentials> ComputeEngineCredentials();
+std::shared_ptr<Credentials> ComputeEngineCredentials();
// Builds service account credentials.
// json_key is the JSON key string containing the client's private key.
@@ -101,7 +102,7 @@ std::unique_ptr<Credentials> ComputeEngineCredentials();
// token_lifetime_seconds is the lifetime in seconds of each token acquired
// through this service account credentials. It should be positive and should
// not exceed grpc_max_auth_token_lifetime or will be cropped to this value.
-std::unique_ptr<Credentials> ServiceAccountCredentials(
+std::shared_ptr<Credentials> ServiceAccountCredentials(
const grpc::string& json_key, const grpc::string& scope,
long token_lifetime_seconds);
@@ -110,27 +111,27 @@ std::unique_ptr<Credentials> ServiceAccountCredentials(
// token_lifetime_seconds is the lifetime in seconds of each Json Web Token
// (JWT) created with this credentials. It should not exceed
// grpc_max_auth_token_lifetime or will be cropped to this value.
-std::unique_ptr<Credentials> JWTCredentials(
- const grpc::string& json_key, long token_lifetime_seconds);
+std::shared_ptr<Credentials> JWTCredentials(const grpc::string& json_key,
+ long token_lifetime_seconds);
// Builds refresh token credentials.
// json_refresh_token is the JSON string containing the refresh token along
// with a client_id and client_secret.
-std::unique_ptr<Credentials> RefreshTokenCredentials(
+std::shared_ptr<Credentials> RefreshTokenCredentials(
const grpc::string& json_refresh_token);
// Builds IAM credentials.
-std::unique_ptr<Credentials> IAMCredentials(
+std::shared_ptr<Credentials> IAMCredentials(
const grpc::string& authorization_token,
const grpc::string& authority_selector);
// Combines two credentials objects into a composite credentials
-std::unique_ptr<Credentials> CompositeCredentials(
- const std::unique_ptr<Credentials>& creds1,
- const std::unique_ptr<Credentials>& creds2);
+std::shared_ptr<Credentials> CompositeCredentials(
+ const std::shared_ptr<Credentials>& creds1,
+ const std::shared_ptr<Credentials>& creds2);
// Credentials for an unencrypted, unauthenticated channel
-std::unique_ptr<Credentials> InsecureCredentials();
+std::shared_ptr<Credentials> InsecureCredentials();
} // namespace grpc
diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc
index f38a694734..72cdd49d19 100644
--- a/src/cpp/client/client_context.cc
+++ b/src/cpp/client/client_context.cc
@@ -34,6 +34,7 @@
#include <grpc++/client_context.h>
#include <grpc/grpc.h>
+#include <grpc++/credentials.h>
#include <grpc++/time.h>
namespace grpc {
@@ -63,6 +64,17 @@ void ClientContext::AddMetadata(const grpc::string& meta_key,
send_initial_metadata_.insert(std::make_pair(meta_key, meta_value));
}
+void ClientContext::set_call(grpc_call* call,
+ const std::shared_ptr<ChannelInterface>& channel) {
+ GPR_ASSERT(call_ == nullptr);
+ call_ = call;
+ channel_ = channel;
+ if (creds_ && !creds_->ApplyToCall(call_)) {
+ grpc_call_cancel_with_status(call, GRPC_STATUS_CANCELLED,
+ "Failed to set credentials to rpc.");
+ }
+}
+
void ClientContext::TryCancel() {
if (call_) {
grpc_call_cancel(call_);
diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc
index 301430572a..510af2bb00 100644
--- a/src/cpp/client/create_channel.cc
+++ b/src/cpp/client/create_channel.cc
@@ -41,7 +41,7 @@ namespace grpc {
class ChannelArguments;
std::shared_ptr<ChannelInterface> CreateChannel(
- const grpc::string& target, const std::unique_ptr<Credentials>& creds,
+ const grpc::string& target, const std::shared_ptr<Credentials>& creds,
const ChannelArguments& args) {
return creds ? creds->CreateChannel(target, args)
: std::shared_ptr<ChannelInterface>(
diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc
index 8945b038de..668ea2e873 100644
--- a/src/cpp/client/insecure_credentials.cc
+++ b/src/cpp/client/insecure_credentials.cc
@@ -52,12 +52,14 @@ class InsecureCredentialsImpl GRPC_FINAL : public Credentials {
target, grpc_channel_create(target.c_str(), &channel_args)));
}
+ bool ApplyToCall(grpc_call* call) GRPC_OVERRIDE { return true; }
+
SecureCredentials* AsSecureCredentials() GRPC_OVERRIDE { return nullptr; }
};
} // namespace
-std::unique_ptr<Credentials> InsecureCredentials() {
- return std::unique_ptr<Credentials>(new InsecureCredentialsImpl());
+std::shared_ptr<Credentials> InsecureCredentials() {
+ return std::shared_ptr<Credentials>(new InsecureCredentialsImpl());
}
} // namespace grpc
diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc
index 48bf7430b2..b5134b3140 100644
--- a/src/cpp/client/secure_credentials.cc
+++ b/src/cpp/client/secure_credentials.cc
@@ -49,20 +49,24 @@ std::shared_ptr<grpc::ChannelInterface> SecureCredentials::CreateChannel(
grpc_secure_channel_create(c_creds_, target.c_str(), &channel_args)));
}
+bool SecureCredentials::ApplyToCall(grpc_call* call) {
+ return grpc_call_set_credentials(call, c_creds_) == GRPC_CALL_OK;
+}
+
namespace {
-std::unique_ptr<Credentials> WrapCredentials(grpc_credentials* creds) {
+std::shared_ptr<Credentials> WrapCredentials(grpc_credentials* creds) {
return creds == nullptr
? nullptr
- : std::unique_ptr<Credentials>(new SecureCredentials(creds));
+ : std::shared_ptr<Credentials>(new SecureCredentials(creds));
}
} // namespace
-std::unique_ptr<Credentials> GoogleDefaultCredentials() {
+std::shared_ptr<Credentials> GoogleDefaultCredentials() {
return WrapCredentials(grpc_google_default_credentials_create());
}
// Builds SSL Credentials given SSL specific options
-std::unique_ptr<Credentials> SslCredentials(
+std::shared_ptr<Credentials> SslCredentials(
const SslCredentialsOptions& options) {
grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {
options.pem_private_key.c_str(), options.pem_cert_chain.c_str()};
@@ -74,12 +78,12 @@ std::unique_ptr<Credentials> SslCredentials(
}
// Builds credentials for use when running in GCE
-std::unique_ptr<Credentials> ComputeEngineCredentials() {
+std::shared_ptr<Credentials> ComputeEngineCredentials() {
return WrapCredentials(grpc_compute_engine_credentials_create());
}
// Builds service account credentials.
-std::unique_ptr<Credentials> ServiceAccountCredentials(
+std::shared_ptr<Credentials> ServiceAccountCredentials(
const grpc::string& json_key, const grpc::string& scope,
long token_lifetime_seconds) {
if (token_lifetime_seconds <= 0) {
@@ -94,8 +98,8 @@ std::unique_ptr<Credentials> ServiceAccountCredentials(
}
// Builds JWT credentials.
-std::unique_ptr<Credentials> JWTCredentials(
- const grpc::string& json_key, long token_lifetime_seconds) {
+std::shared_ptr<Credentials> JWTCredentials(const grpc::string& json_key,
+ long token_lifetime_seconds) {
if (token_lifetime_seconds <= 0) {
gpr_log(GPR_ERROR,
"Trying to create JWTCredentials with non-positive lifetime");
@@ -107,14 +111,14 @@ std::unique_ptr<Credentials> JWTCredentials(
}
// Builds refresh token credentials.
-std::unique_ptr<Credentials> RefreshTokenCredentials(
+std::shared_ptr<Credentials> RefreshTokenCredentials(
const grpc::string& json_refresh_token) {
return WrapCredentials(
grpc_refresh_token_credentials_create(json_refresh_token.c_str()));
}
// Builds IAM credentials.
-std::unique_ptr<Credentials> IAMCredentials(
+std::shared_ptr<Credentials> IAMCredentials(
const grpc::string& authorization_token,
const grpc::string& authority_selector) {
return WrapCredentials(grpc_iam_credentials_create(
@@ -122,10 +126,10 @@ std::unique_ptr<Credentials> IAMCredentials(
}
// Combines two credentials objects into a composite credentials.
-std::unique_ptr<Credentials> CompositeCredentials(
- const std::unique_ptr<Credentials>& creds1,
- const std::unique_ptr<Credentials>& creds2) {
- // Note that we are not saving unique_ptrs to the two credentials
+std::shared_ptr<Credentials> CompositeCredentials(
+ const std::shared_ptr<Credentials>& creds1,
+ const std::shared_ptr<Credentials>& creds2) {
+ // Note that we are not saving shared_ptrs to the two credentials
// passed in here. This is OK because the underlying C objects (i.e.,
// creds1 and creds2) into grpc_composite_credentials_create will see their
// refcounts incremented.
diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h
index 77d575813e..ddf69911b5 100644
--- a/src/cpp/client/secure_credentials.h
+++ b/src/cpp/client/secure_credentials.h
@@ -46,6 +46,7 @@ class SecureCredentials GRPC_FINAL : public Credentials {
explicit SecureCredentials(grpc_credentials* c_creds) : c_creds_(c_creds) {}
~SecureCredentials() GRPC_OVERRIDE { grpc_credentials_release(c_creds_); }
grpc_credentials* GetRawCreds() { return c_creds_; }
+ bool ApplyToCall(grpc_call* call) GRPC_OVERRIDE;
std::shared_ptr<grpc::ChannelInterface> CreateChannel(
const string& target, const grpc::ChannelArguments& args) GRPC_OVERRIDE;
diff --git a/test/cpp/client/credentials_test.cc b/test/cpp/client/credentials_test.cc
index 6840418989..ee94f455a4 100644
--- a/test/cpp/client/credentials_test.cc
+++ b/test/cpp/client/credentials_test.cc
@@ -46,8 +46,7 @@ class CredentialsTest : public ::testing::Test {
};
TEST_F(CredentialsTest, InvalidServiceAccountCreds) {
- std::unique_ptr<Credentials> bad1 =
- ServiceAccountCredentials("", "", 1);
+ std::shared_ptr<Credentials> bad1 = ServiceAccountCredentials("", "", 1);
EXPECT_EQ(nullptr, bad1.get());
}
diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc
index f35b16fe55..a98d7c23b9 100644
--- a/test/cpp/end2end/end2end_test.cc
+++ b/test/cpp/end2end/end2end_test.cc
@@ -429,7 +429,7 @@ TEST_F(End2endTest, DiffPackageServices) {
// rpc and stream should fail on bad credentials.
TEST_F(End2endTest, BadCredentials) {
- std::unique_ptr<Credentials> bad_creds = ServiceAccountCredentials("", "", 1);
+ std::shared_ptr<Credentials> bad_creds = ServiceAccountCredentials("", "", 1);
EXPECT_EQ(nullptr, bad_creds.get());
std::shared_ptr<ChannelInterface> channel =
CreateChannel(server_address_.str(), bad_creds, ChannelArguments());
@@ -588,6 +588,54 @@ TEST_F(End2endTest, RpcMaxMessageSize) {
EXPECT_FALSE(s.IsOk());
}
+TEST_F(End2endTest, SetPerCallCredentials) {
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<Credentials> creds =
+ IAMCredentials("fake_token", "fake_selector");
+ context.set_credentials(creds);
+ grpc::string msg("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ // TODO(yangg) verify creds at the server side.
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.IsOk());
+}
+
+TEST_F(End2endTest, InsecurePerCallCredentials) {
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<Credentials> creds = InsecureCredentials();
+ context.set_credentials(creds);
+ grpc::string msg("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.IsOk());
+}
+
+TEST_F(End2endTest, OverridePerCallCredentials) {
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<Credentials> creds1 = InsecureCredentials();
+ context.set_credentials(creds1);
+ std::shared_ptr<Credentials> creds2 =
+ IAMCredentials("fake_token", "fake_selector");
+ context.set_credentials(creds2);
+ grpc::string msg("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ // TODO(yangg) verify creds at the server side.
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.IsOk());
+}
+
} // namespace testing
} // namespace grpc
diff --git a/test/cpp/interop/client_helper.cc b/test/cpp/interop/client_helper.cc
index a1dea383e6..09fd1c8913 100644
--- a/test/cpp/interop/client_helper.cc
+++ b/test/cpp/interop/client_helper.cc
@@ -82,7 +82,7 @@ std::shared_ptr<ChannelInterface> CreateChannelForTestCase(
FLAGS_server_port);
if (test_case == "service_account_creds") {
- std::unique_ptr<Credentials> creds;
+ std::shared_ptr<Credentials> creds;
GPR_ASSERT(FLAGS_enable_ssl);
grpc::string json_key = GetServiceAccountJsonKey();
std::chrono::seconds token_lifetime = std::chrono::hours(1);
@@ -91,13 +91,13 @@ std::shared_ptr<ChannelInterface> CreateChannelForTestCase(
return CreateTestChannel(host_port, FLAGS_server_host_override,
FLAGS_enable_ssl, FLAGS_use_prod_roots, creds);
} else if (test_case == "compute_engine_creds") {
- std::unique_ptr<Credentials> creds;
+ std::shared_ptr<Credentials> creds;
GPR_ASSERT(FLAGS_enable_ssl);
creds = ComputeEngineCredentials();
return CreateTestChannel(host_port, FLAGS_server_host_override,
FLAGS_enable_ssl, FLAGS_use_prod_roots, creds);
} else if (test_case == "jwt_token_creds") {
- std::unique_ptr<Credentials> creds;
+ std::shared_ptr<Credentials> creds;
GPR_ASSERT(FLAGS_enable_ssl);
grpc::string json_key = GetServiceAccountJsonKey();
std::chrono::seconds token_lifetime = std::chrono::hours(1);
diff --git a/test/cpp/util/create_test_channel.cc b/test/cpp/util/create_test_channel.cc
index f040acc4b1..dc48fa2d87 100644
--- a/test/cpp/util/create_test_channel.cc
+++ b/test/cpp/util/create_test_channel.cc
@@ -58,13 +58,13 @@ namespace grpc {
std::shared_ptr<ChannelInterface> CreateTestChannel(
const grpc::string& server, const grpc::string& override_hostname,
bool enable_ssl, bool use_prod_roots,
- const std::unique_ptr<Credentials>& creds) {
+ const std::shared_ptr<Credentials>& creds) {
ChannelArguments channel_args;
if (enable_ssl) {
const char* roots_certs = use_prod_roots ? "" : test_root_cert;
SslCredentialsOptions ssl_opts = {roots_certs, "", ""};
- std::unique_ptr<Credentials> channel_creds = SslCredentials(ssl_opts);
+ std::shared_ptr<Credentials> channel_creds = SslCredentials(ssl_opts);
if (!server.empty() && !override_hostname.empty()) {
channel_args.SetSslTargetNameOverride(override_hostname);
@@ -84,7 +84,7 @@ std::shared_ptr<ChannelInterface> CreateTestChannel(
const grpc::string& server, const grpc::string& override_hostname,
bool enable_ssl, bool use_prod_roots) {
return CreateTestChannel(server, override_hostname, enable_ssl,
- use_prod_roots, std::unique_ptr<Credentials>());
+ use_prod_roots, std::shared_ptr<Credentials>());
}
// Shortcut for end2end and interop tests.
diff --git a/test/cpp/util/create_test_channel.h b/test/cpp/util/create_test_channel.h
index 5c298ce850..5f2609ddd8 100644
--- a/test/cpp/util/create_test_channel.h
+++ b/test/cpp/util/create_test_channel.h
@@ -52,7 +52,7 @@ std::shared_ptr<ChannelInterface> CreateTestChannel(
std::shared_ptr<ChannelInterface> CreateTestChannel(
const grpc::string& server, const grpc::string& override_hostname,
bool enable_ssl, bool use_prod_roots,
- const std::unique_ptr<Credentials>& creds);
+ const std::shared_ptr<Credentials>& creds);
} // namespace grpc
diff --git a/test/cpp/util/grpc_cli.cc b/test/cpp/util/grpc_cli.cc
index d71a7a0b77..ad3c0af877 100644
--- a/test/cpp/util/grpc_cli.cc
+++ b/test/cpp/util/grpc_cli.cc
@@ -104,7 +104,7 @@ int main(int argc, char** argv) {
std::stringstream input_stream;
input_stream << input_file.rdbuf();
- std::unique_ptr<grpc::Credentials> creds;
+ std::shared_ptr<grpc::Credentials> creds;
if (!FLAGS_enable_ssl) {
creds = grpc::InsecureCredentials();
} else {