aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorGravatar Yash Tibrewal <yashkt@google.com>2018-12-28 16:03:20 -0800
committerGravatar Yash Tibrewal <yashkt@google.com>2018-12-28 16:03:20 -0800
commitaecc5f7285faedec634c99aff0b48eea86d3861a (patch)
treed130098b018c5576000a8598eb50e9f4252d7b7e /test
parentfc7d0911a3a44d7bc926d3db99b7300a0c0f33dc (diff)
Add client interceptor test for bidi streaming hijacking interceptor
Diffstat (limited to 'test')
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc91
-rw-r--r--test/cpp/end2end/interceptors_util.cc10
-rw-r--r--test/cpp/end2end/interceptors_util.h3
3 files changed, 104 insertions, 0 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 8abf4eb3f4..ab387aa914 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -270,6 +270,84 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
+class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
+ public:
+ BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message().find("Hello"), 0);
+ msg = req.message();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
+ "testvalue");
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(msg);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
+ ->message()
+ .find("Hello"),
+ 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ grpc::string msg;
+};
+
+class BidiStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new BidiStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@@ -546,6 +624,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
+ new BidiStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc
index e0ad7d1526..900f02b5f3 100644
--- a/test/cpp/end2end/interceptors_util.cc
+++ b/test/cpp/end2end/interceptors_util.cc
@@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first == key && pair.second == value) {
+ return true;
+ }
+ }
+ return false;
+}
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
index 659e613d2e..419845e5f6 100644
--- a/test/cpp/end2end/interceptors_util.h
+++ b/test/cpp/end2end/interceptors_util.h
@@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value);
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();