aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yash Tibrewal <yashkt@google.com>2019-01-04 15:29:00 -0800
committerGravatar Yash Tibrewal <yashkt@google.com>2019-01-04 15:29:00 -0800
commit8ba5922e8767e4c10ca38cb46b19d5878c310598 (patch)
tree4501b2474162197ef55059521067bc5b045d9f8a
parent4886c937c1aa2c52abab23d476f342d076010836 (diff)
parent84742e565ae7f34722b8907dd4fa5f2bdb8560bb (diff)
Merge branch 'master' into failhijackedsend
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc91
-rw-r--r--test/cpp/end2end/interceptors_util.cc10
-rw-r--r--test/cpp/end2end/interceptors_util.h3
3 files changed, 104 insertions, 0 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 9ec89ce1ab..f8728fc595 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -270,6 +270,75 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
+class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
+ public:
+ BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message().find("Hello"), 0u);
+ msg = req.message();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
+ "testvalue");
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(msg);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
+ ->message()
+ .find("Hello"),
+ 0u);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ grpc::string msg;
+};
+
class ClientStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
@@ -324,6 +393,15 @@ class ClientStreamingRpcHijackingInterceptorFactory
}
};
+class BidiStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new BidiStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@@ -633,6 +711,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
+ new BidiStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc
index e0ad7d1526..900f02b5f3 100644
--- a/test/cpp/end2end/interceptors_util.cc
+++ b/test/cpp/end2end/interceptors_util.cc
@@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first == key && pair.second == value) {
+ return true;
+ }
+ }
+ return false;
+}
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
index 659e613d2e..419845e5f6 100644
--- a/test/cpp/end2end/interceptors_util.h
+++ b/test/cpp/end2end/interceptors_util.h
@@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value);
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();