aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/grpcpp/impl/codegen/call.h5
-rw-r--r--include/grpcpp/impl/codegen/server_interface.h8
-rw-r--r--src/cpp/common/completion_queue_cc.cc1
-rw-r--r--src/cpp/server/server_cc.cc20
-rw-r--r--test/cpp/end2end/interceptors_util.h131
-rw-r--r--test/cpp/end2end/server_interceptors_end2end_test.cc156
6 files changed, 312 insertions, 9 deletions
diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h
index 505055e7e6..c6894003b6 100644
--- a/include/grpcpp/impl/codegen/call.h
+++ b/include/grpcpp/impl/codegen/call.h
@@ -1189,13 +1189,14 @@ class CallOpSet : public CallOpSetInterface,
}
bool FinalizeResult(void** tag, bool* status) override {
- // gpr_log(GPR_ERROR, "finalizing result %p", this);
+ gpr_log(GPR_ERROR, "finalizing result");
if (done_intercepting_) {
// We have already finished intercepting and filling in the results. This
// round trip from the core needed to be made because interceptors were
// run
// gpr_log(GPR_ERROR, "done intercepting");
*tag = return_tag_;
+ *status = saved_status_;
g_core_codegen_interface->grpc_call_unref(call_.call());
return true;
}
@@ -1206,6 +1207,7 @@ class CallOpSet : public CallOpSetInterface,
this->Op4::FinishOp(status);
this->Op5::FinishOp(status);
this->Op6::FinishOp(status);
+ saved_status_ = *status;
// gpr_log(GPR_ERROR, "done finish ops");
if (RunInterceptorsPostRecv()) {
*tag = return_tag_;
@@ -1301,6 +1303,7 @@ class CallOpSet : public CallOpSetInterface,
Call call_;
bool done_intercepting_ = false;
InterceptorBatchMethodsImpl interceptor_methods_;
+ bool saved_status_;
};
} // namespace internal
diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h
index aa7dbe8b70..532fff7a54 100644
--- a/include/grpcpp/impl/codegen/server_interface.h
+++ b/include/grpcpp/impl/codegen/server_interface.h
@@ -173,8 +173,6 @@ class ServerInterface : public internal::CallHook {
internal::Call call_wrapper_;
internal::InterceptorBatchMethodsImpl interceptor_methods_;
bool done_intercepting_;
- void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on
- alarm.h here */
};
class RegisteredAsyncRequest : public BaseAsyncRequest {
@@ -186,6 +184,7 @@ class ServerInterface : public internal::CallHook {
const char* name);
virtual bool FinalizeResult(void** tag, bool* status) override {
+ gpr_log(GPR_ERROR, "finalize registeredasyncrequest");
/* If we are done intercepting, then there is nothing more for us to do */
if (done_intercepting_) {
return BaseAsyncRequest::FinalizeResult(tag, status);
@@ -239,6 +238,7 @@ class ServerInterface : public internal::CallHook {
notification_cq_(notification_cq),
tag_(tag),
request_(request) {
+ gpr_log(GPR_ERROR, "new payload request");
IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(),
notification_cq);
}
@@ -248,6 +248,7 @@ class ServerInterface : public internal::CallHook {
}
bool FinalizeResult(void** tag, bool* status) override {
+ gpr_log(GPR_ERROR, "finalize PayloadAsyncRequest");
/* If we are done intercepting, then there is nothing more for us to do */
if (done_intercepting_) {
return RegisteredAsyncRequest::FinalizeResult(tag, status);
@@ -312,6 +313,7 @@ class ServerInterface : public internal::CallHook {
ServerCompletionQueue* notification_cq, void* tag,
Message* message) {
GPR_CODEGEN_ASSERT(method);
+ gpr_log(GPR_ERROR, "request async method with payload");
new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq,
notification_cq, tag, message);
}
@@ -322,6 +324,7 @@ class ServerInterface : public internal::CallHook {
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag) {
GPR_CODEGEN_ASSERT(method);
+ gpr_log(GPR_ERROR, "request async method with no payload");
new NoPayloadAsyncRequest(method, this, context, stream, call_cq,
notification_cq, tag);
}
@@ -331,6 +334,7 @@ class ServerInterface : public internal::CallHook {
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq,
void* tag) {
+ gpr_log(GPR_ERROR, "request async generic call");
new GenericAsyncRequest(this, context, stream, call_cq, notification_cq,
tag, true);
}
diff --git a/src/cpp/common/completion_queue_cc.cc b/src/cpp/common/completion_queue_cc.cc
index 6893201e2e..5dfcfa2984 100644
--- a/src/cpp/common/completion_queue_cc.cc
+++ b/src/cpp/common/completion_queue_cc.cc
@@ -64,6 +64,7 @@ CompletionQueue::NextStatus CompletionQueue::AsyncNextInternal(
*ok = ev.success != 0;
*tag = cq_tag;
if (cq_tag->FinalizeResult(tag, ok)) {
+ gpr_log(GPR_ERROR, "alright got tag %p", *tag);
return GOT_EVENT;
}
break;
diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc
index 9f4ec3e4ab..2f5493dbf8 100644
--- a/src/cpp/server/server_cc.cc
+++ b/src/cpp/server/server_cc.cc
@@ -41,8 +41,10 @@
#include <grpcpp/support/time.h>
#include "src/core/ext/transport/inproc/inproc_transport.h"
+#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/profiling/timers.h"
#include "src/core/lib/surface/call.h"
+#include "src/core/lib/surface/completion_queue.h"
#include "src/cpp/client/create_channel_internal.h"
#include "src/cpp/server/health/default_health_check_service.h"
#include "src/cpp/thread_manager/thread_manager.h"
@@ -753,6 +755,7 @@ ServerInterface::BaseAsyncRequest::BaseAsyncRequest(
/* Set up interception state partially for the receive ops. call_wrapper_ is
* not filled at this point, but it will be filled before the interceptors are
* run. */
+ gpr_log(GPR_ERROR, "Created base async request");
interceptor_methods_.SetCall(&call_wrapper_);
interceptor_methods_.SetReverse();
call_cq_->RegisterAvalanching(); // This op will trigger more ops
@@ -764,9 +767,9 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() {
bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
bool* status) {
+ gpr_log(GPR_ERROR, "in finalize result");
if (done_intercepting_) {
- delete static_cast<Alarm*>(dummy_alarm_);
- dummy_alarm_ = nullptr;
+ gpr_log(GPR_ERROR, "done running interceptors");
*tag = tag_;
if (delete_on_finalize_) {
delete this;
@@ -785,6 +788,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
stream_->BindCall(&call_wrapper_);
if (*status && call_ && call_wrapper_.server_rpc_info()) {
+ gpr_log(GPR_ERROR, "here");
done_intercepting_ = true;
// Set interception point for RECV INITIAL METADATA
interceptor_methods_.AddInterceptionHookPoint(
@@ -799,6 +803,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
// There were interceptors to be run, so
// ContinueFinalizeResultAfterInterception will be run when interceptors
// are done.
+ gpr_log(GPR_ERROR, "don't return this tag");
return false;
}
}
@@ -814,12 +819,15 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
void ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception() {
+ gpr_log(GPR_ERROR, "continue finalize result");
context_->BeginCompletionOp(&call_wrapper_);
// Queue a tag which will be returned immediately
- dummy_alarm_ = new Alarm();
- static_cast<Alarm*>(dummy_alarm_)
- ->Set(notification_cq_,
- g_core_codegen_interface->gpr_time_0(GPR_CLOCK_MONOTONIC), this);
+ grpc_core::ExecCtx exec_ctx;
+ grpc_cq_begin_op(notification_cq_->cq(), this);
+ grpc_cq_end_op(
+ notification_cq_->cq(), this, GRPC_ERROR_NONE,
+ [](void* arg, grpc_cq_completion* completion) { delete completion; },
+ nullptr, new grpc_cq_completion());
}
ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest(
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
index bc6211517d..c44a025f82 100644
--- a/test/cpp/end2end/interceptors_util.h
+++ b/test/cpp/end2end/interceptors_util.h
@@ -174,5 +174,136 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel) {
}
}
+bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
+ string key, string value) {
+ for (const auto& pair : map) {
+ if (pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue")) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); }
+
+class Verifier {
+ public:
+ Verifier() : lambda_run_(false) {}
+ // Expect sets the expected ok value for a specific tag
+ Verifier& Expect(int i, bool expect_ok) {
+ return ExpectUnless(i, expect_ok, false);
+ }
+ // ExpectUnless sets the expected ok value for a specific tag
+ // unless the tag was already marked seen (as a result of ExpectMaybe)
+ Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
+ if (!seen) {
+ expectations_[tag(i)] = expect_ok;
+ }
+ return *this;
+ }
+ // ExpectMaybe sets the expected ok value for a specific tag, but does not
+ // require it to appear
+ // If it does, sets *seen to true
+ Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
+ if (!*seen) {
+ maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
+ }
+ return *this;
+ }
+
+ // Next waits for 1 async tag to complete, checks its
+ // expectations, and returns the tag
+ int Next(CompletionQueue* cq, bool ignore_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ GotTag(got_tag, ok, ignore_ok);
+ return detag(got_tag);
+ }
+
+ template <typename T>
+ CompletionQueue::NextStatus DoOnceThenAsyncNext(
+ CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
+ std::function<void(void)> lambda) {
+ if (lambda_run_) {
+ return cq->AsyncNext(got_tag, ok, deadline);
+ } else {
+ lambda_run_ = true;
+ return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
+ }
+ }
+
+ // Verify keeps calling Next until all currently set
+ // expected tags are complete
+ void Verify(CompletionQueue* cq) { Verify(cq, false); }
+
+ // This version of Verify allows optionally ignoring the
+ // outcome of the expectation
+ void Verify(CompletionQueue* cq, bool ignore_ok) {
+ GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
+ while (!expectations_.empty()) {
+ Next(cq, ignore_ok);
+ }
+ }
+
+ // This version of Verify stops after a certain deadline, and uses the
+ // DoThenAsyncNext API
+ // to call the lambda
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline,
+ const std::function<void(void)>& lambda) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::TIMEOUT);
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::GOT_EVENT);
+ GotTag(got_tag, ok, false);
+ }
+ }
+ }
+
+ private:
+ void GotTag(void* got_tag, bool ok, bool ignore_ok) {
+ auto it = expectations_.find(got_tag);
+ if (it != expectations_.end()) {
+ if (!ignore_ok) {
+ EXPECT_EQ(it->second, ok);
+ }
+ expectations_.erase(it);
+ } else {
+ auto it2 = maybe_expectations_.find(got_tag);
+ if (it2 != maybe_expectations_.end()) {
+ if (it2->second.seen != nullptr) {
+ EXPECT_FALSE(*it2->second.seen);
+ *it2->second.seen = true;
+ }
+ if (!ignore_ok) {
+ EXPECT_EQ(it2->second.ok, ok);
+ }
+ } else {
+ gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
+ abort();
+ }
+ }
+ }
+
+ struct MaybeExpect {
+ bool ok;
+ bool* seen;
+ };
+
+ std::map<void*, bool> expectations_;
+ std::map<void*, MaybeExpect> maybe_expectations_;
+ bool lambda_run_;
+};
+
} // namespace testing
} // namespace grpc \ No newline at end of file
diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc
index 57b85a479e..956aec9359 100644
--- a/test/cpp/end2end/server_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/server_interceptors_end2end_test.cc
@@ -269,6 +269,162 @@ TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {};
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel = CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncEcho(&cli_ctx, send_request, cq.get()));
+
+ service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(),
+ cq.get(), tag(2));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) {
+ DummyInterceptor::Reset();
+ int port = grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + std::to_string(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel = CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1)));
+
+ service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq.get());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ cli_stream->Write(send_request, tag(3));
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq.get());
+
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ grpc_recycle_unused_port(port);
+}
+
} // namespace
} // namespace testing
} // namespace grpc