diff options
Diffstat (limited to 'include/grpc++/impl/codegen/method_handler_impl.h')
-rw-r--r-- | include/grpc++/impl/codegen/method_handler_impl.h | 64 |
1 files changed, 49 insertions, 15 deletions
diff --git a/include/grpc++/impl/codegen/method_handler_impl.h b/include/grpc++/impl/codegen/method_handler_impl.h index 2f4be644ba..d989263252 100644 --- a/include/grpc++/impl/codegen/method_handler_impl.h +++ b/include/grpc++/impl/codegen/method_handler_impl.h @@ -167,20 +167,22 @@ class ServerStreamingHandler : public MethodHandler { }; // A wrapper class of an application provided bidi-streaming handler. -template <class ServiceType, class RequestType, class ResponseType> -class BidiStreamingHandler : public MethodHandler { +// This also applies to server-streamed implementation of a unary method +// with the additional requirement that such methods must have done a +// write for status to be ok +// Since this is used by more than 1 class, the service is not passed in. +// Instead, it is expected to be an implicitly-captured argument of func +// (through bind or something along those lines) +template <class Streamer, bool WriteNeeded> +class TemplatedBidiStreamingHandler : public MethodHandler { public: - BidiStreamingHandler( - std::function<Status(ServiceType*, ServerContext*, - ServerReaderWriter<ResponseType, RequestType>*)> - func, - ServiceType* service) - : func_(func), service_(service) {} + TemplatedBidiStreamingHandler( + std::function<Status(ServerContext*, Streamer*)> func) + : func_(func), write_needed_(WriteNeeded) {} void RunHandler(const HandlerParameter& param) GRPC_FINAL { - ServerReaderWriter<ResponseType, RequestType> stream(param.call, - param.server_context); - Status status = func_(service_, param.server_context, &stream); + Streamer stream(param.call, param.server_context); + Status status = func_(param.server_context, &stream); CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { @@ -189,6 +191,12 @@ class BidiStreamingHandler : public MethodHandler { if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } + if (write_needed_ && status.ok()) { + // If we needed a write but never did one, we need to mark the + // status as a fail + status = Status(StatusCode::INTERNAL, + "Service did not provide response message"); + } } ops.ServerSendStatus(param.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); @@ -196,10 +204,36 @@ class BidiStreamingHandler : public MethodHandler { } private: - std::function<Status(ServiceType*, ServerContext*, - ServerReaderWriter<ResponseType, RequestType>*)> - func_; - ServiceType* service_; + std::function<Status(ServerContext*, Streamer*)> func_; + const bool write_needed_; +}; + +template <class ServiceType, class RequestType, class ResponseType> +class BidiStreamingHandler + : public TemplatedBidiStreamingHandler< + ServerReaderWriter<ResponseType, RequestType>, false> { + public: + BidiStreamingHandler( + std::function<Status(ServiceType*, ServerContext*, + ServerReaderWriter<ResponseType, RequestType>*)> + func, + ServiceType* service) + : TemplatedBidiStreamingHandler< + ServerReaderWriter<ResponseType, RequestType>, false>(std::bind( + func, service, std::placeholders::_1, std::placeholders::_2)) {} +}; + +template <class RequestType, class ResponseType> +class StreamedUnaryHandler + : public TemplatedBidiStreamingHandler< + ServerUnaryStreamer<RequestType, ResponseType>, true> { + public: + explicit StreamedUnaryHandler( + std::function<Status(ServerContext*, + ServerUnaryStreamer<RequestType, ResponseType>*)> + func) + : TemplatedBidiStreamingHandler< + ServerUnaryStreamer<RequestType, ResponseType>, true>(func) {} }; // Handle unknown method by returning UNIMPLEMENTED error. |