diff options
-rw-r--r-- | include/grpc++/impl/service_type.h | 66 | ||||
-rw-r--r-- | include/grpc++/server.h | 15 | ||||
-rw-r--r-- | include/grpc++/stream.h | 14 | ||||
-rw-r--r-- | src/compiler/cpp_generator.cc | 15 | ||||
-rw-r--r-- | src/cpp/server/server.cc | 69 |
5 files changed, 143 insertions, 36 deletions
diff --git a/include/grpc++/impl/service_type.h b/include/grpc++/impl/service_type.h index 30654553ad..19432522df 100644 --- a/include/grpc++/impl/service_type.h +++ b/include/grpc++/impl/service_type.h @@ -34,10 +34,18 @@ #ifndef __GRPCPP_IMPL_SERVICE_TYPE_H__ #define __GRPCPP_IMPL_SERVICE_TYPE_H__ +namespace google { +namespace protobuf { +class Message; +} // namespace protobuf +} // namespace google + namespace grpc { class RpcService; class Server; +class ServerContext; +class Status; class SynchronousService { public: @@ -45,19 +53,69 @@ class SynchronousService { virtual RpcService* service() = 0; }; +class ServerAsyncStreamingInterface { + public: + virtual ~ServerAsyncStreamingInterface() {} + + virtual void SendInitialMetadata(void* tag) = 0; + virtual void Finish(const Status& status, void* tag) = 0; +}; + class AsynchronousService { public: - AsynchronousService(CompletionQueue* cq, const char** method_names, size_t method_count) : cq_(cq), method_names_(method_names), method_count_(method_count) {} + // this is Server, but in disguise to avoid a link dependency + class DispatchImpl { + public: + virtual void RequestAsyncCall(void* registered_method, + ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) = 0; + }; + + AsynchronousService(CompletionQueue* cq, const char** method_names, + size_t method_count) + : cq_(cq), method_names_(method_names), method_count_(method_count) {} + + ~AsynchronousService(); CompletionQueue* completion_queue() const { return cq_; } + protected: + void RequestAsyncUnary(int index, ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + dispatch_impl_->RequestAsyncCall(request_args_[index], context, request, + stream, cq, tag); + } + void RequestClientStreaming(int index, ServerContext* context, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + dispatch_impl_->RequestAsyncCall(request_args_[index], context, nullptr, + stream, cq, tag); + } + void RequestServerStreaming(int index, ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + dispatch_impl_->RequestAsyncCall(request_args_[index], context, request, + stream, cq, tag); + } + void RequestBidiStreaming(int index, ServerContext* context, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + dispatch_impl_->RequestAsyncCall(request_args_[index], context, nullptr, + stream, cq, tag); + } + private: friend class Server; CompletionQueue* const cq_; - Server* server_ = nullptr; - const char**const method_names_; + DispatchImpl* dispatch_impl_ = nullptr; + const char** const method_names_; size_t method_count_; - std::vector<void*> request_args_; + void** request_args_ = nullptr; }; } // namespace grpc diff --git a/include/grpc++/server.h b/include/grpc++/server.h index 77aac75076..8050ef8c9d 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -42,6 +42,7 @@ #include <grpc++/completion_queue.h> #include <grpc++/config.h> #include <grpc++/impl/call.h> +#include <grpc++/impl/service_type.h> #include <grpc++/status.h> struct grpc_server; @@ -60,7 +61,8 @@ class ServerCredentials; class ThreadPoolInterface; // Currently it only supports handling rpcs in a single thread. -class Server final : private CallHook { +class Server final : private CallHook, + private AsynchronousService::DispatchImpl { public: ~Server(); @@ -70,7 +72,8 @@ class Server final : private CallHook { private: friend class ServerBuilder; - class MethodRequestData; + class SyncRequest; + class AsyncRequest; // ServerBuilder use only Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, @@ -91,6 +94,12 @@ class Server final : private CallHook { void PerformOpsOnCall(CallOpBuffer* ops, Call* call) override; + // DispatchImpl + void RequestAsyncCall(void* registered_method, ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag); + // Completion queue. CompletionQueue cq_; @@ -102,7 +111,7 @@ class Server final : private CallHook { int num_running_cb_; std::condition_variable callback_cv_; - std::list<MethodRequestData> methods_; + std::list<SyncRequest> sync_methods_; // Pointer to the c grpc server. grpc_server* server_; diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h index 6dc05bc9a6..c013afb141 100644 --- a/include/grpc++/stream.h +++ b/include/grpc++/stream.h @@ -39,6 +39,7 @@ #include <grpc++/completion_queue.h> #include <grpc++/server_context.h> #include <grpc++/impl/call.h> +#include <grpc++/impl/service_type.h> #include <grpc++/status.h> #include <grpc/support/log.h> @@ -370,15 +371,6 @@ class ClientAsyncStreamingInterface { virtual void Finish(Status* status, void* tag) = 0; }; -class ServerAsyncStreamingInterface { - public: - virtual ~ServerAsyncStreamingInterface() {} - - virtual void SendInitialMetadata(void* tag) = 0; - - virtual void Finish(const Status& status, void* tag) = 0; -}; - // An interface that yields a sequence of R messages. template <class R> class AsyncReaderInterface { @@ -580,11 +572,11 @@ class ClientAsyncReaderWriter final : public ClientAsyncStreamingInterface, // TODO(yangg) Move out of stream.h template <class W> -class ServerAsyncResponseWriter final { +class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface { public: explicit ServerAsyncResponseWriter(Call* call) : call_(call) {} - virtual void Write(const W& msg, void* tag) override { + virtual void Write(const W& msg, void* tag) { CallOpBuffer buf; buf.Reset(tag); buf.AddSendMessage(msg); diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 4a31ff949e..d1a7bd2b88 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -374,7 +374,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer, "::grpc::ClientContext* context, " "const $Request$& request, $Response$* response) {\n"); printer->Print(*vars, - "return ::grpc::BlockingUnaryCall(channel()," + " return ::grpc::BlockingUnaryCall(channel()," "::grpc::RpcMethod($Service$_method_names[$Idx$]), " "context, request, response);\n" "}\n\n"); @@ -484,6 +484,9 @@ void PrintSourceServerAsyncMethod( "$Request$* request, " "::grpc::ServerAsyncResponseWriter< $Response$>* response, " "::grpc::CompletionQueue* cq, void* tag) {\n"); + printer->Print( + *vars, + " AsynchronousService::RequestAsyncUnary($Idx$, context, request, response, cq, tag);\n"); printer->Print("}\n\n"); } else if (ClientOnlyStreaming(method)) { printer->Print(*vars, @@ -491,6 +494,9 @@ void PrintSourceServerAsyncMethod( "::grpc::ServerContext* context, " "::grpc::ServerAsyncReader< $Request$>* reader, " "::grpc::CompletionQueue* cq, void* tag) {\n"); + printer->Print( + *vars, + " AsynchronousService::RequestClientStreaming($Idx$, context, reader, cq, tag);\n"); printer->Print("}\n\n"); } else if (ServerOnlyStreaming(method)) { printer->Print(*vars, @@ -499,6 +505,9 @@ void PrintSourceServerAsyncMethod( "$Request$* request, " "::grpc::ServerAsyncWriter< $Response$>* writer, " "::grpc::CompletionQueue* cq, void* tag) {\n"); + printer->Print( + *vars, + " AsynchronousService::RequestServerStreaming($Idx$, context, request, writer, cq, tag);\n"); printer->Print("}\n\n"); } else if (BidiStreaming(method)) { printer->Print( @@ -507,6 +516,9 @@ void PrintSourceServerAsyncMethod( "::grpc::ServerContext* context, " "::grpc::ServerAsyncReaderWriter< $Response$, $Request$>* stream, " "::grpc::CompletionQueue* cq, void *tag) {\n"); + printer->Print( + *vars, + " AsynchronousService::RequestBidiStreaming($Idx$, context, stream, cq, tag);\n"); printer->Print("}\n\n"); } } @@ -548,6 +560,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer, " delete service_;\n" "}\n\n"); for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); PrintSourceServerMethod(printer, service->method(i), vars); PrintSourceServerAsyncMethod(printer, service->method(i), vars); } diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index 20dd135a86..b4620868b8 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -93,24 +93,26 @@ bool Server::RegisterService(RpcService* service) { method->name()); return false; } - methods_.emplace_back(method, tag); + sync_methods_.emplace_back(method, tag); } return true; } bool Server::RegisterAsyncService(AsynchronousService* service) { - GPR_ASSERT(service->server_ == nullptr && "Can only register an asynchronous service against one server."); - service->server_ = this; - service->request_args_.reserve(service->method_count_); + GPR_ASSERT(service->dispatch_impl_ == nullptr && + "Can only register an asynchronous service against one server."); + service->dispatch_impl_ = this; + service->request_args_ = new void* [service->method_count_]; for (size_t i = 0; i < service->method_count_; ++i) { - void* tag = grpc_server_register_method(server_, service->method_names_[i], nullptr, - service->completion_queue()->cq()); + void* tag = + grpc_server_register_method(server_, service->method_names_[i], nullptr, + service->completion_queue()->cq()); if (!tag) { gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", service->method_names_[i]); return false; } - service->request_args_.push_back(tag); + service->request_args_[i] = tag; } return true; } @@ -124,9 +126,9 @@ int Server::AddPort(const grpc::string& addr) { } } -class Server::MethodRequestData final : public CompletionQueueTag { +class Server::SyncRequest final : public CompletionQueueTag { public: - MethodRequestData(RpcServiceMethod* method, void* tag) + SyncRequest(RpcServiceMethod* method, void* tag) : method_(method), tag_(tag), has_request_payload_(method->method_type() == RpcMethod::NORMAL_RPC || @@ -138,13 +140,13 @@ class Server::MethodRequestData final : public CompletionQueueTag { grpc_metadata_array_init(&request_metadata_); } - static MethodRequestData* Wait(CompletionQueue* cq, bool* ok) { + static SyncRequest* Wait(CompletionQueue* cq, bool* ok) { void* tag = nullptr; *ok = false; if (!cq->Next(&tag, ok)) { return nullptr; } - auto* mrd = static_cast<MethodRequestData*>(tag); + auto* mrd = static_cast<SyncRequest*>(tag); GPR_ASSERT(mrd->in_flight_); return mrd; } @@ -162,9 +164,9 @@ class Server::MethodRequestData final : public CompletionQueueTag { void FinalizeResult(void** tag, bool* status) override {} - class CallData { + class CallData final { public: - explicit CallData(Server* server, MethodRequestData* mrd) + explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), call_(mrd->call_, server, &cq_), ctx_(mrd->deadline_, mrd->request_metadata_.metadata, @@ -239,8 +241,8 @@ bool Server::Start() { grpc_server_start(server_); // Start processing rpcs. - if (!methods_.empty()) { - for (auto& m : methods_) { + if (!sync_methods_.empty()) { + for (auto& m : sync_methods_) { m.Request(server_); } @@ -275,6 +277,39 @@ void Server::PerformOpsOnCall(CallOpBuffer* buf, Call* call) { grpc_call_start_batch(call->call(), ops, nops, buf)); } +class Server::AsyncRequest final : public CompletionQueueTag { + public: + AsyncRequest(Server* server, void* registered_method, ServerContext* ctx, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, CompletionQueue* cq, + void* tag) + : tag_(tag), request_(request), stream_(stream), ctx_(ctx) { + memset(&array_, 0, sizeof(array_)); + grpc_server_request_registered_call( + server->server_, registered_method, &call_, &deadline_, &array_, + request ? &payload_ : nullptr, cq->cq(), this); + } + + void FinalizeResult(void** tag, bool* status) override {} + + private: + void* const tag_; + ::google::protobuf::Message* const request_; + ServerAsyncStreamingInterface* const stream_; + ServerContext* const ctx_; + grpc_call* call_ = nullptr; + gpr_timespec deadline_; + grpc_metadata_array array_; + grpc_byte_buffer* payload_ = nullptr; +}; + +void Server::RequestAsyncCall(void* registered_method, ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + new AsyncRequest(this, registered_method, context, request, stream, cq, tag); +} + void Server::ScheduleCallback() { { std::unique_lock<std::mutex> lock(mu_); @@ -286,11 +321,11 @@ void Server::ScheduleCallback() { void Server::RunRpc() { // Wait for one more incoming rpc. bool ok; - auto* mrd = MethodRequestData::Wait(&cq_, &ok); + auto* mrd = SyncRequest::Wait(&cq_, &ok); if (mrd) { ScheduleCallback(); if (ok) { - MethodRequestData::CallData cd(this, mrd); + SyncRequest::CallData cd(this, mrd); mrd->Request(server_); cd.Run(); |