diff options
-rw-r--r-- | include/grpcpp/impl/codegen/call.h | 5 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/server_interface.h | 8 | ||||
-rw-r--r-- | src/cpp/common/completion_queue_cc.cc | 1 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 20 | ||||
-rw-r--r-- | test/cpp/end2end/interceptors_util.h | 131 | ||||
-rw-r--r-- | test/cpp/end2end/server_interceptors_end2end_test.cc | 156 |
6 files changed, 312 insertions, 9 deletions
diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 505055e7e6..c6894003b6 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -1189,13 +1189,14 @@ class CallOpSet : public CallOpSetInterface, } bool FinalizeResult(void** tag, bool* status) override { - // gpr_log(GPR_ERROR, "finalizing result %p", this); + gpr_log(GPR_ERROR, "finalizing result"); if (done_intercepting_) { // We have already finished intercepting and filling in the results. This // round trip from the core needed to be made because interceptors were // run // gpr_log(GPR_ERROR, "done intercepting"); *tag = return_tag_; + *status = saved_status_; g_core_codegen_interface->grpc_call_unref(call_.call()); return true; } @@ -1206,6 +1207,7 @@ class CallOpSet : public CallOpSetInterface, this->Op4::FinishOp(status); this->Op5::FinishOp(status); this->Op6::FinishOp(status); + saved_status_ = *status; // gpr_log(GPR_ERROR, "done finish ops"); if (RunInterceptorsPostRecv()) { *tag = return_tag_; @@ -1301,6 +1303,7 @@ class CallOpSet : public CallOpSetInterface, Call call_; bool done_intercepting_ = false; InterceptorBatchMethodsImpl interceptor_methods_; + bool saved_status_; }; } // namespace internal diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index aa7dbe8b70..532fff7a54 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -173,8 +173,6 @@ class ServerInterface : public internal::CallHook { internal::Call call_wrapper_; internal::InterceptorBatchMethodsImpl interceptor_methods_; bool done_intercepting_; - void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on - alarm.h here */ }; class RegisteredAsyncRequest : public BaseAsyncRequest { @@ -186,6 +184,7 @@ class ServerInterface : public internal::CallHook { const char* name); virtual bool FinalizeResult(void** tag, bool* status) override { + gpr_log(GPR_ERROR, "finalize registeredasyncrequest"); /* If we are done intercepting, then there is nothing more for us to do */ if (done_intercepting_) { return BaseAsyncRequest::FinalizeResult(tag, status); @@ -239,6 +238,7 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { + gpr_log(GPR_ERROR, "new payload request"); IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), notification_cq); } @@ -248,6 +248,7 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + gpr_log(GPR_ERROR, "finalize PayloadAsyncRequest"); /* If we are done intercepting, then there is nothing more for us to do */ if (done_intercepting_) { return RegisteredAsyncRequest::FinalizeResult(tag, status); @@ -312,6 +313,7 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); + gpr_log(GPR_ERROR, "request async method with payload"); new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq, notification_cq, tag, message); } @@ -322,6 +324,7 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); + gpr_log(GPR_ERROR, "request async method with no payload"); new NoPayloadAsyncRequest(method, this, context, stream, call_cq, notification_cq, tag); } @@ -331,6 +334,7 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { + gpr_log(GPR_ERROR, "request async generic call"); new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag, true); } diff --git a/src/cpp/common/completion_queue_cc.cc b/src/cpp/common/completion_queue_cc.cc index 6893201e2e..5dfcfa2984 100644 --- a/src/cpp/common/completion_queue_cc.cc +++ b/src/cpp/common/completion_queue_cc.cc @@ -64,6 +64,7 @@ CompletionQueue::NextStatus CompletionQueue::AsyncNextInternal( *ok = ev.success != 0; *tag = cq_tag; if (cq_tag->FinalizeResult(tag, ok)) { + gpr_log(GPR_ERROR, "alright got tag %p", *tag); return GOT_EVENT; } break; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 9f4ec3e4ab..2f5493dbf8 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -41,8 +41,10 @@ #include <grpcpp/support/time.h> #include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" #include "src/cpp/client/create_channel_internal.h" #include "src/cpp/server/health/default_health_check_service.h" #include "src/cpp/thread_manager/thread_manager.h" @@ -753,6 +755,7 @@ ServerInterface::BaseAsyncRequest::BaseAsyncRequest( /* Set up interception state partially for the receive ops. call_wrapper_ is * not filled at this point, but it will be filled before the interceptors are * run. */ + gpr_log(GPR_ERROR, "Created base async request"); interceptor_methods_.SetCall(&call_wrapper_); interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops @@ -764,9 +767,9 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + gpr_log(GPR_ERROR, "in finalize result"); if (done_intercepting_) { - delete static_cast<Alarm*>(dummy_alarm_); - dummy_alarm_ = nullptr; + gpr_log(GPR_ERROR, "done running interceptors"); *tag = tag_; if (delete_on_finalize_) { delete this; @@ -785,6 +788,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, stream_->BindCall(&call_wrapper_); if (*status && call_ && call_wrapper_.server_rpc_info()) { + gpr_log(GPR_ERROR, "here"); done_intercepting_ = true; // Set interception point for RECV INITIAL METADATA interceptor_methods_.AddInterceptionHookPoint( @@ -799,6 +803,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, // There were interceptors to be run, so // ContinueFinalizeResultAfterInterception will be run when interceptors // are done. + gpr_log(GPR_ERROR, "don't return this tag"); return false; } } @@ -814,12 +819,15 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, void ServerInterface::BaseAsyncRequest:: ContinueFinalizeResultAfterInterception() { + gpr_log(GPR_ERROR, "continue finalize result"); context_->BeginCompletionOp(&call_wrapper_); // Queue a tag which will be returned immediately - dummy_alarm_ = new Alarm(); - static_cast<Alarm*>(dummy_alarm_) - ->Set(notification_cq_, - g_core_codegen_interface->gpr_time_0(GPR_CLOCK_MONOTONIC), this); + grpc_core::ExecCtx exec_ctx; + grpc_cq_begin_op(notification_cq_->cq(), this); + grpc_cq_end_op( + notification_cq_->cq(), this, GRPC_ERROR_NONE, + [](void* arg, grpc_cq_completion* completion) { delete completion; }, + nullptr, new grpc_cq_completion()); } ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index bc6211517d..c44a025f82 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -174,5 +174,136 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel) { } } +bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, + string key, string value) { + for (const auto& pair : map) { + if (pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue")) { + return true; + } + } + return false; +} + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); } + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template <typename T> + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function<void(void)> lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function<void(void)>& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map<void*, bool> expectations_; + std::map<void*, MaybeExpect> maybe_expectations_; + bool lambda_run_; +}; + } // namespace testing } // namespace grpc
\ No newline at end of file diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc index 57b85a479e..956aec9359 100644 --- a/test/cpp/end2end/server_interceptors_end2end_test.cc +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -269,6 +269,162 @@ TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {}; + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, cq.get())); + + service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(), + cq.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1))); + + service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq.get()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq.get()); + + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + } // namespace } // namespace testing } // namespace grpc |