diff options
author | Yang Gao <yangg@google.com> | 2015-02-18 16:01:13 -0800 |
---|---|---|
committer | Yang Gao <yangg@google.com> | 2015-02-18 16:01:13 -0800 |
commit | d9f3dfe7ebb251baab7494f5512c5fa68399ae7a (patch) | |
tree | 32007f224f72365bcda80ac7e6c92b30dbf40dd4 | |
parent | 646f60153a83a8fc7d733d36d7bf933811e1c8ea (diff) | |
parent | 2627e4e0a917cc438bff186d0bea2bee030ac98a (diff) |
Merge pull request #581 from ctiller/an-update-on-c++
Server side cancellation receive support for C++
-rw-r--r-- | include/grpc++/async_unary_call.h | 4 | ||||
-rw-r--r-- | include/grpc++/completion_queue.h | 9 | ||||
-rw-r--r-- | include/grpc++/impl/call.h | 4 | ||||
-rw-r--r-- | include/grpc++/server_context.h | 11 | ||||
-rw-r--r-- | include/grpc++/stream.h | 8 | ||||
-rw-r--r-- | src/compiler/cpp_generator.cc | 2 | ||||
-rw-r--r-- | src/cpp/client/client_unary_call.cc | 1 | ||||
-rw-r--r-- | src/cpp/common/call.cc | 3 | ||||
-rw-r--r-- | src/cpp/common/completion_queue.cc | 42 | ||||
-rw-r--r-- | src/cpp/server/server.cc | 17 | ||||
-rw-r--r-- | src/cpp/server/server_context.cc | 71 | ||||
-rw-r--r-- | test/cpp/end2end/async_end2end_test.cc | 12 |
12 files changed, 144 insertions, 40 deletions
diff --git a/include/grpc++/async_unary_call.h b/include/grpc++/async_unary_call.h index 105250ce9d..b4a654c4a9 100644 --- a/include/grpc++/async_unary_call.h +++ b/include/grpc++/async_unary_call.h @@ -111,8 +111,6 @@ class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface { if (status.IsOk()) { finish_buf_.AddSendMessage(msg); } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } @@ -124,8 +122,6 @@ class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface { finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } diff --git a/include/grpc++/completion_queue.h b/include/grpc++/completion_queue.h index f1b4962d1b..0075482d71 100644 --- a/include/grpc++/completion_queue.h +++ b/include/grpc++/completion_queue.h @@ -55,6 +55,7 @@ class ServerReaderWriter; class CompletionQueue; class Server; +class ServerContext; class CompletionQueueTag { public: @@ -62,7 +63,9 @@ class CompletionQueueTag { // Called prior to returning from Next(), return value // is the status of the operation (return status is the default thing // to do) - virtual void FinalizeResult(void **tag, bool *status) = 0; + // If this function returns false, the tag is dropped and not returned + // from the completion queue + virtual bool FinalizeResult(void **tag, bool *status) = 0; }; // grpc_completion_queue wrapper class @@ -99,6 +102,7 @@ class CompletionQueue { template <class R, class W> friend class ::grpc::ServerReaderWriter; friend class ::grpc::Server; + friend class ::grpc::ServerContext; friend Status BlockingUnaryCall(ChannelInterface *channel, const RpcMethod &method, ClientContext *context, @@ -109,6 +113,9 @@ class CompletionQueue { // Cannot be mixed with calls to Next(). bool Pluck(CompletionQueueTag *tag); + // Does a single polling pluck on tag + void TryPluck(CompletionQueueTag *tag); + grpc_completion_queue *cq_; // owned }; diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h index 7ba5d16bf3..341710f7a2 100644 --- a/include/grpc++/impl/call.h +++ b/include/grpc++/impl/call.h @@ -65,7 +65,7 @@ class CallOpBuffer : public CompletionQueueTag { void AddSendInitialMetadata( std::multimap<grpc::string, grpc::string> *metadata); void AddSendInitialMetadata(ClientContext *ctx); - void AddRecvInitialMetadata(ClientContext* ctx); + void AddRecvInitialMetadata(ClientContext *ctx); void AddSendMessage(const google::protobuf::Message &message); void AddRecvMessage(google::protobuf::Message *message); void AddClientSendClose(); @@ -80,7 +80,7 @@ class CallOpBuffer : public CompletionQueueTag { void FillOps(grpc_op *ops, size_t *nops); // Called by completion queue just prior to returning from Next() or Pluck() - void FinalizeResult(void **tag, bool *status) override; + bool FinalizeResult(void **tag, bool *status) override; bool got_message = false; diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h index 520278f949..d327d8b41e 100644 --- a/include/grpc++/server_context.h +++ b/include/grpc++/server_context.h @@ -60,7 +60,9 @@ class ServerWriter; template <class R, class W> class ServerReaderWriter; +class Call; class CallOpBuffer; +class CompletionQueue; class Server; // Interface of server side rpc context. @@ -76,6 +78,8 @@ class ServerContext final { void AddInitialMetadata(const grpc::string& key, const grpc::string& value); void AddTrailingMetadata(const grpc::string& key, const grpc::string& value); + bool IsCancelled(); + const std::multimap<grpc::string, grpc::string>& client_metadata() { return client_metadata_; } @@ -97,11 +101,18 @@ class ServerContext final { template <class R, class W> friend class ::grpc::ServerReaderWriter; + class CompletionOp; + + void BeginCompletionOp(Call* call); + ServerContext(gpr_timespec deadline, grpc_metadata* metadata, size_t metadata_count); + CompletionOp* completion_op_ = nullptr; + std::chrono::system_clock::time_point deadline_; grpc_call* call_ = nullptr; + CompletionQueue* cq_ = nullptr; bool sent_initial_metadata_ = false; std::multimap<grpc::string, grpc::string> client_metadata_; std::multimap<grpc::string, grpc::string> initial_metadata_; diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h index 01deac2ce1..cd95ff7c92 100644 --- a/include/grpc++/stream.h +++ b/include/grpc++/stream.h @@ -582,8 +582,6 @@ class ServerAsyncReader : public ServerAsyncStreamingInterface, if (status.IsOk()) { finish_buf_.AddSendMessage(msg); } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } @@ -595,8 +593,6 @@ class ServerAsyncReader : public ServerAsyncStreamingInterface, finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } @@ -643,8 +639,6 @@ class ServerAsyncWriter : public ServerAsyncStreamingInterface, finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } @@ -699,8 +693,6 @@ class ServerAsyncReaderWriter : public ServerAsyncStreamingInterface, finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } - bool cancelled = false; - finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index b73b000a1c..f10824e6b0 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -386,7 +386,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer, "const $Request$& request, " "::grpc::CompletionQueue* cq, void* tag) {\n"); printer->Print(*vars, - " return new ClientAsyncResponseReader< $Response$>(" + " return new ::grpc::ClientAsyncResponseReader< $Response$>(" "channel(), cq, " "::grpc::RpcMethod($Service$_method_names[$Idx$]), " "context, request, tag);\n" diff --git a/src/cpp/client/client_unary_call.cc b/src/cpp/client/client_unary_call.cc index 08491f40f7..684b3cbadb 100644 --- a/src/cpp/client/client_unary_call.cc +++ b/src/cpp/client/client_unary_call.cc @@ -60,4 +60,5 @@ Status BlockingUnaryCall(ChannelInterface *channel, const RpcMethod &method, GPR_ASSERT((cq.Pluck(&buf) && buf.got_message) || !status.IsOk()); return status; } + } // namespace grpc diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc index e29c6a053d..e6a20a252d 100644 --- a/src/cpp/common/call.cc +++ b/src/cpp/common/call.cc @@ -231,7 +231,7 @@ void CallOpBuffer::FillOps(grpc_op* ops, size_t* nops) { } } -void CallOpBuffer::FinalizeResult(void** tag, bool* status) { +bool CallOpBuffer::FinalizeResult(void** tag, bool* status) { // Release send buffers. if (send_message_buf_) { grpc_byte_buffer_destroy(send_message_buf_); @@ -274,6 +274,7 @@ void CallOpBuffer::FinalizeResult(void** tag, bool* status) { if (recv_closed_) { *recv_closed_ = cancelled_buf_ != 0; } + return true; } Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) diff --git a/src/cpp/common/completion_queue.cc b/src/cpp/common/completion_queue.cc index c7d883d5b0..414966c1cd 100644 --- a/src/cpp/common/completion_queue.cc +++ b/src/cpp/common/completion_queue.cc @@ -43,7 +43,7 @@ namespace grpc { CompletionQueue::CompletionQueue() { cq_ = grpc_completion_queue_create(); } -CompletionQueue::CompletionQueue(grpc_completion_queue *take) : cq_(take) {} +CompletionQueue::CompletionQueue(grpc_completion_queue* take) : cq_(take) {} CompletionQueue::~CompletionQueue() { grpc_completion_queue_destroy(cq_); } @@ -52,34 +52,48 @@ void CompletionQueue::Shutdown() { grpc_completion_queue_shutdown(cq_); } // Helper class so we can declare a unique_ptr with grpc_event class EventDeleter { public: - void operator()(grpc_event *ev) { + void operator()(grpc_event* ev) { if (ev) grpc_event_finish(ev); } }; -bool CompletionQueue::Next(void **tag, bool *ok) { +bool CompletionQueue::Next(void** tag, bool* ok) { std::unique_ptr<grpc_event, EventDeleter> ev; - ev.reset(grpc_completion_queue_next(cq_, gpr_inf_future)); - if (ev->type == GRPC_QUEUE_SHUTDOWN) { - return false; + for (;;) { + ev.reset(grpc_completion_queue_next(cq_, gpr_inf_future)); + if (ev->type == GRPC_QUEUE_SHUTDOWN) { + return false; + } + auto cq_tag = static_cast<CompletionQueueTag*>(ev->tag); + *ok = ev->data.op_complete == GRPC_OP_OK; + *tag = cq_tag; + if (cq_tag->FinalizeResult(tag, ok)) { + return true; + } } - auto cq_tag = static_cast<CompletionQueueTag *>(ev->tag); - *ok = ev->data.op_complete == GRPC_OP_OK; - *tag = cq_tag; - cq_tag->FinalizeResult(tag, ok); - return true; } -bool CompletionQueue::Pluck(CompletionQueueTag *tag) { +bool CompletionQueue::Pluck(CompletionQueueTag* tag) { std::unique_ptr<grpc_event, EventDeleter> ev; ev.reset(grpc_completion_queue_pluck(cq_, tag, gpr_inf_future)); bool ok = ev->data.op_complete == GRPC_OP_OK; - void *ignored = tag; - tag->FinalizeResult(&ignored, &ok); + void* ignored = tag; + GPR_ASSERT(tag->FinalizeResult(&ignored, &ok)); GPR_ASSERT(ignored == tag); return ok; } +void CompletionQueue::TryPluck(CompletionQueueTag* tag) { + std::unique_ptr<grpc_event, EventDeleter> ev; + + ev.reset(grpc_completion_queue_pluck(cq_, tag, gpr_inf_past)); + if (!ev) return; + bool ok = ev->data.op_complete == GRPC_OP_OK; + void* ignored = tag; + // the tag must be swallowed if using TryPluck + GPR_ASSERT(!tag->FinalizeResult(&ignored, &ok)); +} + } // namespace grpc diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index da98cf5ce0..7cccc58afd 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -163,10 +163,11 @@ class Server::SyncRequest final : public CompletionQueueTag { this)); } - void FinalizeResult(void** tag, bool* status) override { + bool FinalizeResult(void** tag, bool* status) override { if (!*status) { grpc_completion_queue_destroy(cq_); } + return true; } class CallData final { @@ -204,6 +205,7 @@ class Server::SyncRequest final : public CompletionQueueTag { if (has_response_payload_) { res.reset(method_->AllocateResponseProto()); } + ctx_.BeginCompletionOp(&call_); auto status = method_->handler()->RunHandler( MethodHandler::HandlerParameter(&call_, &ctx_, req.get(), res.get())); CallOpBuffer buf; @@ -214,10 +216,12 @@ class Server::SyncRequest final : public CompletionQueueTag { buf.AddSendMessage(*res); } buf.AddServerSendStatus(&ctx_.trailing_metadata_, status); - bool cancelled; - buf.AddServerRecvClose(&cancelled); call_.PerformOps(&buf); GPR_ASSERT(cq_.Pluck(&buf)); + void* ignored_tag; + bool ignored_ok; + cq_.Shutdown(); + GPR_ASSERT(cq_.Next(&ignored_tag, &ignored_ok) == false); } private: @@ -310,11 +314,11 @@ class Server::AsyncRequest final : public CompletionQueueTag { grpc_metadata_array_destroy(&array_); } - void FinalizeResult(void** tag, bool* status) override { + bool FinalizeResult(void** tag, bool* status) override { *tag = tag_; if (*status && request_) { if (payload_) { - *status = *status && DeserializeProto(payload_, request_); + *status = DeserializeProto(payload_, request_); } else { *status = false; } @@ -331,8 +335,11 @@ class Server::AsyncRequest final : public CompletionQueueTag { } ctx_->call_ = call_; Call call(call_, server_, cq_); + ctx_->BeginCompletionOp(&call); + // just the pointers inside call are copied here stream_->BindCall(&call); delete this; + return true; } private: diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 10cce450d7..1aa18bcac5 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -32,15 +32,67 @@ */ #include <grpc++/server_context.h> + +#include <mutex> + #include <grpc++/impl/call.h> #include <grpc/grpc.h> +#include <grpc/support/log.h> #include "src/cpp/util/time.h" namespace grpc { +// CompletionOp + +class ServerContext::CompletionOp final : public CallOpBuffer { + public: + CompletionOp(); + bool FinalizeResult(void** tag, bool* status) override; + + bool CheckCancelled(CompletionQueue* cq); + + void Unref(); + + private: + std::mutex mu_; + int refs_ = 2; // initial refs: one in the server context, one in the cq + bool finalized_ = false; + bool cancelled_ = false; +}; + +ServerContext::CompletionOp::CompletionOp() { AddServerRecvClose(&cancelled_); } + +void ServerContext::CompletionOp::Unref() { + std::unique_lock<std::mutex> lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + delete this; + } +} + +bool ServerContext::CompletionOp::CheckCancelled(CompletionQueue* cq) { + cq->TryPluck(this); + std::lock_guard<std::mutex> g(mu_); + return finalized_ ? cancelled_ : false; +} + +bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { + GPR_ASSERT(CallOpBuffer::FinalizeResult(tag, status)); + std::unique_lock<std::mutex> lock(mu_); + finalized_ = true; + if (!*status) cancelled_ = true; + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return false; +} + +// ServerContext body + ServerContext::ServerContext() {} -ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata *metadata, +ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata* metadata, size_t metadata_count) : deadline_(Timespec2Timepoint(deadline)) { for (size_t i = 0; i < metadata_count; i++) { @@ -55,16 +107,29 @@ ServerContext::~ServerContext() { if (call_) { grpc_call_destroy(call_); } + if (completion_op_) { + completion_op_->Unref(); + } +} + +void ServerContext::BeginCompletionOp(Call* call) { + GPR_ASSERT(!completion_op_); + completion_op_ = new CompletionOp(); + call->PerformOps(completion_op_); } void ServerContext::AddInitialMetadata(const grpc::string& key, - const grpc::string& value) { + const grpc::string& value) { initial_metadata_.insert(std::make_pair(key, value)); } void ServerContext::AddTrailingMetadata(const grpc::string& key, - const grpc::string& value) { + const grpc::string& value) { trailing_metadata_.insert(std::make_pair(key, value)); } +bool ServerContext::IsCancelled() { + return completion_op_ && completion_op_->CheckCancelled(cq_); +} + } // namespace grpc diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index fe31289661..79160bfaa5 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -91,7 +91,17 @@ class AsyncEnd2endTest : public ::testing::Test { server_ = builder.BuildAndStart(); } - void TearDown() override { server_->Shutdown(); } + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cli_cq_.Shutdown(); + srv_cq_.Shutdown(); + while (cli_cq_.Next(&ignored_tag, &ignored_ok)) + ; + while (srv_cq_.Next(&ignored_tag, &ignored_ok)) + ; + } void ResetStub() { std::shared_ptr<ChannelInterface> channel = |