aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/grpcpp/impl/codegen/byte_buffer.h8
-rw-r--r--include/grpcpp/impl/codegen/callback_common.h19
-rw-r--r--include/grpcpp/impl/codegen/server_callback.h774
-rw-r--r--include/grpcpp/impl/codegen/server_context.h21
-rw-r--r--src/compiler/cpp_generator.cc69
-rw-r--r--src/cpp/server/server_cc.cc7
-rw-r--r--src/cpp/server/server_context.cc47
-rw-r--r--test/cpp/codegen/compiler_test_golden56
-rw-r--r--test/cpp/end2end/end2end_test.cc90
-rw-r--r--test/cpp/end2end/test_service_impl.cc314
-rw-r--r--test/cpp/end2end/test_service_impl.h21
11 files changed, 1261 insertions, 165 deletions
diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h
index abba5549b8..53ecb53371 100644
--- a/include/grpcpp/impl/codegen/byte_buffer.h
+++ b/include/grpcpp/impl/codegen/byte_buffer.h
@@ -45,8 +45,10 @@ template <class ServiceType, class RequestType, class ResponseType>
class RpcMethodHandler;
template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
-template <class ServiceType, class RequestType, class ResponseType>
+template <class RequestType, class ResponseType>
class CallbackUnaryHandler;
+template <class RequestType, class ResponseType>
+class CallbackServerStreamingHandler;
template <StatusCode code>
class ErrorMethodHandler;
template <class R>
@@ -156,8 +158,10 @@ class ByteBuffer final {
friend class internal::RpcMethodHandler;
template <class ServiceType, class RequestType, class ResponseType>
friend class internal::ServerStreamingHandler;
- template <class ServiceType, class RequestType, class ResponseType>
+ template <class RequestType, class ResponseType>
friend class internal::CallbackUnaryHandler;
+ template <class RequestType, class ResponseType>
+ friend class ::grpc::internal::CallbackServerStreamingHandler;
template <StatusCode code>
friend class internal::ErrorMethodHandler;
template <class R>
diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h
index f7a24204dc..a3c8c41246 100644
--- a/include/grpcpp/impl/codegen/callback_common.h
+++ b/include/grpcpp/impl/codegen/callback_common.h
@@ -32,6 +32,8 @@ namespace grpc {
namespace internal {
/// An exception-safe way of invoking a user-specified callback function
+// TODO(vjpai): decide whether it is better for this to take a const lvalue
+// parameter or an rvalue parameter, or if it even matters
template <class Func, class... Args>
void CatchingCallback(Func&& func, Args&&... args) {
#if GRPC_ALLOW_EXCEPTIONS
@@ -45,6 +47,20 @@ void CatchingCallback(Func&& func, Args&&... args) {
#endif // GRPC_ALLOW_EXCEPTIONS
}
+template <class ReturnType, class Func, class... Args>
+ReturnType* CatchingReactorCreator(Func&& func, Args&&... args) {
+#if GRPC_ALLOW_EXCEPTIONS
+ try {
+ return func(std::forward<Args>(args)...);
+ } catch (...) {
+ // fail the RPC, don't crash the library
+ return nullptr;
+ }
+#else // GRPC_ALLOW_EXCEPTIONS
+ return func(std::forward<Args>(args)...);
+#endif // GRPC_ALLOW_EXCEPTIONS
+}
+
// The contract on these tags is that they are single-shot. They must be
// constructed and then fired at exactly one point. There is no expectation
// that they can be reused without reconstruction.
@@ -185,8 +201,9 @@ class CallbackWithSuccessTag
void* ignored = ops_;
// Allow a "false" return value from FinalizeResult to silence the
// callback, just as it silences a CQ tag in the async cases
+ auto* ops = ops_;
bool do_callback = ops_->FinalizeResult(&ignored, &ok);
- GPR_CODEGEN_ASSERT(ignored == ops_);
+ GPR_CODEGEN_ASSERT(ignored == ops);
if (do_callback) {
CatchingCallback(func_, ok);
diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h
index b866fc16dc..1854f6ef2f 100644
--- a/include/grpcpp/impl/codegen/server_callback.h
+++ b/include/grpcpp/impl/codegen/server_callback.h
@@ -19,7 +19,9 @@
#ifndef GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H
#define GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H
+#include <atomic>
#include <functional>
+#include <type_traits>
#include <grpcpp/impl/codegen/call.h>
#include <grpcpp/impl/codegen/call_op_set.h>
@@ -32,19 +34,33 @@
namespace grpc {
-// forward declarations
+// Declare base class of all reactors as internal
namespace internal {
-template <class ServiceType, class RequestType, class ResponseType>
-class CallbackUnaryHandler;
+
+class ServerReactor {
+ public:
+ virtual ~ServerReactor() = default;
+ virtual void OnDone() {}
+ virtual void OnCancel() {}
+};
+
} // namespace internal
namespace experimental {
+// Forward declarations
+template <class Request, class Response>
+class ServerReadReactor;
+template <class Request, class Response>
+class ServerWriteReactor;
+template <class Request, class Response>
+class ServerBidiReactor;
+
// For unary RPCs, the exposed controller class is only an interface
// and the actual implementation is an internal class.
class ServerCallbackRpcController {
public:
- virtual ~ServerCallbackRpcController() {}
+ virtual ~ServerCallbackRpcController() = default;
// The method handler must call this function when it is done so that
// the library knows to free its resources
@@ -55,18 +71,193 @@ class ServerCallbackRpcController {
virtual void SendInitialMetadata(std::function<void(bool)>) = 0;
};
+// NOTE: The actual streaming object classes are provided
+// as API only to support mocking. There are no implementations of
+// these class interfaces in the API.
+template <class Request>
+class ServerCallbackReader {
+ public:
+ virtual ~ServerCallbackReader() {}
+ virtual void Finish(Status s) = 0;
+ virtual void SendInitialMetadata() = 0;
+ virtual void Read(Request* msg) = 0;
+
+ protected:
+ template <class Response>
+ void BindReactor(ServerReadReactor<Request, Response>* reactor) {
+ reactor->BindReader(this);
+ }
+};
+
+template <class Response>
+class ServerCallbackWriter {
+ public:
+ virtual ~ServerCallbackWriter() {}
+
+ virtual void Finish(Status s) = 0;
+ virtual void SendInitialMetadata() = 0;
+ virtual void Write(const Response* msg, WriteOptions options) = 0;
+ virtual void WriteAndFinish(const Response* msg, WriteOptions options,
+ Status s) {
+ // Default implementation that can/should be overridden
+ Write(msg, std::move(options));
+ Finish(std::move(s));
+ };
+
+ protected:
+ template <class Request>
+ void BindReactor(ServerWriteReactor<Request, Response>* reactor) {
+ reactor->BindWriter(this);
+ }
+};
+
+template <class Request, class Response>
+class ServerCallbackReaderWriter {
+ public:
+ virtual ~ServerCallbackReaderWriter() {}
+
+ virtual void Finish(Status s) = 0;
+ virtual void SendInitialMetadata() = 0;
+ virtual void Read(Request* msg) = 0;
+ virtual void Write(const Response* msg, WriteOptions options) = 0;
+ virtual void WriteAndFinish(const Response* msg, WriteOptions options,
+ Status s) {
+ // Default implementation that can/should be overridden
+ Write(msg, std::move(options));
+ Finish(std::move(s));
+ };
+
+ protected:
+ void BindReactor(ServerBidiReactor<Request, Response>* reactor) {
+ reactor->BindStream(this);
+ }
+};
+
+// The following classes are reactors that are to be implemented
+// by the user, returned as the result of the method handler for
+// a callback method, and activated by the call to OnStarted
+template <class Request, class Response>
+class ServerBidiReactor : public internal::ServerReactor {
+ public:
+ ~ServerBidiReactor() = default;
+ virtual void OnStarted(ServerContext*) {}
+ virtual void OnSendInitialMetadataDone(bool ok) {}
+ virtual void OnReadDone(bool ok) {}
+ virtual void OnWriteDone(bool ok) {}
+
+ void StartSendInitialMetadata() { stream_->SendInitialMetadata(); }
+ void StartRead(Request* msg) { stream_->Read(msg); }
+ void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); }
+ void StartWrite(const Response* msg, WriteOptions options) {
+ stream_->Write(msg, std::move(options));
+ }
+ void StartWriteAndFinish(const Response* msg, WriteOptions options,
+ Status s) {
+ stream_->WriteAndFinish(msg, std::move(options), std::move(s));
+ }
+ void StartWriteLast(const Response* msg, WriteOptions options) {
+ StartWrite(msg, std::move(options.set_last_message()));
+ }
+ void Finish(Status s) { stream_->Finish(std::move(s)); }
+
+ private:
+ friend class ServerCallbackReaderWriter<Request, Response>;
+ void BindStream(ServerCallbackReaderWriter<Request, Response>* stream) {
+ stream_ = stream;
+ }
+
+ ServerCallbackReaderWriter<Request, Response>* stream_;
+};
+
+template <class Request, class Response>
+class ServerReadReactor : public internal::ServerReactor {
+ public:
+ ~ServerReadReactor() = default;
+ virtual void OnStarted(ServerContext*, Response* resp) {}
+ virtual void OnSendInitialMetadataDone(bool ok) {}
+ virtual void OnReadDone(bool ok) {}
+
+ void StartSendInitialMetadata() { reader_->SendInitialMetadata(); }
+ void StartRead(Request* msg) { reader_->Read(msg); }
+ void Finish(Status s) { reader_->Finish(std::move(s)); }
+
+ private:
+ friend class ServerCallbackReader<Request>;
+ void BindReader(ServerCallbackReader<Request>* reader) { reader_ = reader; }
+
+ ServerCallbackReader<Request>* reader_;
+};
+
+template <class Request, class Response>
+class ServerWriteReactor : public internal::ServerReactor {
+ public:
+ ~ServerWriteReactor() = default;
+ virtual void OnStarted(ServerContext*, const Request* req) {}
+ virtual void OnSendInitialMetadataDone(bool ok) {}
+ virtual void OnWriteDone(bool ok) {}
+
+ void StartSendInitialMetadata() { writer_->SendInitialMetadata(); }
+ void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); }
+ void StartWrite(const Response* msg, WriteOptions options) {
+ writer_->Write(msg, std::move(options));
+ }
+ void StartWriteAndFinish(const Response* msg, WriteOptions options,
+ Status s) {
+ writer_->WriteAndFinish(msg, std::move(options), std::move(s));
+ }
+ void StartWriteLast(const Response* msg, WriteOptions options) {
+ StartWrite(msg, std::move(options.set_last_message()));
+ }
+ void Finish(Status s) { writer_->Finish(std::move(s)); }
+
+ private:
+ friend class ServerCallbackWriter<Response>;
+ void BindWriter(ServerCallbackWriter<Response>* writer) { writer_ = writer; }
+
+ ServerCallbackWriter<Response>* writer_;
+};
+
} // namespace experimental
namespace internal {
-template <class ServiceType, class RequestType, class ResponseType>
+template <class Request, class Response>
+class UnimplementedReadReactor
+ : public experimental::ServerReadReactor<Request, Response> {
+ public:
+ void OnDone() override { delete this; }
+ void OnStarted(ServerContext*, Response*) override {
+ this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
+ }
+};
+
+template <class Request, class Response>
+class UnimplementedWriteReactor
+ : public experimental::ServerWriteReactor<Request, Response> {
+ public:
+ void OnDone() override { delete this; }
+ void OnStarted(ServerContext*, const Request*) override {
+ this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
+ }
+};
+
+template <class Request, class Response>
+class UnimplementedBidiReactor
+ : public experimental::ServerBidiReactor<Request, Response> {
+ public:
+ void OnDone() override { delete this; }
+ void OnStarted(ServerContext*) override {
+ this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
+ }
+};
+
+template <class RequestType, class ResponseType>
class CallbackUnaryHandler : public MethodHandler {
public:
CallbackUnaryHandler(
std::function<void(ServerContext*, const RequestType*, ResponseType*,
experimental::ServerCallbackRpcController*)>
- func,
- ServiceType* service)
+ func)
: func_(func) {}
void RunHandler(const HandlerParameter& param) final {
// Arena allocate a controller structure (that includes request/response)
@@ -81,9 +272,8 @@ class CallbackUnaryHandler : public MethodHandler {
if (status.ok()) {
// Call the actual function handler and expect the user to call finish
- CatchingCallback(std::move(func_), param.server_context,
- controller->request(), controller->response(),
- controller);
+ CatchingCallback(func_, param.server_context, controller->request(),
+ controller->response(), controller);
} else {
// if deserialization failed, we need to fail the call
controller->Finish(status);
@@ -117,79 +307,579 @@ class CallbackUnaryHandler : public MethodHandler {
: public experimental::ServerCallbackRpcController {
public:
void Finish(Status s) override {
- finish_tag_.Set(
- call_.call(),
- [this](bool) {
- grpc_call* call = call_.call();
- auto call_requester = std::move(call_requester_);
- this->~ServerCallbackRpcControllerImpl(); // explicitly call
- // destructor
- g_core_codegen_interface->grpc_call_unref(call);
- call_requester();
- },
- &finish_buf_);
+ finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
+ &finish_ops_);
if (!ctx_->sent_initial_metadata_) {
- finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_,
+ finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
- finish_buf_.set_compression_level(ctx_->compression_level());
+ finish_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
// The response is dropped if the status is not OK.
if (s.ok()) {
- finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_,
- finish_buf_.SendMessage(resp_));
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_,
+ finish_ops_.SendMessage(resp_));
} else {
- finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, s);
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
}
- finish_buf_.set_core_cq_tag(&finish_tag_);
- call_.PerformOps(&finish_buf_);
+ finish_ops_.set_core_cq_tag(&finish_tag_);
+ call_.PerformOps(&finish_ops_);
}
void SendInitialMetadata(std::function<void(bool)> f) override {
GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
-
- meta_tag_.Set(call_.call(), std::move(f), &meta_buf_);
- meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_,
+ callbacks_outstanding_++;
+ // TODO(vjpai): Consider taking f as a move-capture if we adopt C++14
+ // and if performance of this operation matters
+ meta_tag_.Set(call_.call(),
+ [this, f](bool ok) {
+ f(ok);
+ MaybeDone();
+ },
+ &meta_ops_);
+ meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
- meta_buf_.set_compression_level(ctx_->compression_level());
+ meta_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
- meta_buf_.set_core_cq_tag(&meta_tag_);
- call_.PerformOps(&meta_buf_);
+ meta_ops_.set_core_cq_tag(&meta_tag_);
+ call_.PerformOps(&meta_ops_);
}
private:
- template <class SrvType, class ReqType, class RespType>
- friend class CallbackUnaryHandler;
+ friend class CallbackUnaryHandler<RequestType, ResponseType>;
ServerCallbackRpcControllerImpl(ServerContext* ctx, Call* call,
- RequestType* req,
+ const RequestType* req,
std::function<void()> call_requester)
: ctx_(ctx),
call_(*call),
req_(req),
- call_requester_(std::move(call_requester)) {}
+ call_requester_(std::move(call_requester)) {
+ ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, nullptr);
+ }
~ServerCallbackRpcControllerImpl() { req_->~RequestType(); }
- RequestType* request() { return req_; }
+ const RequestType* request() { return req_; }
ResponseType* response() { return &resp_; }
- CallOpSet<CallOpSendInitialMetadata> meta_buf_;
+ void MaybeDone() {
+ if (--callbacks_outstanding_ == 0) {
+ grpc_call* call = call_.call();
+ auto call_requester = std::move(call_requester_);
+ this->~ServerCallbackRpcControllerImpl(); // explicitly call destructor
+ g_core_codegen_interface->grpc_call_unref(call);
+ call_requester();
+ }
+ }
+
+ CallOpSet<CallOpSendInitialMetadata> meta_ops_;
CallbackWithSuccessTag meta_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
CallOpServerSendStatus>
- finish_buf_;
+ finish_ops_;
CallbackWithSuccessTag finish_tag_;
ServerContext* ctx_;
Call call_;
- RequestType* req_;
+ const RequestType* req_;
ResponseType resp_;
std::function<void()> call_requester_;
+ std::atomic_int callbacks_outstanding_{
+ 2}; // reserve for Finish and CompletionOp
+ };
+};
+
+template <class RequestType, class ResponseType>
+class CallbackClientStreamingHandler : public MethodHandler {
+ public:
+ CallbackClientStreamingHandler(
+ std::function<
+ experimental::ServerReadReactor<RequestType, ResponseType>*()>
+ func)
+ : func_(std::move(func)) {}
+ void RunHandler(const HandlerParameter& param) final {
+ // Arena allocate a reader structure (that includes response)
+ g_core_codegen_interface->grpc_call_ref(param.call->call());
+
+ experimental::ServerReadReactor<RequestType, ResponseType>* reactor =
+ param.status.ok()
+ ? CatchingReactorCreator<
+ experimental::ServerReadReactor<RequestType, ResponseType>>(
+ func_)
+ : nullptr;
+
+ if (reactor == nullptr) {
+ // if deserialization or reactor creator failed, we need to fail the call
+ reactor = new UnimplementedReadReactor<RequestType, ResponseType>;
+ }
+
+ auto* reader = new (g_core_codegen_interface->grpc_call_arena_alloc(
+ param.call->call(), sizeof(ServerCallbackReaderImpl)))
+ ServerCallbackReaderImpl(param.server_context, param.call,
+ std::move(param.call_requester), reactor);
+
+ reader->BindReactor(reactor);
+ reactor->OnStarted(param.server_context, reader->response());
+ reader->MaybeDone();
+ }
+
+ private:
+ std::function<experimental::ServerReadReactor<RequestType, ResponseType>*()>
+ func_;
+
+ class ServerCallbackReaderImpl
+ : public experimental::ServerCallbackReader<RequestType> {
+ public:
+ void Finish(Status s) override {
+ finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
+ &finish_ops_);
+ if (!ctx_->sent_initial_metadata_) {
+ finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ finish_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ }
+ // The response is dropped if the status is not OK.
+ if (s.ok()) {
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_,
+ finish_ops_.SendMessage(resp_));
+ } else {
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
+ }
+ finish_ops_.set_core_cq_tag(&finish_tag_);
+ call_.PerformOps(&finish_ops_);
+ }
+
+ void SendInitialMetadata() override {
+ GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
+ callbacks_outstanding_++;
+ meta_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnSendInitialMetadataDone(ok);
+ MaybeDone();
+ },
+ &meta_ops_);
+ meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ meta_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ meta_ops_.set_core_cq_tag(&meta_tag_);
+ call_.PerformOps(&meta_ops_);
+ }
+
+ void Read(RequestType* req) override {
+ callbacks_outstanding_++;
+ read_ops_.RecvMessage(req);
+ call_.PerformOps(&read_ops_);
+ }
+
+ private:
+ friend class CallbackClientStreamingHandler<RequestType, ResponseType>;
+
+ ServerCallbackReaderImpl(
+ ServerContext* ctx, Call* call, std::function<void()> call_requester,
+ experimental::ServerReadReactor<RequestType, ResponseType>* reactor)
+ : ctx_(ctx),
+ call_(*call),
+ call_requester_(std::move(call_requester)),
+ reactor_(reactor) {
+ ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
+ read_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnReadDone(ok);
+ MaybeDone();
+ },
+ &read_ops_);
+ read_ops_.set_core_cq_tag(&read_tag_);
+ }
+
+ ~ServerCallbackReaderImpl() {}
+
+ ResponseType* response() { return &resp_; }
+
+ void MaybeDone() {
+ if (--callbacks_outstanding_ == 0) {
+ reactor_->OnDone();
+ grpc_call* call = call_.call();
+ auto call_requester = std::move(call_requester_);
+ this->~ServerCallbackReaderImpl(); // explicitly call destructor
+ g_core_codegen_interface->grpc_call_unref(call);
+ call_requester();
+ }
+ }
+
+ CallOpSet<CallOpSendInitialMetadata> meta_ops_;
+ CallbackWithSuccessTag meta_tag_;
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
+ CallOpServerSendStatus>
+ finish_ops_;
+ CallbackWithSuccessTag finish_tag_;
+ CallOpSet<CallOpRecvMessage<RequestType>> read_ops_;
+ CallbackWithSuccessTag read_tag_;
+
+ ServerContext* ctx_;
+ Call call_;
+ ResponseType resp_;
+ std::function<void()> call_requester_;
+ experimental::ServerReadReactor<RequestType, ResponseType>* reactor_;
+ std::atomic_int callbacks_outstanding_{
+ 3}; // reserve for OnStarted, Finish, and CompletionOp
+ };
+};
+
+template <class RequestType, class ResponseType>
+class CallbackServerStreamingHandler : public MethodHandler {
+ public:
+ CallbackServerStreamingHandler(
+ std::function<
+ experimental::ServerWriteReactor<RequestType, ResponseType>*()>
+ func)
+ : func_(std::move(func)) {}
+ void RunHandler(const HandlerParameter& param) final {
+ // Arena allocate a writer structure
+ g_core_codegen_interface->grpc_call_ref(param.call->call());
+
+ experimental::ServerWriteReactor<RequestType, ResponseType>* reactor =
+ param.status.ok()
+ ? CatchingReactorCreator<
+ experimental::ServerWriteReactor<RequestType, ResponseType>>(
+ func_)
+ : nullptr;
+
+ if (reactor == nullptr) {
+ // if deserialization or reactor creator failed, we need to fail the call
+ reactor = new UnimplementedWriteReactor<RequestType, ResponseType>;
+ }
+
+ auto* writer = new (g_core_codegen_interface->grpc_call_arena_alloc(
+ param.call->call(), sizeof(ServerCallbackWriterImpl)))
+ ServerCallbackWriterImpl(param.server_context, param.call,
+ static_cast<RequestType*>(param.request),
+ std::move(param.call_requester), reactor);
+ writer->BindReactor(reactor);
+ reactor->OnStarted(param.server_context, writer->request());
+ writer->MaybeDone();
+ }
+
+ void* Deserialize(grpc_call* call, grpc_byte_buffer* req,
+ Status* status) final {
+ ByteBuffer buf;
+ buf.set_buffer(req);
+ auto* request = new (g_core_codegen_interface->grpc_call_arena_alloc(
+ call, sizeof(RequestType))) RequestType();
+ *status = SerializationTraits<RequestType>::Deserialize(&buf, request);
+ buf.Release();
+ if (status->ok()) {
+ return request;
+ }
+ request->~RequestType();
+ return nullptr;
+ }
+
+ private:
+ std::function<experimental::ServerWriteReactor<RequestType, ResponseType>*()>
+ func_;
+
+ class ServerCallbackWriterImpl
+ : public experimental::ServerCallbackWriter<ResponseType> {
+ public:
+ void Finish(Status s) override {
+ finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
+ &finish_ops_);
+ finish_ops_.set_core_cq_tag(&finish_tag_);
+
+ if (!ctx_->sent_initial_metadata_) {
+ finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ finish_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ }
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
+ call_.PerformOps(&finish_ops_);
+ }
+
+ void SendInitialMetadata() override {
+ GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
+ callbacks_outstanding_++;
+ meta_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnSendInitialMetadataDone(ok);
+ MaybeDone();
+ },
+ &meta_ops_);
+ meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ meta_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ meta_ops_.set_core_cq_tag(&meta_tag_);
+ call_.PerformOps(&meta_ops_);
+ }
+
+ void Write(const ResponseType* resp, WriteOptions options) override {
+ callbacks_outstanding_++;
+ if (options.is_last_message()) {
+ options.set_buffer_hint();
+ }
+ if (!ctx_->sent_initial_metadata_) {
+ write_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ write_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ }
+ // TODO(vjpai): don't assert
+ GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok());
+ call_.PerformOps(&write_ops_);
+ }
+
+ void WriteAndFinish(const ResponseType* resp, WriteOptions options,
+ Status s) override {
+ // This combines the write into the finish callback
+ // Don't send any message if the status is bad
+ if (s.ok()) {
+ // TODO(vjpai): don't assert
+ GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok());
+ }
+ Finish(std::move(s));
+ }
+
+ private:
+ friend class CallbackServerStreamingHandler<RequestType, ResponseType>;
+
+ ServerCallbackWriterImpl(
+ ServerContext* ctx, Call* call, const RequestType* req,
+ std::function<void()> call_requester,
+ experimental::ServerWriteReactor<RequestType, ResponseType>* reactor)
+ : ctx_(ctx),
+ call_(*call),
+ req_(req),
+ call_requester_(std::move(call_requester)),
+ reactor_(reactor) {
+ ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
+ write_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnWriteDone(ok);
+ MaybeDone();
+ },
+ &write_ops_);
+ write_ops_.set_core_cq_tag(&write_tag_);
+ }
+ ~ServerCallbackWriterImpl() { req_->~RequestType(); }
+
+ const RequestType* request() { return req_; }
+
+ void MaybeDone() {
+ if (--callbacks_outstanding_ == 0) {
+ reactor_->OnDone();
+ grpc_call* call = call_.call();
+ auto call_requester = std::move(call_requester_);
+ this->~ServerCallbackWriterImpl(); // explicitly call destructor
+ g_core_codegen_interface->grpc_call_unref(call);
+ call_requester();
+ }
+ }
+
+ CallOpSet<CallOpSendInitialMetadata> meta_ops_;
+ CallbackWithSuccessTag meta_tag_;
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
+ CallOpServerSendStatus>
+ finish_ops_;
+ CallbackWithSuccessTag finish_tag_;
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage> write_ops_;
+ CallbackWithSuccessTag write_tag_;
+
+ ServerContext* ctx_;
+ Call call_;
+ const RequestType* req_;
+ std::function<void()> call_requester_;
+ experimental::ServerWriteReactor<RequestType, ResponseType>* reactor_;
+ std::atomic_int callbacks_outstanding_{
+ 3}; // reserve for OnStarted, Finish, and CompletionOp
+ };
+};
+
+template <class RequestType, class ResponseType>
+class CallbackBidiHandler : public MethodHandler {
+ public:
+ CallbackBidiHandler(
+ std::function<
+ experimental::ServerBidiReactor<RequestType, ResponseType>*()>
+ func)
+ : func_(std::move(func)) {}
+ void RunHandler(const HandlerParameter& param) final {
+ g_core_codegen_interface->grpc_call_ref(param.call->call());
+
+ experimental::ServerBidiReactor<RequestType, ResponseType>* reactor =
+ param.status.ok()
+ ? CatchingReactorCreator<
+ experimental::ServerBidiReactor<RequestType, ResponseType>>(
+ func_)
+ : nullptr;
+
+ if (reactor == nullptr) {
+ // if deserialization or reactor creator failed, we need to fail the call
+ reactor = new UnimplementedBidiReactor<RequestType, ResponseType>;
+ }
+
+ auto* stream = new (g_core_codegen_interface->grpc_call_arena_alloc(
+ param.call->call(), sizeof(ServerCallbackReaderWriterImpl)))
+ ServerCallbackReaderWriterImpl(param.server_context, param.call,
+ std::move(param.call_requester),
+ reactor);
+
+ stream->BindReactor(reactor);
+ reactor->OnStarted(param.server_context);
+ stream->MaybeDone();
+ }
+
+ private:
+ std::function<experimental::ServerBidiReactor<RequestType, ResponseType>*()>
+ func_;
+
+ class ServerCallbackReaderWriterImpl
+ : public experimental::ServerCallbackReaderWriter<RequestType,
+ ResponseType> {
+ public:
+ void Finish(Status s) override {
+ finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
+ &finish_ops_);
+ finish_ops_.set_core_cq_tag(&finish_tag_);
+
+ if (!ctx_->sent_initial_metadata_) {
+ finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ finish_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ }
+ finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
+ call_.PerformOps(&finish_ops_);
+ }
+
+ void SendInitialMetadata() override {
+ GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
+ callbacks_outstanding_++;
+ meta_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnSendInitialMetadataDone(ok);
+ MaybeDone();
+ },
+ &meta_ops_);
+ meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ meta_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ meta_ops_.set_core_cq_tag(&meta_tag_);
+ call_.PerformOps(&meta_ops_);
+ }
+
+ void Write(const ResponseType* resp, WriteOptions options) override {
+ callbacks_outstanding_++;
+ if (options.is_last_message()) {
+ options.set_buffer_hint();
+ }
+ if (!ctx_->sent_initial_metadata_) {
+ write_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
+ ctx_->initial_metadata_flags());
+ if (ctx_->compression_level_set()) {
+ write_ops_.set_compression_level(ctx_->compression_level());
+ }
+ ctx_->sent_initial_metadata_ = true;
+ }
+ // TODO(vjpai): don't assert
+ GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok());
+ call_.PerformOps(&write_ops_);
+ }
+
+ void WriteAndFinish(const ResponseType* resp, WriteOptions options,
+ Status s) override {
+ // Don't send any message if the status is bad
+ if (s.ok()) {
+ // TODO(vjpai): don't assert
+ GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok());
+ }
+ Finish(std::move(s));
+ }
+
+ void Read(RequestType* req) override {
+ callbacks_outstanding_++;
+ read_ops_.RecvMessage(req);
+ call_.PerformOps(&read_ops_);
+ }
+
+ private:
+ friend class CallbackBidiHandler<RequestType, ResponseType>;
+
+ ServerCallbackReaderWriterImpl(
+ ServerContext* ctx, Call* call, std::function<void()> call_requester,
+ experimental::ServerBidiReactor<RequestType, ResponseType>* reactor)
+ : ctx_(ctx),
+ call_(*call),
+ call_requester_(std::move(call_requester)),
+ reactor_(reactor) {
+ ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
+ write_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnWriteDone(ok);
+ MaybeDone();
+ },
+ &write_ops_);
+ write_ops_.set_core_cq_tag(&write_tag_);
+ read_tag_.Set(call_.call(),
+ [this](bool ok) {
+ reactor_->OnReadDone(ok);
+ MaybeDone();
+ },
+ &read_ops_);
+ read_ops_.set_core_cq_tag(&read_tag_);
+ }
+ ~ServerCallbackReaderWriterImpl() {}
+
+ void MaybeDone() {
+ if (--callbacks_outstanding_ == 0) {
+ reactor_->OnDone();
+ grpc_call* call = call_.call();
+ auto call_requester = std::move(call_requester_);
+ this->~ServerCallbackReaderWriterImpl(); // explicitly call destructor
+ g_core_codegen_interface->grpc_call_unref(call);
+ call_requester();
+ }
+ }
+
+ CallOpSet<CallOpSendInitialMetadata> meta_ops_;
+ CallbackWithSuccessTag meta_tag_;
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
+ CallOpServerSendStatus>
+ finish_ops_;
+ CallbackWithSuccessTag finish_tag_;
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage> write_ops_;
+ CallbackWithSuccessTag write_tag_;
+ CallOpSet<CallOpRecvMessage<RequestType>> read_ops_;
+ CallbackWithSuccessTag read_tag_;
+
+ ServerContext* ctx_;
+ Call call_;
+ std::function<void()> call_requester_;
+ experimental::ServerBidiReactor<RequestType, ResponseType>* reactor_;
+ std::atomic_int callbacks_outstanding_{
+ 3}; // reserve for OnStarted, Finish, and CompletionOp
};
};
diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h
index 82ee862f61..4a5f9e2dd9 100644
--- a/include/grpcpp/impl/codegen/server_context.h
+++ b/include/grpcpp/impl/codegen/server_context.h
@@ -66,13 +66,20 @@ template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
class BidiStreamingHandler;
-template <class ServiceType, class RequestType, class ResponseType>
+template <class RequestType, class ResponseType>
class CallbackUnaryHandler;
+template <class RequestType, class ResponseType>
+class CallbackClientStreamingHandler;
+template <class RequestType, class ResponseType>
+class CallbackServerStreamingHandler;
+template <class RequestType, class ResponseType>
+class CallbackBidiHandler;
template <class Streamer, bool WriteNeeded>
class TemplatedBidiStreamingHandler;
template <StatusCode code>
class ErrorMethodHandler;
class Call;
+class ServerReactor;
} // namespace internal
class CompletionQueue;
@@ -270,8 +277,14 @@ class ServerContext {
friend class ::grpc::internal::ServerStreamingHandler;
template <class Streamer, bool WriteNeeded>
friend class ::grpc::internal::TemplatedBidiStreamingHandler;
- template <class ServiceType, class RequestType, class ResponseType>
+ template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackUnaryHandler;
+ template <class RequestType, class ResponseType>
+ friend class ::grpc::internal::CallbackClientStreamingHandler;
+ template <class RequestType, class ResponseType>
+ friend class ::grpc::internal::CallbackServerStreamingHandler;
+ template <class RequestType, class ResponseType>
+ friend class ::grpc::internal::CallbackBidiHandler;
template <StatusCode code>
friend class internal::ErrorMethodHandler;
friend class ::grpc::ClientContext;
@@ -282,7 +295,9 @@ class ServerContext {
class CompletionOp;
- void BeginCompletionOp(internal::Call* call, bool callback);
+ void BeginCompletionOp(internal::Call* call,
+ std::function<void(bool)> callback,
+ internal::ServerReactor* reactor);
/// Return the tag queued by BeginCompletionOp()
internal::CompletionQueueTag* GetCompletionOpTag();
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index a368b47f01..b004687250 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -889,6 +889,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
+ printer->Print(*vars,
+ "virtual ::grpc::experimental::ServerReadReactor< "
+ "$RealRequest$, $RealResponse$>* $Method$() {\n"
+ " return new ::grpc::internal::UnimplementedReadReactor<\n"
+ " $RealRequest$, $RealResponse$>;}\n");
} else if (ServerOnlyStreaming(method)) {
printer->Print(
*vars,
@@ -900,6 +905,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
+ printer->Print(*vars,
+ "virtual ::grpc::experimental::ServerWriteReactor< "
+ "$RealRequest$, $RealResponse$>* $Method$() {\n"
+ " return new ::grpc::internal::UnimplementedWriteReactor<\n"
+ " $RealRequest$, $RealResponse$>;}\n");
} else if (method->BidiStreaming()) {
printer->Print(
*vars,
@@ -911,6 +921,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
+ printer->Print(*vars,
+ "virtual ::grpc::experimental::ServerBidiReactor< "
+ "$RealRequest$, $RealResponse$>* $Method$() {\n"
+ " return new ::grpc::internal::UnimplementedBidiReactor<\n"
+ " $RealRequest$, $RealResponse$>;}\n");
}
}
@@ -939,22 +954,36 @@ void PrintHeaderServerMethodCallback(
*vars,
" ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
" new ::grpc::internal::CallbackUnaryHandler< "
- "ExperimentalWithCallbackMethod_$Method$<BaseClass>, $RealRequest$, "
- "$RealResponse$>(\n"
+ "$RealRequest$, $RealResponse$>(\n"
" [this](::grpc::ServerContext* context,\n"
" const $RealRequest$* request,\n"
" $RealResponse$* response,\n"
" ::grpc::experimental::ServerCallbackRpcController* "
"controller) {\n"
- " this->$"
+ " return this->$"
"Method$(context, request, response, controller);\n"
- " }, this));\n");
+ " }));\n");
} else if (ClientOnlyStreaming(method)) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackClientStreamingHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
} else if (ServerOnlyStreaming(method)) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackServerStreamingHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
} else if (method->BidiStreaming()) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackBidiHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
}
printer->Print(*vars, "}\n");
printer->Print(*vars,
@@ -991,8 +1020,7 @@ void PrintHeaderServerMethodRawCallback(
*vars,
" ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
" new ::grpc::internal::CallbackUnaryHandler< "
- "ExperimentalWithRawCallbackMethod_$Method$<BaseClass>, $RealRequest$, "
- "$RealResponse$>(\n"
+ "$RealRequest$, $RealResponse$>(\n"
" [this](::grpc::ServerContext* context,\n"
" const $RealRequest$* request,\n"
" $RealResponse$* response,\n"
@@ -1000,13 +1028,28 @@ void PrintHeaderServerMethodRawCallback(
"controller) {\n"
" this->$"
"Method$(context, request, response, controller);\n"
- " }, this));\n");
+ " }));\n");
} else if (ClientOnlyStreaming(method)) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackClientStreamingHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
} else if (ServerOnlyStreaming(method)) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackServerStreamingHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
} else if (method->BidiStreaming()) {
- // TODO(vjpai): Add in code generation for all streaming methods
+ printer->Print(
+ *vars,
+ " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
+ " new ::grpc::internal::CallbackBidiHandler< "
+ "$RealRequest$, $RealResponse$>(\n"
+ " [this] { return this->$Method$(); }));\n");
}
printer->Print(*vars, "}\n");
printer->Print(*vars,
diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc
index 0a51cf5626..69af43a656 100644
--- a/src/cpp/server/server_cc.cc
+++ b/src/cpp/server/server_cc.cc
@@ -291,7 +291,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
void ContinueRunAfterInterception() {
{
- ctx_.BeginCompletionOp(&call_, false);
+ ctx_.BeginCompletionOp(&call_, nullptr, nullptr);
global_callbacks_->PreSynchronousRequest(&ctx_);
auto* handler = resources_ ? method_->handler()
: server_->resource_exhausted_handler_.get();
@@ -456,7 +456,6 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
}
}
void ContinueRunAfterInterception() {
- req_->ctx_.BeginCompletionOp(call_, true);
req_->method_->handler()->RunHandler(
internal::MethodHandler::HandlerParameter(
call_, &req_->ctx_, req_->request_, req_->request_status_,
@@ -1018,7 +1017,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
}
}
if (*status && call_) {
- context_->BeginCompletionOp(&call_wrapper_, false);
+ context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr);
}
*tag = tag_;
if (delete_on_finalize_) {
@@ -1029,7 +1028,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
void ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception() {
- context_->BeginCompletionOp(&call_wrapper_, false);
+ context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr);
// Queue a tag which will be returned immediately
grpc_core::ExecCtx exec_ctx;
grpc_cq_begin_op(notification_cq_->cq(), this);
diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc
index 9c01f896e6..1b524bc3e8 100644
--- a/src/cpp/server/server_context.cc
+++ b/src/cpp/server/server_context.cc
@@ -17,6 +17,7 @@
*/
#include <grpcpp/server_context.h>
+#include <grpcpp/support/server_callback.h>
#include <algorithm>
#include <mutex>
@@ -41,8 +42,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
public:
// initial refs: one in the server context, one in the cq
// must ref the call before calling constructor and after deleting this
- CompletionOp(internal::Call* call)
+ CompletionOp(internal::Call* call, internal::ServerReactor* reactor)
: call_(*call),
+ reactor_(reactor),
has_tag_(false),
tag_(nullptr),
core_cq_tag_(this),
@@ -124,9 +126,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
return;
}
/* Start a dummy op so that we can return the tag */
- GPR_CODEGEN_ASSERT(GRPC_CALL_OK ==
- g_core_codegen_interface->grpc_call_start_batch(
- call_.call(), nullptr, 0, this, nullptr));
+ GPR_CODEGEN_ASSERT(
+ GRPC_CALL_OK ==
+ grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag_, nullptr));
}
private:
@@ -136,13 +138,14 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
}
internal::Call call_;
+ internal::ServerReactor* reactor_;
bool has_tag_;
void* tag_;
void* core_cq_tag_;
std::mutex mu_;
int refs_;
bool finalized_;
- int cancelled_;
+ int cancelled_; // This is an int (not bool) because it is passed to core
bool done_intercepting_;
internal::InterceptorBatchMethodsImpl interceptor_methods_;
};
@@ -190,7 +193,16 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) {
}
finalized_ = true;
- if (!*status) cancelled_ = 1;
+ // If for some reason the incoming status is false, mark that as a
+ // cancellation.
+ // TODO(vjpai): does this ever happen?
+ if (!*status) {
+ cancelled_ = 1;
+ }
+
+ if (cancelled_ && (reactor_ != nullptr)) {
+ reactor_->OnCancel();
+ }
/* Release the lock since we are going to be running through interceptors now
*/
lock.unlock();
@@ -251,21 +263,25 @@ void ServerContext::Clear() {
initial_metadata_.clear();
trailing_metadata_.clear();
client_metadata_.Reset();
- if (call_) {
- grpc_call_unref(call_);
- }
if (completion_op_) {
completion_op_->Unref();
+ completion_op_ = nullptr;
completion_tag_.Clear();
}
if (rpc_info_) {
rpc_info_->Unref();
+ rpc_info_ = nullptr;
+ }
+ if (call_) {
+ auto* call = call_;
+ call_ = nullptr;
+ grpc_call_unref(call);
}
- // Don't need to clear out call_, completion_op_, or rpc_info_ because this is
- // either called from destructor or just before Setup
}
-void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) {
+void ServerContext::BeginCompletionOp(internal::Call* call,
+ std::function<void(bool)> callback,
+ internal::ServerReactor* reactor) {
GPR_ASSERT(!completion_op_);
if (rpc_info_) {
rpc_info_->Ref();
@@ -273,10 +289,11 @@ void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) {
grpc_call_ref(call->call());
completion_op_ =
new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp)))
- CompletionOp(call);
- if (callback) {
- completion_tag_.Set(call->call(), nullptr, completion_op_);
+ CompletionOp(call, reactor);
+ if (callback != nullptr) {
+ completion_tag_.Set(call->call(), std::move(callback), completion_op_);
completion_op_->set_core_cq_tag(&completion_tag_);
+ completion_op_->set_tag(completion_op_);
} else if (has_notify_when_done_tag_) {
completion_op_->set_tag(async_notify_when_done_tag_);
}
diff --git a/test/cpp/codegen/compiler_test_golden b/test/cpp/codegen/compiler_test_golden
index 5f0eb6c35c..1871e1375e 100644
--- a/test/cpp/codegen/compiler_test_golden
+++ b/test/cpp/codegen/compiler_test_golden
@@ -322,13 +322,13 @@ class ServiceA final {
public:
ExperimentalWithCallbackMethod_MethodA1() {
::grpc::Service::experimental().MarkMethodCallback(0,
- new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodA1<BaseClass>, ::grpc::testing::Request, ::grpc::testing::Response>(
+ new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this](::grpc::ServerContext* context,
const ::grpc::testing::Request* request,
::grpc::testing::Response* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
- this->MethodA1(context, request, response, controller);
- }, this));
+ return this->MethodA1(context, request, response, controller);
+ }));
}
~ExperimentalWithCallbackMethod_MethodA1() override {
BaseClassMustBeDerivedFromService(this);
@@ -346,6 +346,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA2() {
+ ::grpc::Service::experimental().MarkMethodCallback(1,
+ new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
+ [this] { return this->MethodA2(); }));
}
~ExperimentalWithCallbackMethod_MethodA2() override {
BaseClassMustBeDerivedFromService(this);
@@ -355,6 +358,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerReadReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA2() {
+ return new ::grpc::internal::UnimplementedReadReactor<
+ ::grpc::testing::Request, ::grpc::testing::Response>;}
};
template <class BaseClass>
class ExperimentalWithCallbackMethod_MethodA3 : public BaseClass {
@@ -362,6 +368,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA3() {
+ ::grpc::Service::experimental().MarkMethodCallback(2,
+ new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
+ [this] { return this->MethodA3(); }));
}
~ExperimentalWithCallbackMethod_MethodA3() override {
BaseClassMustBeDerivedFromService(this);
@@ -371,6 +380,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerWriteReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA3() {
+ return new ::grpc::internal::UnimplementedWriteReactor<
+ ::grpc::testing::Request, ::grpc::testing::Response>;}
};
template <class BaseClass>
class ExperimentalWithCallbackMethod_MethodA4 : public BaseClass {
@@ -378,6 +390,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA4() {
+ ::grpc::Service::experimental().MarkMethodCallback(3,
+ new ::grpc::internal::CallbackBidiHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
+ [this] { return this->MethodA4(); }));
}
~ExperimentalWithCallbackMethod_MethodA4() override {
BaseClassMustBeDerivedFromService(this);
@@ -387,6 +402,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerBidiReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4() {
+ return new ::grpc::internal::UnimplementedBidiReactor<
+ ::grpc::testing::Request, ::grpc::testing::Response>;}
};
typedef ExperimentalWithCallbackMethod_MethodA1<ExperimentalWithCallbackMethod_MethodA2<ExperimentalWithCallbackMethod_MethodA3<ExperimentalWithCallbackMethod_MethodA4<Service > > > > ExperimentalCallbackService;
template <class BaseClass>
@@ -544,13 +562,13 @@ class ServiceA final {
public:
ExperimentalWithRawCallbackMethod_MethodA1() {
::grpc::Service::experimental().MarkMethodRawCallback(0,
- new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodA1<BaseClass>, ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
+ new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this](::grpc::ServerContext* context,
const ::grpc::ByteBuffer* request,
::grpc::ByteBuffer* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodA1(context, request, response, controller);
- }, this));
+ }));
}
~ExperimentalWithRawCallbackMethod_MethodA1() override {
BaseClassMustBeDerivedFromService(this);
@@ -568,6 +586,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA2() {
+ ::grpc::Service::experimental().MarkMethodRawCallback(1,
+ new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
+ [this] { return this->MethodA2(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA2() override {
BaseClassMustBeDerivedFromService(this);
@@ -577,6 +598,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerReadReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA2() {
+ return new ::grpc::internal::UnimplementedReadReactor<
+ ::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class ExperimentalWithRawCallbackMethod_MethodA3 : public BaseClass {
@@ -584,6 +608,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA3() {
+ ::grpc::Service::experimental().MarkMethodRawCallback(2,
+ new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
+ [this] { return this->MethodA3(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA3() override {
BaseClassMustBeDerivedFromService(this);
@@ -593,6 +620,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerWriteReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA3() {
+ return new ::grpc::internal::UnimplementedWriteReactor<
+ ::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class ExperimentalWithRawCallbackMethod_MethodA4 : public BaseClass {
@@ -600,6 +630,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA4() {
+ ::grpc::Service::experimental().MarkMethodRawCallback(3,
+ new ::grpc::internal::CallbackBidiHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
+ [this] { return this->MethodA4(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA4() override {
BaseClassMustBeDerivedFromService(this);
@@ -609,6 +642,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
+ virtual ::grpc::experimental::ServerBidiReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA4() {
+ return new ::grpc::internal::UnimplementedBidiReactor<
+ ::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class WithStreamedUnaryMethod_MethodA1 : public BaseClass {
@@ -752,13 +788,13 @@ class ServiceB final {
public:
ExperimentalWithCallbackMethod_MethodB1() {
::grpc::Service::experimental().MarkMethodCallback(0,
- new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodB1<BaseClass>, ::grpc::testing::Request, ::grpc::testing::Response>(
+ new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this](::grpc::ServerContext* context,
const ::grpc::testing::Request* request,
::grpc::testing::Response* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
- this->MethodB1(context, request, response, controller);
- }, this));
+ return this->MethodB1(context, request, response, controller);
+ }));
}
~ExperimentalWithCallbackMethod_MethodB1() override {
BaseClassMustBeDerivedFromService(this);
@@ -815,13 +851,13 @@ class ServiceB final {
public:
ExperimentalWithRawCallbackMethod_MethodB1() {
::grpc::Service::experimental().MarkMethodRawCallback(0,
- new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodB1<BaseClass>, ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
+ new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this](::grpc::ServerContext* context,
const ::grpc::ByteBuffer* request,
::grpc::ByteBuffer* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodB1(context, request, response, controller);
- }, this));
+ }));
}
~ExperimentalWithRawCallbackMethod_MethodB1() override {
BaseClassMustBeDerivedFromService(this);
diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc
index 03291e1785..4d37bae217 100644
--- a/test/cpp/end2end/end2end_test.cc
+++ b/test/cpp/end2end/end2end_test.cc
@@ -196,16 +196,18 @@ class TestServiceImplDupPkg
class TestScenario {
public:
TestScenario(bool interceptors, bool proxy, bool inproc_stub,
- const grpc::string& creds_type)
+ const grpc::string& creds_type, bool use_callback_server)
: use_interceptors(interceptors),
use_proxy(proxy),
inproc(inproc_stub),
- credentials_type(creds_type) {}
+ credentials_type(creds_type),
+ callback_server(use_callback_server) {}
void Log() const;
bool use_interceptors;
bool use_proxy;
bool inproc;
const grpc::string credentials_type;
+ bool callback_server;
};
static std::ostream& operator<<(std::ostream& out,
@@ -214,6 +216,8 @@ static std::ostream& operator<<(std::ostream& out,
<< (scenario.use_interceptors ? "true" : "false")
<< ", use_proxy=" << (scenario.use_proxy ? "true" : "false")
<< ", inproc=" << (scenario.inproc ? "true" : "false")
+ << ", server_type="
+ << (scenario.callback_server ? "callback" : "sync")
<< ", credentials='" << scenario.credentials_type << "'}";
}
@@ -280,7 +284,11 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
builder.experimental().SetInterceptorCreators(std::move(creators));
}
builder.AddListeningPort(server_address_.str(), server_creds);
- builder.RegisterService(&service_);
+ if (!GetParam().callback_server) {
+ builder.RegisterService(&service_);
+ } else {
+ builder.RegisterService(&callback_service_);
+ }
builder.RegisterService("foo.test.youtube.com", &special_service_);
builder.RegisterService(&dup_pkg_service_);
@@ -362,6 +370,7 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
std::ostringstream server_address_;
const int kMaxMessageSize_;
TestServiceImpl service_;
+ CallbackTestServiceImpl callback_service_;
TestServiceImpl special_service_;
TestServiceImplDupPkg dup_pkg_service_;
grpc::string user_agent_prefix_;
@@ -1016,7 +1025,8 @@ TEST_P(End2endTest, DiffPackageServices) {
EXPECT_TRUE(s.ok());
}
-void CancelRpc(ClientContext* context, int delay_us, TestServiceImpl* service) {
+template <class ServiceType>
+void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) {
gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_micros(delay_us, GPR_TIMESPAN)));
while (!service->signal_client()) {
@@ -1446,7 +1456,24 @@ TEST_P(ProxyEnd2endTest, ClientCancelsRpc) {
request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs);
ClientContext context;
- std::thread cancel_thread(CancelRpc, &context, kCancelDelayUs, &service_);
+ std::thread cancel_thread;
+ if (!GetParam().callback_server) {
+ cancel_thread = std::thread(
+ [&context, this](int delay) { CancelRpc(&context, delay, &service_); },
+ kCancelDelayUs);
+ // Note: the unusual pattern above (and below) is caused by a conflict
+ // between two sets of compiler expectations. clang allows const to be
+ // captured without mention, so there is no need to capture kCancelDelayUs
+ // (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler
+ // in our tests requires an explicit capture even for const. We square this
+ // circle by passing the const value in as an argument to the lambda.
+ } else {
+ cancel_thread = std::thread(
+ [&context, this](int delay) {
+ CancelRpc(&context, delay, &callback_service_);
+ },
+ kCancelDelayUs);
+ }
Status s = stub_->Echo(&context, request, &response);
cancel_thread.join();
EXPECT_EQ(StatusCode::CANCELLED, s.error_code());
@@ -1838,10 +1865,12 @@ TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) {
EXPECT_TRUE(s.ok());
}
+// TODO(vjpai): refactor arguments into a struct if it makes sense
std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
bool test_insecure,
bool test_secure,
- bool test_inproc) {
+ bool test_inproc,
+ bool test_callback_server) {
std::vector<TestScenario> scenarios;
std::vector<grpc::string> credentials_types;
if (test_secure) {
@@ -1857,41 +1886,48 @@ std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
if (test_insecure && insec_ok()) {
credentials_types.push_back(kInsecureCredentialsType);
}
+
+ // For now test callback server only with inproc
GPR_ASSERT(!credentials_types.empty());
for (const auto& cred : credentials_types) {
- scenarios.emplace_back(false, false, false, cred);
- scenarios.emplace_back(true, false, false, cred);
+ scenarios.emplace_back(false, false, false, cred, false);
+ scenarios.emplace_back(true, false, false, cred, false);
if (use_proxy) {
- scenarios.emplace_back(false, true, false, cred);
- scenarios.emplace_back(true, true, false, cred);
+ scenarios.emplace_back(false, true, false, cred, false);
+ scenarios.emplace_back(true, true, false, cred, false);
}
}
if (test_inproc && insec_ok()) {
- scenarios.emplace_back(false, false, true, kInsecureCredentialsType);
- scenarios.emplace_back(true, false, true, kInsecureCredentialsType);
+ scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false);
+ scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false);
+ if (test_callback_server) {
+ scenarios.emplace_back(false, false, true, kInsecureCredentialsType,
+ true);
+ scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true);
+ }
}
return scenarios;
}
-INSTANTIATE_TEST_CASE_P(End2end, End2endTest,
- ::testing::ValuesIn(CreateTestScenarios(false, true,
- true, true)));
+INSTANTIATE_TEST_CASE_P(
+ End2end, End2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
-INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest,
- ::testing::ValuesIn(CreateTestScenarios(false, true,
- true, true)));
+INSTANTIATE_TEST_CASE_P(
+ End2endServerTryCancel, End2endServerTryCancelTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
-INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest,
- ::testing::ValuesIn(CreateTestScenarios(true, true,
- true, true)));
+INSTANTIATE_TEST_CASE_P(
+ ProxyEnd2end, ProxyEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, false)));
-INSTANTIATE_TEST_CASE_P(SecureEnd2end, SecureEnd2endTest,
- ::testing::ValuesIn(CreateTestScenarios(false, false,
- true, false)));
+INSTANTIATE_TEST_CASE_P(
+ SecureEnd2end, SecureEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true)));
-INSTANTIATE_TEST_CASE_P(ResourceQuotaEnd2end, ResourceQuotaEnd2endTest,
- ::testing::ValuesIn(CreateTestScenarios(false, true,
- true, true)));
+INSTANTIATE_TEST_CASE_P(
+ ResourceQuotaEnd2end, ResourceQuotaEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, false)));
} // namespace
} // namespace testing
diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc
index 1726e87ea6..9d9c01cade 100644
--- a/test/cpp/end2end/test_service_impl.cc
+++ b/test/cpp/end2end/test_service_impl.cc
@@ -71,6 +71,46 @@ void CheckServerAuthContext(
}
} // namespace
+namespace {
+int GetIntValueFromMetadataHelper(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value) {
+ if (metadata.find(key) != metadata.end()) {
+ std::istringstream iss(ToString(metadata.find(key)->second));
+ iss >> default_value;
+ gpr_log(GPR_INFO, "%s : %d", key, default_value);
+ }
+
+ return default_value;
+}
+
+int GetIntValueFromMetadata(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value) {
+ return GetIntValueFromMetadataHelper(key, metadata, default_value);
+}
+
+void ServerTryCancel(ServerContext* context) {
+ EXPECT_FALSE(context->IsCancelled());
+ context->TryCancel();
+ gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
+ // Now wait until it's really canceled
+ while (!context->IsCancelled()) {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(1000, GPR_TIMESPAN)));
+ }
+}
+
+void ServerTryCancelNonblocking(ServerContext* context) {
+ EXPECT_FALSE(context->IsCancelled());
+ context->TryCancel();
+ gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
+}
+
+} // namespace
+
Status TestServiceImpl::Echo(ServerContext* context, const EchoRequest* request,
EchoResponse* response) {
// A bit of sleep to make sure that short deadline tests fail
@@ -195,6 +235,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
controller->Finish(Status(static_cast<StatusCode>(error.code()),
error.error_message(),
error.binary_error_details()));
+ return;
}
int server_try_cancel = GetIntValueFromMetadata(
kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
@@ -254,7 +295,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
alarm_.experimental().Set(
gpr_time_add(
gpr_now(GPR_CLOCK_REALTIME),
- gpr_time_from_micros(request->param().client_cancel_after_us(),
+ gpr_time_from_micros(request->param().server_cancel_after_us(),
GPR_TIMESPAN)),
[controller](bool) { controller->Finish(Status::CANCELLED); });
return;
@@ -279,6 +320,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
request->param().debug_info().SerializeAsString();
context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info);
controller->Finish(Status::CANCELLED);
+ return;
}
}
if (request->has_param() &&
@@ -325,7 +367,7 @@ Status TestServiceImpl::RequestStream(ServerContext* context,
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
- new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
+ new std::thread([context] { ServerTryCancel(context); });
}
int num_msgs_read = 0;
@@ -380,7 +422,7 @@ Status TestServiceImpl::ResponseStream(ServerContext* context,
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
- new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
+ new std::thread([context] { ServerTryCancel(context); });
}
for (int i = 0; i < server_responses_to_send; i++) {
@@ -431,7 +473,7 @@ Status TestServiceImpl::BidiStream(
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
- new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
+ new std::thread([context] { ServerTryCancel(context); });
}
// kServerFinishAfterNReads suggests after how many reads, the server should
@@ -465,44 +507,244 @@ Status TestServiceImpl::BidiStream(
return Status::OK;
}
-namespace {
-int GetIntValueFromMetadataHelper(
- const char* key,
- const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
- int default_value) {
- if (metadata.find(key) != metadata.end()) {
- std::istringstream iss(ToString(metadata.find(key)->second));
- iss >> default_value;
- gpr_log(GPR_INFO, "%s : %d", key, default_value);
- }
+experimental::ServerReadReactor<EchoRequest, EchoResponse>*
+CallbackTestServiceImpl::RequestStream() {
+ class Reactor : public ::grpc::experimental::ServerReadReactor<EchoRequest,
+ EchoResponse> {
+ public:
+ Reactor() {}
+ void OnStarted(ServerContext* context, EchoResponse* response) override {
+ ctx_ = context;
+ response_ = response;
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ server_try_cancel_ = GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+
+ response_->set_message("");
+
+ if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ return;
+ }
- return default_value;
-}
-}; // namespace
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx_->TryCancel();
+ // Don't wait for it here
+ }
-int TestServiceImpl::GetIntValueFromMetadata(
- const char* key,
- const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
- int default_value) {
- return GetIntValueFromMetadataHelper(key, metadata, default_value);
+ StartRead(&request_);
+ }
+ void OnDone() override { delete this; }
+ void OnCancel() override { FinishOnce(Status::CANCELLED); }
+ void OnReadDone(bool ok) override {
+ if (ok) {
+ response_->mutable_message()->append(request_.message());
+ num_msgs_read_++;
+ StartRead(&request_);
+ } else {
+ gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_);
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel recover this
+ return;
+ }
+ if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ return;
+ }
+ FinishOnce(Status::OK);
+ }
+ }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ Finish(s);
+ finished_ = true;
+ }
+ }
+
+ ServerContext* ctx_;
+ EchoResponse* response_;
+ EchoRequest request_;
+ int num_msgs_read_{0};
+ int server_try_cancel_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ };
+
+ return new Reactor;
}
-int CallbackTestServiceImpl::GetIntValueFromMetadata(
- const char* key,
- const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
- int default_value) {
- return GetIntValueFromMetadataHelper(key, metadata, default_value);
+// Return 'kNumResponseStreamMsgs' messages.
+// TODO(yangg) make it generic by adding a parameter into EchoRequest
+experimental::ServerWriteReactor<EchoRequest, EchoResponse>*
+CallbackTestServiceImpl::ResponseStream() {
+ class Reactor
+ : public ::grpc::experimental::ServerWriteReactor<EchoRequest,
+ EchoResponse> {
+ public:
+ Reactor() {}
+ void OnStarted(ServerContext* context,
+ const EchoRequest* request) override {
+ ctx_ = context;
+ request_ = request;
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ server_try_cancel_ = GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+ server_coalescing_api_ = GetIntValueFromMetadata(
+ kServerUseCoalescingApi, context->client_metadata(), 0);
+ server_responses_to_send_ = GetIntValueFromMetadata(
+ kServerResponseStreamsToSend, context->client_metadata(),
+ kServerDefaultResponseStreamsToSend);
+ if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ return;
+ }
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx_->TryCancel();
+ }
+ if (num_msgs_sent_ < server_responses_to_send_) {
+ NextWrite();
+ }
+ }
+ void OnDone() override { delete this; }
+ void OnCancel() override { FinishOnce(Status::CANCELLED); }
+ void OnWriteDone(bool ok) override {
+ if (num_msgs_sent_ < server_responses_to_send_) {
+ NextWrite();
+ } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel recover this
+ } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ } else {
+ FinishOnce(Status::OK);
+ }
+ }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ Finish(s);
+ finished_ = true;
+ }
+ }
+
+ void NextWrite() {
+ response_.set_message(request_->message() +
+ grpc::to_string(num_msgs_sent_));
+ if (num_msgs_sent_ == server_responses_to_send_ - 1 &&
+ server_coalescing_api_ != 0) {
+ num_msgs_sent_++;
+ StartWriteLast(&response_, WriteOptions());
+ } else {
+ num_msgs_sent_++;
+ StartWrite(&response_);
+ }
+ }
+ ServerContext* ctx_;
+ const EchoRequest* request_;
+ EchoResponse response_;
+ int num_msgs_sent_{0};
+ int server_try_cancel_;
+ int server_coalescing_api_;
+ int server_responses_to_send_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ };
+ return new Reactor;
}
-void TestServiceImpl::ServerTryCancel(ServerContext* context) {
- EXPECT_FALSE(context->IsCancelled());
- context->TryCancel();
- gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
- // Now wait until it's really canceled
- while (!context->IsCancelled()) {
- gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
- gpr_time_from_micros(1000, GPR_TIMESPAN)));
- }
+experimental::ServerBidiReactor<EchoRequest, EchoResponse>*
+CallbackTestServiceImpl::BidiStream() {
+ class Reactor : public ::grpc::experimental::ServerBidiReactor<EchoRequest,
+ EchoResponse> {
+ public:
+ Reactor() {}
+ void OnStarted(ServerContext* context) override {
+ ctx_ = context;
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ server_try_cancel_ = GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+ server_write_last_ = GetIntValueFromMetadata(
+ kServerFinishAfterNReads, context->client_metadata(), 0);
+ if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ return;
+ }
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx_->TryCancel();
+ }
+
+ StartRead(&request_);
+ }
+ void OnDone() override { delete this; }
+ void OnCancel() override { FinishOnce(Status::CANCELLED); }
+ void OnReadDone(bool ok) override {
+ if (ok) {
+ num_msgs_read_++;
+ gpr_log(GPR_INFO, "recv msg %s", request_.message().c_str());
+ response_.set_message(request_.message());
+ if (num_msgs_read_ == server_write_last_) {
+ StartWriteLast(&response_, WriteOptions());
+ } else {
+ StartWrite(&response_);
+ }
+ } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel handle this
+ } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ ServerTryCancelNonblocking(ctx_);
+ } else {
+ FinishOnce(Status::OK);
+ }
+ }
+ void OnWriteDone(bool ok) override { StartRead(&request_); }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ Finish(s);
+ finished_ = true;
+ }
+ }
+
+ ServerContext* ctx_;
+ EchoRequest request_;
+ EchoResponse response_;
+ int num_msgs_read_{0};
+ int server_try_cancel_;
+ int server_write_last_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ };
+
+ return new Reactor;
}
} // namespace testing
diff --git a/test/cpp/end2end/test_service_impl.h b/test/cpp/end2end/test_service_impl.h
index ddfe94487e..fad7768b87 100644
--- a/test/cpp/end2end/test_service_impl.h
+++ b/test/cpp/end2end/test_service_impl.h
@@ -72,13 +72,6 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
}
private:
- int GetIntValueFromMetadata(
- const char* key,
- const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
- int default_value);
-
- void ServerTryCancel(ServerContext* context);
-
bool signal_client_;
std::mutex mu_;
std::unique_ptr<grpc::string> host_;
@@ -95,6 +88,15 @@ class CallbackTestServiceImpl
EchoResponse* response,
experimental::ServerCallbackRpcController* controller) override;
+ experimental::ServerReadReactor<EchoRequest, EchoResponse>* RequestStream()
+ override;
+
+ experimental::ServerWriteReactor<EchoRequest, EchoResponse>* ResponseStream()
+ override;
+
+ experimental::ServerBidiReactor<EchoRequest, EchoResponse>* BidiStream()
+ override;
+
// Unimplemented is left unimplemented to test the returned error.
bool signal_client() {
std::unique_lock<std::mutex> lock(mu_);
@@ -106,11 +108,6 @@ class CallbackTestServiceImpl
EchoResponse* response,
experimental::ServerCallbackRpcController* controller);
- int GetIntValueFromMetadata(
- const char* key,
- const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
- int default_value);
-
Alarm alarm_;
bool signal_client_;
std::mutex mu_;