aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorGravatar Yash Tibrewal <yashkt@google.com>2018-10-28 23:36:59 -0700
committerGravatar GitHub <noreply@github.com>2018-10-28 23:36:59 -0700
commit01313976e1a44b5c9625d3a349fffa55471beff4 (patch)
tree805596796cce33154e0d875c4c1ade918ba958f2 /test
parentffac9d90b18cb076b1c952faa55ce4e049cbc9a6 (diff)
parent395edbfa24968b8406a0c157874d3cb473076df5 (diff)
Merge pull request #16842 from yashykt/interceptors
Experimental API for Client and Server Interception
Diffstat (limited to 'test')
-rw-r--r--test/cpp/end2end/BUILD53
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc606
-rw-r--r--test/cpp/end2end/interceptors_util.h308
-rw-r--r--test/cpp/end2end/server_interceptors_end2end_test.cc623
-rw-r--r--test/cpp/interop/client_helper.h2
5 files changed, 1592 insertions, 0 deletions
diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD
index 0415efc1ef..235249e8bf 100644
--- a/test/cpp/end2end/BUILD
+++ b/test/cpp/end2end/BUILD
@@ -35,6 +35,19 @@ grpc_cc_library(
],
)
+grpc_cc_library(
+ name = "interceptors_util",
+ testonly = True,
+ hdrs = ["interceptors_util.h"],
+ external_deps = [
+ "gtest",
+ ],
+ deps = [
+ "//src/proto/grpc/testing:echo_proto",
+ "//test/cpp/util:test_util",
+ ],
+)
+
grpc_cc_test(
name = "async_end2end_test",
srcs = ["async_end2end_test.cc"],
@@ -117,6 +130,26 @@ grpc_cc_test(
],
)
+grpc_cc_test(
+ name = "client_interceptors_end2end_test",
+ srcs = ["client_interceptors_end2end_test.cc"],
+ external_deps = [
+ "gtest",
+ ],
+ deps = [
+ ":interceptors_util",
+ ":test_service_impl",
+ "//:gpr",
+ "//:grpc",
+ "//:grpc++",
+ "//src/proto/grpc/testing:echo_messages_proto",
+ "//src/proto/grpc/testing:echo_proto",
+ "//test/core/util:gpr_test_util",
+ "//test/core/util:grpc_test_util",
+ "//test/cpp/util:test_util",
+ ],
+)
+
grpc_cc_library(
name = "end2end_test_lib",
testonly = True,
@@ -470,6 +503,26 @@ grpc_cc_binary(
)
grpc_cc_test(
+ name = "server_interceptors_end2end_test",
+ srcs = ["server_interceptors_end2end_test.cc"],
+ external_deps = [
+ "gtest",
+ ],
+ deps = [
+ ":interceptors_util",
+ ":test_service_impl",
+ "//:gpr",
+ "//:grpc",
+ "//:grpc++",
+ "//src/proto/grpc/testing:echo_messages_proto",
+ "//src/proto/grpc/testing:echo_proto",
+ "//test/core/util:gpr_test_util",
+ "//test/core/util:grpc_test_util",
+ "//test/cpp/util:test_util",
+ ],
+)
+
+grpc_cc_test(
name = "server_load_reporting_end2end_test",
srcs = ["server_load_reporting_end2end_test.cc"],
external_deps = [
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
new file mode 100644
index 0000000000..5720e87478
--- /dev/null
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -0,0 +1,606 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <vector>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/client_interceptor.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
+ protected:
+ ClientInterceptorsStreamingEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + std::to_string(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
+
+ std::string server_address_;
+ EchoTestServiceStreamingImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+class ClientInterceptorsEnd2endTest : public ::testing::Test {
+ protected:
+ ClientInterceptorsEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + std::to_string(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
+
+ std::string server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+/* This interceptor does nothing. Just keeps a global count on the number of
+ * times it was invoked. */
+class DummyInterceptor : public experimental::Interceptor {
+ public:
+ DummyInterceptor(experimental::ClientRpcInfo* info) {}
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ num_times_run_++;
+ } else if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::
+ POST_RECV_INITIAL_METADATA)) {
+ num_times_run_reverse_++;
+ }
+ methods->Proceed();
+ }
+
+ static void Reset() {
+ num_times_run_.store(0);
+ num_times_run_reverse_.store(0);
+ }
+
+ static int GetNumTimesRun() {
+ EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
+ return num_times_run_.load();
+ }
+
+ private:
+ static std::atomic<int> num_times_run_;
+ static std::atomic<int> num_times_run_reverse_;
+};
+
+std::atomic<int> DummyInterceptor::num_times_run_;
+std::atomic<int> DummyInterceptor::num_times_run_reverse_;
+
+class DummyInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new DummyInterceptor(info);
+ }
+};
+
+/* Hijacks Echo RPC and fills in the expected values */
+class HijackingInterceptor : public experimental::Interceptor {
+ public:
+ HijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ // Make sure it is the right method
+ EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+ EXPECT_EQ(req.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ // Check that we got the hijacked message, and re-insert the expected
+ // message
+ EXPECT_EQ(resp->message(), "Hello1");
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here at the moment
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ // Insert a different message than expected
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello1");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+};
+
+class HijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new HijackingInterceptor(info);
+ }
+};
+
+class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
+ public:
+ HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ // Make sure it is the right method
+ EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ // Make a copy of the map
+ metadata_map_ = *map;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+ EXPECT_EQ(req.message(), "Hello");
+ req_ = req;
+ stub_ = grpc::testing::EchoTestService::NewStub(
+ methods->GetInterceptedChannel());
+ ctx_.AddMetadata(metadata_map_.begin()->first,
+ metadata_map_.begin()->second);
+ stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
+ [this, methods](Status s) {
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp_.message(), "Hello");
+ methods->Hijack();
+ });
+ // There isn't going to be any other interesting operation in this batch,
+ // so it is fine to return
+ return;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ // Check that we got the hijacked message, and re-insert the expected
+ // message
+ EXPECT_EQ(resp->message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here at the moment
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ // Insert a different message than expected
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(resp_.message());
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ std::multimap<grpc::string, grpc::string> metadata_map_;
+ ClientContext ctx_;
+ EchoRequest req_;
+ EchoResponse resp_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+};
+
+class HijackingInterceptorMakesAnotherCallFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new HijackingInterceptorMakesAnotherCall(info);
+ }
+};
+
+class LoggingInterceptor : public experimental::Interceptor {
+ public:
+ LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+ EXPECT_TRUE(req.message().find("Hello") == 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ EXPECT_TRUE(resp->message().find("Hello") == 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+};
+
+class LoggingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new LoggingInterceptor(info);
+ }
+};
+
+TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ // Add 20 dummy interceptors before hijacking interceptor
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ creators->push_back(std::unique_ptr<HijackingInterceptorFactory>(
+ new HijackingInterceptorFactory()));
+ // Add 20 dummy interceptors after hijacking interceptor
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ MakeCall(channel);
+ // Make sure only 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
+ ChannelArguments args;
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ creators->push_back(std::unique_ptr<HijackingInterceptorFactory>(
+ new HijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ MakeCall(channel);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest,
+ ClientInterceptorHijackingMakesAnotherCallTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ // Add 5 dummy interceptors before hijacking interceptor
+ for (auto i = 0; i < 5; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ creators->push_back(
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
+ new HijackingInterceptorMakesAnotherCallFactory()));
+ // Add 7 dummy interceptors after hijacking interceptor
+ for (auto i = 0; i < 7; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = server_->experimental().InProcessChannelWithInterceptors(
+ args, std::move(creators));
+
+ MakeCall(channel);
+ // Make sure all interceptors were run once, since the hijacking interceptor
+ // makes an RPC on the intercepted channel
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest,
+ ClientInterceptorLoggingTestWithCallback) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = server_->experimental().InProcessChannelWithInterceptors(
+ args, std::move(creators));
+ MakeCallbackCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeClientStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeServerStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto creators = std::unique_ptr<std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
+ new std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
+ creators->push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators->push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc_test_init(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
new file mode 100644
index 0000000000..5f0aa37dc0
--- /dev/null
+++ b/test/cpp/end2end/interceptors_util.h
@@ -0,0 +1,308 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+class EchoTestServiceStreamingImpl : public EchoTestService::Service {
+ public:
+ ~EchoTestServiceStreamingImpl() override {}
+
+ Status BidiStream(
+ ServerContext* context,
+ grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ while (stream->Read(&req)) {
+ resp.set_message(req.message());
+ EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
+ }
+ return Status::OK;
+ }
+
+ Status RequestStream(ServerContext* context,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* resp) override {
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ EchoRequest req;
+ string response_str = "";
+ while (reader->Read(&req)) {
+ response_str += req.message();
+ }
+ resp->set_message(response_str);
+ return Status::OK;
+ }
+
+ Status ResponseStream(ServerContext* context, const EchoRequest* req,
+ ServerWriter<EchoResponse>* writer) override {
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ EchoResponse resp;
+ resp.set_message(req->message());
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(resp));
+ }
+ return Status::OK;
+ }
+};
+
+void MakeCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ Status s = stub->Echo(&ctx, req, &resp);
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+}
+
+void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ writer->Write(req);
+ expected_resp += "Hello";
+ }
+ writer->WritesDone();
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), expected_resp);
+}
+
+void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ string expected_resp = "";
+ auto reader = stub->ResponseStream(&ctx, req);
+ int count = 0;
+ while (reader->Read(&resp)) {
+ EXPECT_EQ(resp.message(), "Hello");
+ count++;
+ }
+ ASSERT_EQ(count, 10);
+ Status s = reader->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ ctx.AddMetadata("testkey", "testvalue");
+ auto stream = stub->BidiStream(&ctx);
+ for (auto i = 0; i < 10; i++) {
+ req.set_message("Hello" + std::to_string(i));
+ stream->Write(req);
+ stream->Read(&resp);
+ EXPECT_EQ(req.message(), resp.message());
+ }
+ ASSERT_TRUE(stream->WritesDone());
+ Status s = stream->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+void MakeCallbackCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ stub->experimental_async()->Echo(&ctx, &req, &resp,
+ [&resp, &mu, &done, &cv](Status s) {
+ // gpr_log(GPR_ERROR, "got the callback");
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first.starts_with(key) && pair.second.starts_with(value)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); }
+
+class Verifier {
+ public:
+ Verifier() : lambda_run_(false) {}
+ // Expect sets the expected ok value for a specific tag
+ Verifier& Expect(int i, bool expect_ok) {
+ return ExpectUnless(i, expect_ok, false);
+ }
+ // ExpectUnless sets the expected ok value for a specific tag
+ // unless the tag was already marked seen (as a result of ExpectMaybe)
+ Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
+ if (!seen) {
+ expectations_[tag(i)] = expect_ok;
+ }
+ return *this;
+ }
+ // ExpectMaybe sets the expected ok value for a specific tag, but does not
+ // require it to appear
+ // If it does, sets *seen to true
+ Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
+ if (!*seen) {
+ maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
+ }
+ return *this;
+ }
+
+ // Next waits for 1 async tag to complete, checks its
+ // expectations, and returns the tag
+ int Next(CompletionQueue* cq, bool ignore_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ GotTag(got_tag, ok, ignore_ok);
+ return detag(got_tag);
+ }
+
+ template <typename T>
+ CompletionQueue::NextStatus DoOnceThenAsyncNext(
+ CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
+ std::function<void(void)> lambda) {
+ if (lambda_run_) {
+ return cq->AsyncNext(got_tag, ok, deadline);
+ } else {
+ lambda_run_ = true;
+ return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
+ }
+ }
+
+ // Verify keeps calling Next until all currently set
+ // expected tags are complete
+ void Verify(CompletionQueue* cq) { Verify(cq, false); }
+
+ // This version of Verify allows optionally ignoring the
+ // outcome of the expectation
+ void Verify(CompletionQueue* cq, bool ignore_ok) {
+ GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
+ while (!expectations_.empty()) {
+ Next(cq, ignore_ok);
+ }
+ }
+
+ // This version of Verify stops after a certain deadline, and uses the
+ // DoThenAsyncNext API
+ // to call the lambda
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline,
+ const std::function<void(void)>& lambda) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::TIMEOUT);
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::GOT_EVENT);
+ GotTag(got_tag, ok, false);
+ }
+ }
+ }
+
+ private:
+ void GotTag(void* got_tag, bool ok, bool ignore_ok) {
+ auto it = expectations_.find(got_tag);
+ if (it != expectations_.end()) {
+ if (!ignore_ok) {
+ EXPECT_EQ(it->second, ok);
+ }
+ expectations_.erase(it);
+ } else {
+ auto it2 = maybe_expectations_.find(got_tag);
+ if (it2 != maybe_expectations_.end()) {
+ if (it2->second.seen != nullptr) {
+ EXPECT_FALSE(*it2->second.seen);
+ *it2->second.seen = true;
+ }
+ if (!ignore_ok) {
+ EXPECT_EQ(it2->second.ok, ok);
+ }
+ } else {
+ gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
+ abort();
+ }
+ }
+ }
+
+ struct MaybeExpect {
+ bool ok;
+ bool* seen;
+ };
+
+ std::map<void*, bool> expectations_;
+ std::map<void*, MaybeExpect> maybe_expectations_;
+ bool lambda_run_;
+};
+
+} // namespace testing
+} // namespace grpc
diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc
new file mode 100644
index 0000000000..44ba2a6009
--- /dev/null
+++ b/test/cpp/end2end/server_interceptors_end2end_test.cc
@@ -0,0 +1,623 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <vector>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/impl/codegen/server_interceptor.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+/* This interceptor does nothing. Just keeps a global count on the number of
+ * times it was invoked. */
+class DummyInterceptor : public experimental::Interceptor {
+ public:
+ DummyInterceptor(experimental::ServerRpcInfo* info) {}
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ num_times_run_++;
+ } else if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::
+ POST_RECV_INITIAL_METADATA)) {
+ num_times_run_reverse_++;
+ }
+ methods->Proceed();
+ }
+
+ static void Reset() {
+ num_times_run_.store(0);
+ num_times_run_reverse_.store(0);
+ }
+
+ static int GetNumTimesRun() {
+ EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
+ return num_times_run_.load();
+ }
+
+ private:
+ static std::atomic<int> num_times_run_;
+ static std::atomic<int> num_times_run_reverse_;
+};
+
+std::atomic<int> DummyInterceptor::num_times_run_;
+std::atomic<int> DummyInterceptor::num_times_run_reverse_;
+
+class DummyInterceptorFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new DummyInterceptor(info);
+ }
+};
+
+class LoggingInterceptor : public experimental::Interceptor {
+ public:
+ LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSendMessage();
+ auto copied_buffer = *buffer;
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+ EXPECT_TRUE(req.message().find("Hello") == 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_STATUS)) {
+ auto* map = methods->GetSendTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.find("testkey") == 0 &&
+ pair.second.find("testvalue") == 0;
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto status = methods->GetSendStatus();
+ EXPECT_EQ(status.ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.find("testkey") == 0 &&
+ pair.second.find("testvalue") == 0;
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ EXPECT_TRUE(resp->message().find("Hello") == 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_CLOSE)) {
+ // Got nothing interesting to do here
+ }
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ServerRpcInfo* info_;
+};
+
+class LoggingInterceptorFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new LoggingInterceptor(info);
+ }
+};
+
+void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ ctx.AddMetadata("testkey", "testvalue");
+ auto stream = stub->BidiStream(&ctx);
+ for (auto i = 0; i < 10; i++) {
+ req.set_message("Hello" + std::to_string(i));
+ stream->Write(req);
+ stream->Read(&resp);
+ EXPECT_EQ(req.message(), resp.message());
+ }
+ ASSERT_TRUE(stream->WritesDone());
+ Status s = stream->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test {
+ protected:
+ ServerInterceptorsEnd2endSyncUnaryTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + std::to_string(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ server_ = builder.BuildAndStart();
+ }
+ std::string server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel = CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test {
+ protected:
+ ServerInterceptorsEnd2endSyncStreamingTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + std::to_string(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ server_ = builder.BuildAndStart();
+ }
+ std::string server_address_;
+ EchoTestServiceStreamingImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel = CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeClientStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel = CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeServerStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel = CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeBidiStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {};
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel = CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncEcho(&cli_ctx, send_request, cq.get()));
+
+ service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(),
+ cq.get(), tag(2));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel = CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1)));
+
+ service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq.get());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ cli_stream->Write(send_request, tag(3));
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq.get());
+
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ AsyncGenericService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterAsyncGenericService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel = CreateChannel(server_address, InsecureChannelCredentials());
+ GenericStub generic_stub(channel);
+
+ const grpc::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+
+ // The string needs to be long enough to test heap-based slice.
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+
+ std::unique_ptr<GenericClientAsyncReaderWriter> call =
+ generic_stub.PrepareCall(&cli_ctx, kMethodName, cq.get());
+ call->StartCall(tag(1));
+ Verifier().Expect(1, true).Verify(cq.get());
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ call->Write(*send_buffer, tag(2));
+ // Send ByteBuffer can be destroyed after calling Write.
+ send_buffer.reset();
+ Verifier().Expect(2, true).Verify(cq.get());
+ call->WritesDone(tag(3));
+ Verifier().Expect(3, true).Verify(cq.get());
+
+ service.RequestCall(&srv_ctx, &stream, cq.get(), cq.get(), tag(4));
+
+ Verifier().Expect(4, true).Verify(cq.get());
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ ByteBuffer recv_buffer;
+ stream.Read(&recv_buffer, tag(5));
+ Verifier().Expect(5, true).Verify(cq.get());
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ stream.Write(*send_buffer, tag(6));
+ send_buffer.reset();
+ Verifier().Expect(6, true).Verify(cq.get());
+
+ stream.Finish(Status::OK, tag(7));
+ Verifier().Expect(7, true).Verify(cq.get());
+
+ recv_buffer.Clear();
+ call->Read(&recv_buffer, tag(8));
+ Verifier().Expect(8, true).Verify(cq.get());
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+
+ call->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel =
+ CreateChannel(server_address, InsecureChannelCredentials());
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get()));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier().Expect(4, true).Verify(cq.get());
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
+class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test {
+};
+
+TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ TestServiceImpl service;
+ builder.RegisterService(&service);
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel =
+ CreateChannel(server_address, InsecureChannelCredentials());
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ Status recv_status =
+ stub->Unimplemented(&cli_ctx, send_request, &recv_response);
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ grpc_recycle_unused_port(port);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc_test_init(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/test/cpp/interop/client_helper.h b/test/cpp/interop/client_helper.h
index eada2f671f..7dee85cc98 100644
--- a/test/cpp/interop/client_helper.h
+++ b/test/cpp/interop/client_helper.h
@@ -19,10 +19,12 @@
#ifndef GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H
#define GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H
+#include <functional>
#include <memory>
#include <unordered_map>
#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
#include "src/core/lib/surface/call_test_only.h"