aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yash Tibrewal <yashkt@google.com>2018-10-21 23:06:21 -0700
committerGravatar Yash Tibrewal <yashkt@google.com>2018-10-21 23:06:21 -0700
commit3a17f5b05ec6adce638fd03168a923e727759969 (patch)
tree9d8cb4f288ac2723761baeff868daa1de7540f86
parent52765e9cb1e18780fcd4701e91e019bf78c9b957 (diff)
Working on tests
-rw-r--r--include/grpcpp/impl/codegen/client_interceptor.h6
-rw-r--r--src/cpp/server/server_cc.cc29
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc132
3 files changed, 139 insertions, 28 deletions
diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h
index 6feec224ce..06f009e7d3 100644
--- a/include/grpcpp/impl/codegen/client_interceptor.h
+++ b/include/grpcpp/impl/codegen/client_interceptor.h
@@ -47,7 +47,7 @@ class ClientRpcInfo {
public:
ClientRpcInfo() {}
ClientRpcInfo(grpc::ClientContext* ctx, const char* method,
- const grpc::Channel* channel,
+ grpc::Channel* channel,
const std::vector<std::unique_ptr<
experimental::ClientInterceptorFactoryInterface>>& creators)
: ctx_(ctx), method_(method), channel_(channel) {
@@ -64,7 +64,7 @@ class ClientRpcInfo {
// Getter methods
const char* method() { return method_; }
- const Channel* channel() { return channel_; }
+ Channel* channel() { return channel_; }
grpc::ClientContext* client_context() { return ctx_; }
public:
@@ -79,7 +79,7 @@ class ClientRpcInfo {
private:
grpc::ClientContext* ctx_ = nullptr;
const char* method_ = nullptr;
- const grpc::Channel* channel_ = nullptr;
+ grpc::Channel* channel_ = nullptr;
std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;
bool hijacked_ = false;
int hijacked_interceptor_ = false;
diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc
index d53c3534a9..5124044a8b 100644
--- a/src/cpp/server/server_cc.cc
+++ b/src/cpp/server/server_cc.cc
@@ -243,13 +243,13 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
interceptor_methods_.SetCall(&call_);
interceptor_methods_.SetReverse();
- /* Set interception point for RECV INITIAL METADATA */
+ // Set interception point for RECV INITIAL METADATA
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_);
if (has_request_payload_) {
- /* Set interception point for RECV MESSAGE */
+ // Set interception point for RECV MESSAGE
auto* handler = resources_ ? method_->handler()
: server_->resource_exhausted_handler_.get();
request_ = handler->Deserialize(request_payload_, &request_status_);
@@ -264,8 +264,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
if (interceptor_methods_.RunInterceptors(f)) {
ContinueRunAfterInterception();
} else {
- /* There were interceptors to be run, so ContinueRunAfterInterception
- will be run when interceptors are done. */
+ // There were interceptors to be run, so ContinueRunAfterInterception
+ // will be run when interceptors are done.
}
}
@@ -318,7 +318,6 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
grpc_metadata_array request_metadata_;
grpc_byte_buffer* request_payload_;
grpc_completion_queue* cq_;
- bool done_intercepting_ = false;
};
// Implementation of ThreadManager. Each instance of SyncRequestThreadManager
@@ -763,7 +762,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
context_->set_call(call_);
context_->cq_ = call_cq_;
if (call_wrapper_.call() == nullptr) {
- /* Fill it since it is empty. */
+ // Fill it since it is empty.
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(), nullptr);
}
@@ -773,7 +772,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
if (*status && call_ && call_wrapper_.server_rpc_info()) {
done_intercepting_ = true;
- /* Set interception point for RECV INITIAL METADATA */
+ // Set interception point for RECV INITIAL METADATA
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_);
@@ -781,11 +780,11 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
ContinueFinalizeResultAfterInterception,
this);
if (interceptor_methods_.RunInterceptors(f)) {
- /* There are no interceptors to run. Continue */
+ // There are no interceptors to run. Continue
} else {
- /* There were interceptors to be run, so
- ContinueFinalizeResultAfterInterception will be run when interceptors are
- done. */
+ // There were interceptors to be run, so
+ // ContinueFinalizeResultAfterInterception will be run when interceptors
+ // are done.
return false;
}
}
@@ -802,7 +801,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
void ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception() {
context_->BeginCompletionOp(&call_wrapper_);
- /* Queue a tag which will be returned immediately */
+ // Queue a tag which will be returned immediately
dummy_alarm_ = new Alarm();
static_cast<Alarm*>(dummy_alarm_)
->Set(notification_cq_,
@@ -844,7 +843,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest(
bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
bool* status) {
- /* If we are done intercepting, there is nothing more for us to do */
+ // If we are done intercepting, there is nothing more for us to do
if (done_intercepting_) {
return BaseAsyncRequest::FinalizeResult(tag, status);
}
@@ -870,7 +869,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag,
bool* status) {
if (GenericAsyncRequest::FinalizeResult(tag, status)) {
- /* We either had no interceptors run or we are done interceptinh */
+ // We either had no interceptors run or we are done intercepting
if (*status) {
new UnimplementedAsyncRequest(server_, cq_);
new UnimplementedAsyncResponse(this);
@@ -878,7 +877,7 @@ bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag,
delete this;
}
} else {
- /* The tag was swallowed due to interception. We will see it again. */
+ // The tag was swallowed due to interception. We will see it again.
}
return false;
}
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 8537f35602..2e0db8a9b9 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -60,6 +60,8 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test {
std::unique_ptr<Server> server_;
};
+/* This interceptor does nothing. Just keeps a global count on the number of
+ * times it was invoked. */
class DummyInterceptor : public experimental::Interceptor {
public:
DummyInterceptor(experimental::ClientRpcInfo* info) {}
@@ -91,6 +93,7 @@ class DummyInterceptorFactory
}
};
+/* Hijacks Echo RPC and fills in the expected values */
class HijackingInterceptor : public experimental::Interceptor {
public:
HijackingInterceptor(experimental::ClientRpcInfo* info) {
@@ -195,6 +198,111 @@ class HijackingInterceptorFactory
}
};
+class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
+ public:
+ HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ // Make sure it is the right method
+ EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ gpr_log(GPR_ERROR, "ran this");
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), 1);
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ // Make a copy of the map
+ metadata_map_ = *map;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+ EXPECT_EQ(req.message(), "Hello");
+ auto stub = grpc::testing::EchoTestService::NewStub(
+ std::shared_ptr<Channel>(info_->channel()));
+ ClientContext ctx;
+ EchoResponse resp;
+ Status s = stub->Echo(&ctx, req, &resp);
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ // Check that we got the hijacked message, and re-insert the expected
+ // message
+ EXPECT_EQ(resp->message(), "Hello1");
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here at the moment
+ EXPECT_EQ(map->size(), 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ // Insert a different message than expected
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello1");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), 0);
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ std::multimap<grpc::string, grpc::string> metadata_map_;
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) {
@@ -268,6 +376,19 @@ class LoggingInterceptorFactory
}
};
+void MakeCall(std::shared_ptr<Channel> channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ Status s = stub->Echo(&ctx, req, &resp);
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+}
+
TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
@@ -284,16 +405,7 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
}
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
- auto stub = grpc::testing::EchoTestService::NewStub(channel);
- ClientContext ctx;
- EchoRequest req;
- req.mutable_param()->set_echo_metadata(true);
- ctx.AddMetadata("testkey", "testvalue");
- req.set_message("Hello");
- EchoResponse resp;
- Status s = stub->Echo(&ctx, req, &resp);
- EXPECT_EQ(s.ok(), true);
- EXPECT_EQ(resp.message(), "Hello");
+ MakeCall(channel);
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}