diff options
author | Yash Tibrewal <yashkt@google.com> | 2018-10-28 23:36:59 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-28 23:36:59 -0700 |
commit | 01313976e1a44b5c9625d3a349fffa55471beff4 (patch) | |
tree | 805596796cce33154e0d875c4c1ade918ba958f2 /include/grpcpp | |
parent | ffac9d90b18cb076b1c952faa55ce4e049cbc9a6 (diff) | |
parent | 395edbfa24968b8406a0c157874d3cb473076df5 (diff) |
Merge pull request #16842 from yashykt/interceptors
Experimental API for Client and Server Interception
Diffstat (limited to 'include/grpcpp')
28 files changed, 1986 insertions, 760 deletions
diff --git a/include/grpcpp/channel.h b/include/grpcpp/channel.h index b7c9e354de..14209b85ee 100644 --- a/include/grpcpp/channel.h +++ b/include/grpcpp/channel.h @@ -20,6 +20,7 @@ #define GRPCPP_CHANNEL_H #include <memory> +#include <mutex> #include <grpc/grpc.h> #include <grpcpp/impl/call.h> @@ -67,6 +68,7 @@ class Channel final : public ChannelInterface, std::unique_ptr<std::vector< std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> interceptor_creators); + friend class internal::InterceptedChannel; Channel(const grpc::string& host, grpc_channel* c_channel, std::unique_ptr<std::vector< std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> @@ -87,6 +89,10 @@ class Channel final : public ChannelInterface, CompletionQueue* CallbackCQ() override; + internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, CompletionQueue* cq, + int interceptor_pos) override; + const grpc::string host_; grpc_channel* const c_channel_; // owned diff --git a/include/grpcpp/impl/codegen/async_stream.h b/include/grpcpp/impl/codegen/async_stream.h index 6e58fd0eef..bfb2df4f23 100644 --- a/include/grpcpp/impl/codegen/async_stream.h +++ b/include/grpcpp/impl/codegen/async_stream.h @@ -276,7 +276,7 @@ class ClientAsyncReader final : public ClientAsyncReaderInterface<R> { } void StartCallInternal(void* tag) { - init_ops_.SendInitialMetadata(context_->send_initial_metadata_, + init_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); init_ops_.set_output_tag(tag); call_.PerformOps(&init_ops_); @@ -441,7 +441,7 @@ class ClientAsyncWriter final : public ClientAsyncWriterInterface<W> { } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -612,7 +612,7 @@ class ClientAsyncReaderWriter final } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -710,7 +710,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -739,7 +739,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { void Finish(const W& msg, const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); @@ -748,10 +748,10 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { } // The response is dropped if the status is not OK. if (status.ok()) { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, finish_ops_.SendMessage(msg)); } else { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_ops_); } @@ -769,14 +769,14 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { GPR_CODEGEN_ASSERT(!status.ok()); finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -859,7 +859,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -904,7 +904,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -922,7 +922,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { void Finish(const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -932,7 +932,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { template <class T> void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); @@ -1025,7 +1025,7 @@ class ServerAsyncReaderWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -1075,7 +1075,7 @@ class ServerAsyncReaderWriter final EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -1094,7 +1094,7 @@ class ServerAsyncReaderWriter final finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -1106,7 +1106,7 @@ class ServerAsyncReaderWriter final template <class T> void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/impl/codegen/async_unary_call.h b/include/grpcpp/impl/codegen/async_unary_call.h index 60ff8e2f05..744b128141 100644 --- a/include/grpcpp/impl/codegen/async_unary_call.h +++ b/include/grpcpp/impl/codegen/async_unary_call.h @@ -174,7 +174,7 @@ class ClientAsyncResponseReader final } void StartCallInternal() { - single_buf.SendInitialMetadata(context_->send_initial_metadata_, + single_buf.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); } @@ -214,7 +214,7 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_buf_.set_output_tag(tag); - meta_buf_.SendInitialMetadata(ctx_->initial_metadata_, + meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_buf_.set_compression_level(ctx_->compression_level()); @@ -240,8 +240,9 @@ class ServerAsyncResponseWriter final /// metadata. void Finish(const W& msg, const Status& status, void* tag) { finish_buf_.set_output_tag(tag); + finish_buf_.set_cq_tag(&finish_buf_); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); @@ -250,10 +251,10 @@ class ServerAsyncResponseWriter final } // The response is dropped if the status is not OK. if (status.ok()) { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, finish_buf_.SendMessage(msg)); } else { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_buf_); } @@ -274,14 +275,14 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!status.ok()); finish_buf_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index 8cc5158115..d54ae31852 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -50,6 +50,11 @@ class ErrorMethodHandler; template <class R> class DeserializeFuncType; class GrpcByteBufferPeer; +template <class ServiceType, class RequestType, class ResponseType> +class RpcMethodHandler; +template <class ServiceType, class RequestType, class ResponseType> +class ServerStreamingHandler; + } // namespace internal /// A sequence of bytes. class ByteBuffer final { @@ -141,7 +146,10 @@ class ByteBuffer final { template <class R> friend class internal::CallOpRecvMessage; friend class internal::CallOpGenericRecvMessage; - friend class internal::MethodHandler; + template <class ServiceType, class RequestType, class ResponseType> + friend class RpcMethodHandler; + template <class ServiceType, class RequestType, class ResponseType> + friend class ServerStreamingHandler; template <class ServiceType, class RequestType, class ResponseType> friend class internal::RpcMethodHandler; template <class ServiceType, class RequestType, class ResponseType> diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 789ea805a3..c040c30dd9 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -1,6 +1,6 @@ /* * - * Copyright 2015 gRPC authors. + * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,671 +15,31 @@ * limitations under the License. * */ - #ifndef GRPCPP_IMPL_CODEGEN_CALL_H #define GRPCPP_IMPL_CODEGEN_CALL_H -#include <assert.h> -#include <cstring> -#include <functional> -#include <map> -#include <memory> - -#include <grpcpp/impl/codegen/byte_buffer.h> -#include <grpcpp/impl/codegen/call_hook.h> -#include <grpcpp/impl/codegen/client_context.h> -#include <grpcpp/impl/codegen/completion_queue_tag.h> -#include <grpcpp/impl/codegen/config.h> -#include <grpcpp/impl/codegen/core_codegen_interface.h> -#include <grpcpp/impl/codegen/serialization_traits.h> -#include <grpcpp/impl/codegen/slice.h> -#include <grpcpp/impl/codegen/status.h> -#include <grpcpp/impl/codegen/string_ref.h> - -#include <grpc/impl/codegen/atm.h> -#include <grpc/impl/codegen/compression_types.h> #include <grpc/impl/codegen/grpc_types.h> +#include <grpcpp/impl/codegen/call_hook.h> namespace grpc { - -class ByteBuffer; class CompletionQueue; -extern CoreCodegenInterface* g_core_codegen_interface; +namespace experimental { +class ClientRpcInfo; +class ServerRpcInfo; +} // namespace experimental namespace internal { -class Call; class CallHook; - -// TODO(yangg) if the map is changed before we send, the pointers will be a -// mess. Make sure it does not happen. -inline grpc_metadata* FillMetadataArray( - const std::multimap<grpc::string, grpc::string>& metadata, - size_t* metadata_count, const grpc::string& optional_error_details) { - *metadata_count = metadata.size() + (optional_error_details.empty() ? 0 : 1); - if (*metadata_count == 0) { - return nullptr; - } - grpc_metadata* metadata_array = - (grpc_metadata*)(g_core_codegen_interface->gpr_malloc( - (*metadata_count) * sizeof(grpc_metadata))); - size_t i = 0; - for (auto iter = metadata.cbegin(); iter != metadata.cend(); ++iter, ++i) { - metadata_array[i].key = SliceReferencingString(iter->first); - metadata_array[i].value = SliceReferencingString(iter->second); - } - if (!optional_error_details.empty()) { - metadata_array[i].key = - g_core_codegen_interface->grpc_slice_from_static_buffer( - kBinaryErrorDetailsKey, sizeof(kBinaryErrorDetailsKey) - 1); - metadata_array[i].value = SliceReferencingString(optional_error_details); - } - return metadata_array; -} -} // namespace internal - -/// Per-message write options. -class WriteOptions { - public: - WriteOptions() : flags_(0), last_message_(false) {} - WriteOptions(const WriteOptions& other) - : flags_(other.flags_), last_message_(other.last_message_) {} - - /// Clear all flags. - inline void Clear() { flags_ = 0; } - - /// Returns raw flags bitset. - inline uint32_t flags() const { return flags_; } - - /// Sets flag for the disabling of compression for the next message write. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline WriteOptions& set_no_compression() { - SetBit(GRPC_WRITE_NO_COMPRESS); - return *this; - } - - /// Clears flag for the disabling of compression for the next message write. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline WriteOptions& clear_no_compression() { - ClearBit(GRPC_WRITE_NO_COMPRESS); - return *this; - } - - /// Get value for the flag indicating whether compression for the next - /// message write is forcefully disabled. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline bool get_no_compression() const { - return GetBit(GRPC_WRITE_NO_COMPRESS); - } - - /// Sets flag indicating that the write may be buffered and need not go out on - /// the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline WriteOptions& set_buffer_hint() { - SetBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - /// Clears flag indicating that the write may be buffered and need not go out - /// on the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline WriteOptions& clear_buffer_hint() { - ClearBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - /// Get value for the flag indicating that the write may be buffered and need - /// not go out on the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline bool get_buffer_hint() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } - - /// corked bit: aliases set_buffer_hint currently, with the intent that - /// set_buffer_hint will be removed in the future - inline WriteOptions& set_corked() { - SetBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - inline WriteOptions& clear_corked() { - ClearBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - inline bool is_corked() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } - - /// last-message bit: indicates this is the last message in a stream - /// client-side: makes Write the equivalent of performing Write, WritesDone - /// in a single step - /// server-side: hold the Write until the service handler returns (sync api) - /// or until Finish is called (async api) - inline WriteOptions& set_last_message() { - last_message_ = true; - return *this; - } - - /// Clears flag indicating that this is the last message in a stream, - /// disabling coalescing. - inline WriteOptions& clear_last_message() { - last_message_ = false; - return *this; - } - - /// Guarantee that all bytes have been written to the socket before completing - /// this write (usually writes are completed when they pass flow control). - inline WriteOptions& set_write_through() { - SetBit(GRPC_WRITE_THROUGH); - return *this; - } - - inline bool is_write_through() const { return GetBit(GRPC_WRITE_THROUGH); } - - /// Get value for the flag indicating that this is the last message, and - /// should be coalesced with trailing metadata. - /// - /// \sa GRPC_WRITE_LAST_MESSAGE - bool is_last_message() const { return last_message_; } - - WriteOptions& operator=(const WriteOptions& rhs) { - flags_ = rhs.flags_; - return *this; - } - - private: - void SetBit(const uint32_t mask) { flags_ |= mask; } - - void ClearBit(const uint32_t mask) { flags_ &= ~mask; } - - bool GetBit(const uint32_t mask) const { return (flags_ & mask) != 0; } - - uint32_t flags_; - bool last_message_; -}; - -namespace internal { -/// Default argument for CallOpSet. I is unused by the class, but can be -/// used for generating multiple names for the same thing. -template <int I> -class CallNoOp { - protected: - void AddOp(grpc_op* ops, size_t* nops) {} - void FinishOp(bool* status) {} -}; - -class CallOpSendInitialMetadata { - public: - CallOpSendInitialMetadata() : send_(false) { - maybe_compression_level_.is_set = false; - } - - void SendInitialMetadata( - const std::multimap<grpc::string, grpc::string>& metadata, - uint32_t flags) { - maybe_compression_level_.is_set = false; - send_ = true; - flags_ = flags; - initial_metadata_ = - FillMetadataArray(metadata, &initial_metadata_count_, ""); - } - - void set_compression_level(grpc_compression_level level) { - maybe_compression_level_.is_set = true; - maybe_compression_level_.level = level; - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_INITIAL_METADATA; - op->flags = flags_; - op->reserved = NULL; - op->data.send_initial_metadata.count = initial_metadata_count_; - op->data.send_initial_metadata.metadata = initial_metadata_; - op->data.send_initial_metadata.maybe_compression_level.is_set = - maybe_compression_level_.is_set; - if (maybe_compression_level_.is_set) { - op->data.send_initial_metadata.maybe_compression_level.level = - maybe_compression_level_.level; - } - } - void FinishOp(bool* status) { - if (!send_) return; - g_core_codegen_interface->gpr_free(initial_metadata_); - send_ = false; - } - - bool send_; - uint32_t flags_; - size_t initial_metadata_count_; - grpc_metadata* initial_metadata_; - struct { - bool is_set; - grpc_compression_level level; - } maybe_compression_level_; -}; - -class CallOpSendMessage { - public: - CallOpSendMessage() : send_buf_() {} - - /// Send \a message using \a options for the write. The \a options are cleared - /// after use. - template <class M> - Status SendMessage(const M& message, - WriteOptions options) GRPC_MUST_USE_RESULT; - - template <class M> - Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_buf_.Valid()) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_MESSAGE; - op->flags = write_options_.flags(); - op->reserved = NULL; - op->data.send_message.send_message = send_buf_.c_buffer(); - // Flags are per-message: clear them after use. - write_options_.Clear(); - } - void FinishOp(bool* status) { send_buf_.Clear(); } - - private: - ByteBuffer send_buf_; - WriteOptions write_options_; -}; - -template <class M> -Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { - write_options_ = options; - bool own_buf; - // TODO(vjpai): Remove the void below when possible - // The void in the template parameter below should not be needed - // (since it should be implicit) but is needed due to an observed - // difference in behavior between clang and gcc for certain internal users - Status result = SerializationTraits<M, void>::Serialize( - message, send_buf_.bbuf_ptr(), &own_buf); - if (!own_buf) { - send_buf_.Duplicate(); - } - return result; -} - -template <class M> -Status CallOpSendMessage::SendMessage(const M& message) { - return SendMessage(message, WriteOptions()); -} - -template <class R> -class CallOpRecvMessage { - public: - CallOpRecvMessage() - : got_message(false), - message_(nullptr), - allow_not_getting_message_(false) {} - - void RecvMessage(R* message) { message_ = message; } - - // Do not change status if no message is received. - void AllowNoMessage() { allow_not_getting_message_ = true; } - - bool got_message; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (message_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_MESSAGE; - op->flags = 0; - op->reserved = NULL; - op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); - } - - void FinishOp(bool* status) { - if (message_ == nullptr) return; - if (recv_buf_.Valid()) { - if (*status) { - got_message = *status = - SerializationTraits<R>::Deserialize(recv_buf_.bbuf_ptr(), message_) - .ok(); - recv_buf_.Release(); - } else { - got_message = false; - recv_buf_.Clear(); - } - } else { - got_message = false; - if (!allow_not_getting_message_) { - *status = false; - } - } - message_ = nullptr; - } - - private: - R* message_; - ByteBuffer recv_buf_; - bool allow_not_getting_message_; -}; - -class DeserializeFunc { - public: - virtual Status Deserialize(ByteBuffer* buf) = 0; - virtual ~DeserializeFunc() {} -}; - -template <class R> -class DeserializeFuncType final : public DeserializeFunc { - public: - DeserializeFuncType(R* message) : message_(message) {} - Status Deserialize(ByteBuffer* buf) override { - return SerializationTraits<R>::Deserialize(buf->bbuf_ptr(), message_); - } - - ~DeserializeFuncType() override {} - - private: - R* message_; // Not a managed pointer because management is external to this -}; - -class CallOpGenericRecvMessage { - public: - CallOpGenericRecvMessage() - : got_message(false), allow_not_getting_message_(false) {} - - template <class R> - void RecvMessage(R* message) { - // Use an explicit base class pointer to avoid resolution error in the - // following unique_ptr::reset for some old implementations. - DeserializeFunc* func = new DeserializeFuncType<R>(message); - deserialize_.reset(func); - } - - // Do not change status if no message is received. - void AllowNoMessage() { allow_not_getting_message_ = true; } - - bool got_message; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!deserialize_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_MESSAGE; - op->flags = 0; - op->reserved = NULL; - op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); - } - - void FinishOp(bool* status) { - if (!deserialize_) return; - if (recv_buf_.Valid()) { - if (*status) { - got_message = true; - *status = deserialize_->Deserialize(&recv_buf_).ok(); - recv_buf_.Release(); - } else { - got_message = false; - recv_buf_.Clear(); - } - } else { - got_message = false; - if (!allow_not_getting_message_) { - *status = false; - } - } - deserialize_.reset(); - } - - private: - std::unique_ptr<DeserializeFunc> deserialize_; - ByteBuffer recv_buf_; - bool allow_not_getting_message_; -}; - -class CallOpClientSendClose { - public: - CallOpClientSendClose() : send_(false) {} - - void ClientSendClose() { send_ = true; } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; - op->flags = 0; - op->reserved = NULL; - } - void FinishOp(bool* status) { send_ = false; } - - private: - bool send_; -}; - -class CallOpServerSendStatus { - public: - CallOpServerSendStatus() : send_status_available_(false) {} - - void ServerSendStatus( - const std::multimap<grpc::string, grpc::string>& trailing_metadata, - const Status& status) { - send_error_details_ = status.error_details(); - trailing_metadata_ = FillMetadataArray( - trailing_metadata, &trailing_metadata_count_, send_error_details_); - send_status_available_ = true; - send_status_code_ = static_cast<grpc_status_code>(status.error_code()); - send_error_message_ = status.error_message(); - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_status_available_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; - op->data.send_status_from_server.trailing_metadata_count = - trailing_metadata_count_; - op->data.send_status_from_server.trailing_metadata = trailing_metadata_; - op->data.send_status_from_server.status = send_status_code_; - error_message_slice_ = SliceReferencingString(send_error_message_); - op->data.send_status_from_server.status_details = - send_error_message_.empty() ? nullptr : &error_message_slice_; - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (!send_status_available_) return; - g_core_codegen_interface->gpr_free(trailing_metadata_); - send_status_available_ = false; - } - - private: - bool send_status_available_; - grpc_status_code send_status_code_; - grpc::string send_error_details_; - grpc::string send_error_message_; - size_t trailing_metadata_count_; - grpc_metadata* trailing_metadata_; - grpc_slice error_message_slice_; -}; - -class CallOpRecvInitialMetadata { - public: - CallOpRecvInitialMetadata() : metadata_map_(nullptr) {} - - void RecvInitialMetadata(ClientContext* context) { - context->initial_metadata_received_ = true; - metadata_map_ = &context->recv_initial_metadata_; - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (metadata_map_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_INITIAL_METADATA; - op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (metadata_map_ == nullptr) return; - metadata_map_ = nullptr; - } - - private: - MetadataMap* metadata_map_; -}; - -class CallOpClientRecvStatus { - public: - CallOpClientRecvStatus() - : recv_status_(nullptr), debug_error_string_(nullptr) {} - - void ClientRecvStatus(ClientContext* context, Status* status) { - client_context_ = context; - metadata_map_ = &client_context_->trailing_metadata_; - recv_status_ = status; - error_message_ = g_core_codegen_interface->grpc_empty_slice(); - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (recv_status_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; - op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); - op->data.recv_status_on_client.status = &status_code_; - op->data.recv_status_on_client.status_details = &error_message_; - op->data.recv_status_on_client.error_string = &debug_error_string_; - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (recv_status_ == nullptr) return; - grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); - *recv_status_ = - Status(static_cast<StatusCode>(status_code_), - GRPC_SLICE_IS_EMPTY(error_message_) - ? grpc::string() - : grpc::string(GRPC_SLICE_START_PTR(error_message_), - GRPC_SLICE_END_PTR(error_message_)), - binary_error_details); - client_context_->set_debug_error_string( - debug_error_string_ != nullptr ? debug_error_string_ : ""); - g_core_codegen_interface->grpc_slice_unref(error_message_); - if (debug_error_string_ != nullptr) { - g_core_codegen_interface->gpr_free((void*)debug_error_string_); - } - recv_status_ = nullptr; - } - - private: - ClientContext* client_context_; - MetadataMap* metadata_map_; - Status* recv_status_; - const char* debug_error_string_; - grpc_status_code status_code_; - grpc_slice error_message_; -}; - -/// An abstract collection of call ops, used to generate the -/// grpc_call_op structure to pass down to the lower layers, -/// and as it is-a CompletionQueueTag, also massages the final -/// completion into the correct form for consumption in the C++ -/// API. -class CallOpSetInterface : public CompletionQueueTag { - public: - /// Fills in grpc_op, starting from ops[*nops] and moving - /// upwards. - virtual void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) = 0; - - /// Get the tag to be used at the core completion queue. Generally, the - /// value of cq_tag will be "this". However, it can be overridden if we - /// want core to process the tag differently (e.g., as a core callback) - virtual void* cq_tag() = 0; -}; - -/// Primary implementation of CallOpSetInterface. -/// Since we cannot use variadic templates, we declare slots up to -/// the maximum count of ops we'll need in a set. We leverage the -/// empty base class optimization to slim this class (especially -/// when there are many unused slots used). To avoid duplicate base classes, -/// the template parmeter for CallNoOp is varied by argument position. -template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, - class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, - class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> -class CallOpSet : public CallOpSetInterface, - public Op1, - public Op2, - public Op3, - public Op4, - public Op5, - public Op6 { - public: - CallOpSet() : cq_tag_(this), return_tag_(this), call_(nullptr) {} - - // The copy constructor and assignment operator reset the value of - // cq_tag_ and return_tag_ since those are only meaningful on a specific - // object, not across objects. - CallOpSet(const CallOpSet& other) - : cq_tag_(this), return_tag_(this), call_(other.call_) {} - CallOpSet& operator=(const CallOpSet& other) { - cq_tag_ = this; - return_tag_ = this; - call_ = other.call_; - return *this; - } - - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override { - this->Op1::AddOp(ops, nops); - this->Op2::AddOp(ops, nops); - this->Op3::AddOp(ops, nops); - this->Op4::AddOp(ops, nops); - this->Op5::AddOp(ops, nops); - this->Op6::AddOp(ops, nops); - g_core_codegen_interface->grpc_call_ref(call); - call_ = call; - } - - bool FinalizeResult(void** tag, bool* status) override { - this->Op1::FinishOp(status); - this->Op2::FinishOp(status); - this->Op3::FinishOp(status); - this->Op4::FinishOp(status); - this->Op5::FinishOp(status); - this->Op6::FinishOp(status); - *tag = return_tag_; - - g_core_codegen_interface->grpc_call_unref(call_); - return true; - } - - void set_output_tag(void* return_tag) { return_tag_ = return_tag; } - - void* cq_tag() override { return cq_tag_; } - - /// set_cq_tag is used to provide a different core CQ tag than "this". - /// This is used for callback-based tags, where the core tag is the core - /// callback function. It does not change the use or behavior of any other - /// function (such as FinalizeResult) - void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } - - private: - void* cq_tag_; - void* return_tag_; - grpc_call* call_; -}; +class CallOpSetInterface; /// Straightforward wrapping of the C call object class Call final { public: + Call() + : call_hook_(nullptr), + cq_(nullptr), + call_(nullptr), + max_receive_message_size_(-1) {} /** call is owned by the caller */ Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) : call_hook_(call_hook), @@ -688,11 +48,20 @@ class Call final { max_receive_message_size_(-1) {} Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - int max_receive_message_size) + experimental::ClientRpcInfo* rpc_info) : call_hook_(call_hook), cq_(cq), call_(call), - max_receive_message_size_(max_receive_message_size) {} + max_receive_message_size_(-1), + client_rpc_info_(rpc_info) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_receive_message_size, experimental::ServerRpcInfo* rpc_info) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(max_receive_message_size), + server_rpc_info_(rpc_info) {} void PerformOps(CallOpSetInterface* ops) { call_hook_->PerformOpsOnCall(ops, this); @@ -703,11 +72,21 @@ class Call final { int max_receive_message_size() const { return max_receive_message_size_; } + experimental::ClientRpcInfo* client_rpc_info() const { + return client_rpc_info_; + } + + experimental::ServerRpcInfo* server_rpc_info() const { + return server_rpc_info_; + } + private: CallHook* call_hook_; CompletionQueue* cq_; grpc_call* call_; int max_receive_message_size_; + experimental::ClientRpcInfo* client_rpc_info_ = nullptr; + experimental::ServerRpcInfo* server_rpc_info_ = nullptr; }; } // namespace internal } // namespace grpc diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h new file mode 100644 index 0000000000..785688e67f --- /dev/null +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -0,0 +1,920 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H +#define GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H + +#include <assert.h> +#include <array> +#include <cstring> +#include <functional> +#include <map> +#include <memory> +#include <vector> + +#include <grpcpp/impl/codegen/byte_buffer.h> +#include <grpcpp/impl/codegen/call.h> +#include <grpcpp/impl/codegen/call_hook.h> +#include <grpcpp/impl/codegen/call_op_set_interface.h> +#include <grpcpp/impl/codegen/client_context.h> +#include <grpcpp/impl/codegen/completion_queue_tag.h> +#include <grpcpp/impl/codegen/config.h> +#include <grpcpp/impl/codegen/core_codegen_interface.h> +#include <grpcpp/impl/codegen/intercepted_channel.h> +#include <grpcpp/impl/codegen/interceptor_common.h> +#include <grpcpp/impl/codegen/serialization_traits.h> +#include <grpcpp/impl/codegen/slice.h> +#include <grpcpp/impl/codegen/string_ref.h> + +#include <grpc/impl/codegen/atm.h> +#include <grpc/impl/codegen/compression_types.h> +#include <grpc/impl/codegen/grpc_types.h> + +namespace grpc { + +class CompletionQueue; +extern CoreCodegenInterface* g_core_codegen_interface; + +namespace internal { +class Call; +class CallHook; + +// TODO(yangg) if the map is changed before we send, the pointers will be a +// mess. Make sure it does not happen. +inline grpc_metadata* FillMetadataArray( + const std::multimap<grpc::string, grpc::string>& metadata, + size_t* metadata_count, const grpc::string& optional_error_details) { + *metadata_count = metadata.size() + (optional_error_details.empty() ? 0 : 1); + if (*metadata_count == 0) { + return nullptr; + } + grpc_metadata* metadata_array = + (grpc_metadata*)(g_core_codegen_interface->gpr_malloc( + (*metadata_count) * sizeof(grpc_metadata))); + size_t i = 0; + for (auto iter = metadata.cbegin(); iter != metadata.cend(); ++iter, ++i) { + metadata_array[i].key = SliceReferencingString(iter->first); + metadata_array[i].value = SliceReferencingString(iter->second); + } + if (!optional_error_details.empty()) { + metadata_array[i].key = + g_core_codegen_interface->grpc_slice_from_static_buffer( + kBinaryErrorDetailsKey, sizeof(kBinaryErrorDetailsKey) - 1); + metadata_array[i].value = SliceReferencingString(optional_error_details); + } + return metadata_array; +} +} // namespace internal + +/// Per-message write options. +class WriteOptions { + public: + WriteOptions() : flags_(0), last_message_(false) {} + WriteOptions(const WriteOptions& other) + : flags_(other.flags_), last_message_(other.last_message_) {} + + /// Clear all flags. + inline void Clear() { flags_ = 0; } + + /// Returns raw flags bitset. + inline uint32_t flags() const { return flags_; } + + /// Sets flag for the disabling of compression for the next message write. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline WriteOptions& set_no_compression() { + SetBit(GRPC_WRITE_NO_COMPRESS); + return *this; + } + + /// Clears flag for the disabling of compression for the next message write. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline WriteOptions& clear_no_compression() { + ClearBit(GRPC_WRITE_NO_COMPRESS); + return *this; + } + + /// Get value for the flag indicating whether compression for the next + /// message write is forcefully disabled. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline bool get_no_compression() const { + return GetBit(GRPC_WRITE_NO_COMPRESS); + } + + /// Sets flag indicating that the write may be buffered and need not go out on + /// the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline WriteOptions& set_buffer_hint() { + SetBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + /// Clears flag indicating that the write may be buffered and need not go out + /// on the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline WriteOptions& clear_buffer_hint() { + ClearBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + /// Get value for the flag indicating that the write may be buffered and need + /// not go out on the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline bool get_buffer_hint() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } + + /// corked bit: aliases set_buffer_hint currently, with the intent that + /// set_buffer_hint will be removed in the future + inline WriteOptions& set_corked() { + SetBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + inline WriteOptions& clear_corked() { + ClearBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + inline bool is_corked() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } + + /// last-message bit: indicates this is the last message in a stream + /// client-side: makes Write the equivalent of performing Write, WritesDone + /// in a single step + /// server-side: hold the Write until the service handler returns (sync api) + /// or until Finish is called (async api) + inline WriteOptions& set_last_message() { + last_message_ = true; + return *this; + } + + /// Clears flag indicating that this is the last message in a stream, + /// disabling coalescing. + inline WriteOptions& clear_last_message() { + last_message_ = false; + return *this; + } + + /// Guarantee that all bytes have been written to the socket before completing + /// this write (usually writes are completed when they pass flow control). + inline WriteOptions& set_write_through() { + SetBit(GRPC_WRITE_THROUGH); + return *this; + } + + inline bool is_write_through() const { return GetBit(GRPC_WRITE_THROUGH); } + + /// Get value for the flag indicating that this is the last message, and + /// should be coalesced with trailing metadata. + /// + /// \sa GRPC_WRITE_LAST_MESSAGE + bool is_last_message() const { return last_message_; } + + WriteOptions& operator=(const WriteOptions& rhs) { + flags_ = rhs.flags_; + return *this; + } + + private: + void SetBit(const uint32_t mask) { flags_ |= mask; } + + void ClearBit(const uint32_t mask) { flags_ &= ~mask; } + + bool GetBit(const uint32_t mask) const { return (flags_ & mask) != 0; } + + uint32_t flags_; + bool last_message_; +}; + +namespace internal { + +/// Default argument for CallOpSet. I is unused by the class, but can be +/// used for generating multiple names for the same thing. +template <int I> +class CallNoOp { + protected: + void AddOp(grpc_op* ops, size_t* nops) {} + void FinishOp(bool* status) {} + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + } +}; + +class CallOpSendInitialMetadata { + public: + CallOpSendInitialMetadata() : send_(false) { + maybe_compression_level_.is_set = false; + } + + void SendInitialMetadata(std::multimap<grpc::string, grpc::string>* metadata, + uint32_t flags) { + maybe_compression_level_.is_set = false; + send_ = true; + flags_ = flags; + metadata_map_ = metadata; + } + + void set_compression_level(grpc_compression_level level) { + maybe_compression_level_.is_set = true; + maybe_compression_level_.level = level; + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->flags = flags_; + op->reserved = NULL; + initial_metadata_ = + FillMetadataArray(*metadata_map_, &initial_metadata_count_, ""); + op->data.send_initial_metadata.count = initial_metadata_count_; + op->data.send_initial_metadata.metadata = initial_metadata_; + op->data.send_initial_metadata.maybe_compression_level.is_set = + maybe_compression_level_.is_set; + if (maybe_compression_level_.is_set) { + op->data.send_initial_metadata.maybe_compression_level.level = + maybe_compression_level_.level; + } + } + void FinishOp(bool* status) { + if (!send_ || hijacked_) return; + g_core_codegen_interface->gpr_free(initial_metadata_); + send_ = false; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); + interceptor_methods->SetSendInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + bool hijacked_ = false; + bool send_; + uint32_t flags_; + size_t initial_metadata_count_; + std::multimap<grpc::string, grpc::string>* metadata_map_; + grpc_metadata* initial_metadata_; + struct { + bool is_set; + grpc_compression_level level; + } maybe_compression_level_; +}; + +class CallOpSendMessage { + public: + CallOpSendMessage() : send_buf_() {} + + /// Send \a message using \a options for the write. The \a options are cleared + /// after use. + template <class M> + Status SendMessage(const M& message, + WriteOptions options) GRPC_MUST_USE_RESULT; + + template <class M> + Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_buf_.Valid() || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_MESSAGE; + op->flags = write_options_.flags(); + op->reserved = NULL; + op->data.send_message.send_message = send_buf_.c_buffer(); + // Flags are per-message: clear them after use. + write_options_.Clear(); + } + void FinishOp(bool* status) { send_buf_.Clear(); } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_buf_.Valid()) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); + interceptor_methods->SetSendMessage(&send_buf_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + ByteBuffer send_buf_; + WriteOptions write_options_; +}; + +template <class M> +Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { + write_options_ = options; + bool own_buf; + // TODO(vjpai): Remove the void below when possible + // The void in the template parameter below should not be needed + // (since it should be implicit) but is needed due to an observed + // difference in behavior between clang and gcc for certain internal users + Status result = SerializationTraits<M, void>::Serialize( + message, send_buf_.bbuf_ptr(), &own_buf); + if (!own_buf) { + send_buf_.Duplicate(); + } + return result; +} + +template <class M> +Status CallOpSendMessage::SendMessage(const M& message) { + return SendMessage(message, WriteOptions()); +} + +template <class R> +class CallOpRecvMessage { + public: + CallOpRecvMessage() + : got_message(false), + message_(nullptr), + allow_not_getting_message_(false) {} + + void RecvMessage(R* message) { message_ = message; } + + // Do not change status if no message is received. + void AllowNoMessage() { allow_not_getting_message_ = true; } + + bool got_message; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (message_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_MESSAGE; + op->flags = 0; + op->reserved = NULL; + op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); + } + + void FinishOp(bool* status) { + if (message_ == nullptr || hijacked_) return; + if (recv_buf_.Valid()) { + if (*status) { + got_message = *status = + SerializationTraits<R>::Deserialize(recv_buf_.bbuf_ptr(), message_) + .ok(); + recv_buf_.Release(); + } else { + got_message = false; + recv_buf_.Clear(); + } + } else { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + message_ = nullptr; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (message_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + got_message = true; + } + + private: + R* message_; + ByteBuffer recv_buf_; + bool allow_not_getting_message_; + bool hijacked_ = false; +}; + +class DeserializeFunc { + public: + virtual Status Deserialize(ByteBuffer* buf) = 0; + virtual ~DeserializeFunc() {} +}; + +template <class R> +class DeserializeFuncType final : public DeserializeFunc { + public: + DeserializeFuncType(R* message) : message_(message) {} + Status Deserialize(ByteBuffer* buf) override { + return SerializationTraits<R>::Deserialize(buf->bbuf_ptr(), message_); + } + + ~DeserializeFuncType() override {} + + private: + R* message_; // Not a managed pointer because management is external to this +}; + +class CallOpGenericRecvMessage { + public: + CallOpGenericRecvMessage() + : got_message(false), allow_not_getting_message_(false) {} + + template <class R> + void RecvMessage(R* message) { + // Use an explicit base class pointer to avoid resolution error in the + // following unique_ptr::reset for some old implementations. + DeserializeFunc* func = new DeserializeFuncType<R>(message); + deserialize_.reset(func); + message_ = message; + } + + // Do not change status if no message is received. + void AllowNoMessage() { allow_not_getting_message_ = true; } + + bool got_message; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!deserialize_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_MESSAGE; + op->flags = 0; + op->reserved = NULL; + op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); + } + + void FinishOp(bool* status) { + if (!deserialize_ || hijacked_) return; + if (recv_buf_.Valid()) { + if (*status) { + got_message = true; + *status = deserialize_->Deserialize(&recv_buf_).ok(); + recv_buf_.Release(); + } else { + got_message = false; + recv_buf_.Clear(); + } + } else { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + deserialize_.reset(); + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (!deserialize_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + } + + private: + void* message_; + bool hijacked_ = false; + std::unique_ptr<DeserializeFunc> deserialize_; + ByteBuffer recv_buf_; + bool allow_not_getting_message_; +}; + +class CallOpClientSendClose { + public: + CallOpClientSendClose() : send_(false) {} + + void ClientSendClose() { send_ = true; } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = NULL; + } + void FinishOp(bool* status) { send_ = false; } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + bool send_; +}; + +class CallOpServerSendStatus { + public: + CallOpServerSendStatus() : send_status_available_(false) {} + + void ServerSendStatus( + std::multimap<grpc::string, grpc::string>* trailing_metadata, + const Status& status) { + send_error_details_ = status.error_details(); + metadata_map_ = trailing_metadata; + send_status_available_ = true; + send_status_code_ = static_cast<grpc_status_code>(status.error_code()); + send_error_message_ = status.error_message(); + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_status_available_ || hijacked_) return; + trailing_metadata_ = FillMetadataArray( + *metadata_map_, &trailing_metadata_count_, send_error_details_); + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = + trailing_metadata_count_; + op->data.send_status_from_server.trailing_metadata = trailing_metadata_; + op->data.send_status_from_server.status = send_status_code_; + error_message_slice_ = SliceReferencingString(send_error_message_); + op->data.send_status_from_server.status_details = + send_error_message_.empty() ? nullptr : &error_message_slice_; + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (!send_status_available_ || hijacked_) return; + g_core_codegen_interface->gpr_free(trailing_metadata_); + send_status_available_ = false; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_status_available_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS); + interceptor_methods->SetSendTrailingMetadata(metadata_map_); + interceptor_methods->SetSendStatus(&send_status_code_, &send_error_details_, + &send_error_message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + bool send_status_available_; + grpc_status_code send_status_code_; + grpc::string send_error_details_; + grpc::string send_error_message_; + size_t trailing_metadata_count_; + std::multimap<grpc::string, grpc::string>* metadata_map_; + grpc_metadata* trailing_metadata_; + grpc_slice error_message_slice_; +}; + +class CallOpRecvInitialMetadata { + public: + CallOpRecvInitialMetadata() : metadata_map_(nullptr) {} + + void RecvInitialMetadata(ClientContext* context) { + context->initial_metadata_received_ = true; + metadata_map_ = &context->recv_initial_metadata_; + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (metadata_map_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (metadata_map_ == nullptr || hijacked_) return; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + metadata_map_ = nullptr; + } + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA); + } + + private: + bool hijacked_ = false; + MetadataMap* metadata_map_; +}; + +class CallOpClientRecvStatus { + public: + CallOpClientRecvStatus() + : recv_status_(nullptr), debug_error_string_(nullptr) {} + + void ClientRecvStatus(ClientContext* context, Status* status) { + client_context_ = context; + metadata_map_ = &client_context_->trailing_metadata_; + recv_status_ = status; + error_message_ = g_core_codegen_interface->grpc_empty_slice(); + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (recv_status_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); + op->data.recv_status_on_client.status = &status_code_; + op->data.recv_status_on_client.status_details = &error_message_; + op->data.recv_status_on_client.error_string = &debug_error_string_; + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (recv_status_ == nullptr || hijacked_) return; + grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); + *recv_status_ = + Status(static_cast<StatusCode>(status_code_), + GRPC_SLICE_IS_EMPTY(error_message_) + ? grpc::string() + : grpc::string(GRPC_SLICE_START_PTR(error_message_), + GRPC_SLICE_END_PTR(error_message_)), + binary_error_details); + client_context_->set_debug_error_string( + debug_error_string_ != nullptr ? debug_error_string_ : ""); + g_core_codegen_interface->grpc_slice_unref(error_message_); + if (debug_error_string_ != nullptr) { + g_core_codegen_interface->gpr_free((void*)debug_error_string_); + } + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvStatus(recv_status_); + interceptor_methods->SetRecvTrailingMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS); + recv_status_ = nullptr; + } + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS); + } + + private: + bool hijacked_ = false; + ClientContext* client_context_; + MetadataMap* metadata_map_; + Status* recv_status_; + const char* debug_error_string_; + grpc_status_code status_code_; + grpc_slice error_message_; +}; + +template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, + class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, + class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +class CallOpSet; + +/// Primary implementation of CallOpSetInterface. +/// Since we cannot use variadic templates, we declare slots up to +/// the maximum count of ops we'll need in a set. We leverage the +/// empty base class optimization to slim this class (especially +/// when there are many unused slots used). To avoid duplicate base classes, +/// the template parmeter for CallNoOp is varied by argument position. +template <class Op1, class Op2, class Op3, class Op4, class Op5, class Op6> +class CallOpSet : public CallOpSetInterface, + public Op1, + public Op2, + public Op3, + public Op4, + public Op5, + public Op6 { + public: + CallOpSet() : cq_tag_(this), return_tag_(this) {} + // The copy constructor and assignment operator reset the value of + // cq_tag_, return_tag_, done_intercepting_ and interceptor_methods_ since + // those are only meaningful on a specific object, not across objects. + CallOpSet(const CallOpSet& other) + : cq_tag_(this), + return_tag_(this), + call_(other.call_), + done_intercepting_(false), + interceptor_methods_(InterceptorBatchMethodsImpl()) {} + + CallOpSet& operator=(const CallOpSet& other) { + cq_tag_ = this; + return_tag_ = this; + call_ = other.call_; + done_intercepting_ = false; + interceptor_methods_ = InterceptorBatchMethodsImpl(); + return *this; + } + + void FillOps(Call* call) override { + done_intercepting_ = false; + g_core_codegen_interface->grpc_call_ref(call->call()); + call_ = + *call; // It's fine to create a copy of call since it's just pointers + + if (RunInterceptors()) { + ContinueFillOpsAfterInterception(); + } else { + // After the interceptors are run, ContinueFillOpsAfterInterception will + // be run + } + } + + bool FinalizeResult(void** tag, bool* status) override { + if (done_intercepting_) { + // We have already finished intercepting and filling in the results. This + // round trip from the core needed to be made because interceptors were + // run + *tag = return_tag_; + *status = saved_status_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + + this->Op1::FinishOp(status); + this->Op2::FinishOp(status); + this->Op3::FinishOp(status); + this->Op4::FinishOp(status); + this->Op5::FinishOp(status); + this->Op6::FinishOp(status); + saved_status_ = *status; + if (RunInterceptorsPostRecv()) { + *tag = return_tag_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + // Interceptors are going to be run, so we can't return the tag just yet. + // After the interceptors are run, ContinueFinalizeResultAfterInterception + return false; + } + + void set_output_tag(void* return_tag) { return_tag_ = return_tag; } + + void* cq_tag() override { return cq_tag_; } + + /// set_cq_tag is used to provide a different core CQ tag than "this". + /// This is used for callback-based tags, where the core tag is the core + /// callback function. It does not change the use or behavior of any other + /// function (such as FinalizeResult) + void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + this->Op1::SetHijackingState(&interceptor_methods_); + this->Op2::SetHijackingState(&interceptor_methods_); + this->Op3::SetHijackingState(&interceptor_methods_); + this->Op4::SetHijackingState(&interceptor_methods_); + this->Op5::SetHijackingState(&interceptor_methods_); + this->Op6::SetHijackingState(&interceptor_methods_); + } + + // Should be called after interceptors are done running + void ContinueFillOpsAfterInterception() override { + static const size_t MAX_OPS = 6; + grpc_op ops[MAX_OPS]; + size_t nops = 0; + this->Op1::AddOp(ops, &nops); + this->Op2::AddOp(ops, &nops); + this->Op3::AddOp(ops, &nops); + this->Op4::AddOp(ops, &nops); + this->Op5::AddOp(ops, &nops); + this->Op6::AddOp(ops, &nops); + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), ops, nops, cq_tag(), nullptr)); + } + + // Should be called after interceptors are done running on the finalize result + // path + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, cq_tag(), nullptr)); + } + + private: + // Returns true if no interceptors need to be run + bool RunInterceptors() { + interceptor_methods_.ClearState(); + interceptor_methods_.SetCallOpSetInterface(this); + interceptor_methods_.SetCall(&call_); + this->Op1::SetInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetInterceptionHookPoint(&interceptor_methods_); + return interceptor_methods_.RunInterceptors(); + } + // Returns true if no interceptors need to be run + bool RunInterceptorsPostRecv() { + // Call and OpSet had already been set on the set state. + // SetReverse also clears previously set hook points + interceptor_methods_.SetReverse(); + this->Op1::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetFinishInterceptionHookPoint(&interceptor_methods_); + return interceptor_methods_.RunInterceptors(); + } + + void* cq_tag_; + void* return_tag_; + Call call_; + bool done_intercepting_ = false; + InterceptorBatchMethodsImpl interceptor_methods_; + bool saved_status_; +}; + +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H diff --git a/include/grpcpp/impl/codegen/call_op_set_interface.h b/include/grpcpp/impl/codegen/call_op_set_interface.h new file mode 100644 index 0000000000..815227a299 --- /dev/null +++ b/include/grpcpp/impl/codegen/call_op_set_interface.h @@ -0,0 +1,59 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H +#define GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H + +#include <grpcpp/impl/codegen/completion_queue_tag.h> + +namespace grpc { +namespace internal { + +class Call; + +/// An abstract collection of call ops, used to generate the +/// grpc_call_op structure to pass down to the lower layers, +/// and as it is-a CompletionQueueTag, also massages the final +/// completion into the correct form for consumption in the C++ +/// API. +class CallOpSetInterface : public CompletionQueueTag { + public: + /// Fills in grpc_op, starting from ops[*nops] and moving + /// upwards. + virtual void FillOps(internal::Call* call) = 0; + + /// Get the tag to be used at the core completion queue. Generally, the + /// value of cq_tag will be "this". However, it can be overridden if we + /// want core to process the tag differently (e.g., as a core callback) + virtual void* cq_tag() = 0; + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + virtual void SetHijackingState() = 0; + + // Should be called after interceptors are done running + virtual void ContinueFillOpsAfterInterception() = 0; + + // Should be called after interceptors are done running on the finalize result + // path + virtual void ContinueFinalizeResultAfterInterception() = 0; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index a9835973ac..eba9ec6edc 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -94,7 +94,10 @@ class CallbackWithStatusTag void Run(bool ok) { void* ignored = ops_; - GPR_CODEGEN_ASSERT(ops_->FinalizeResult(&ignored, &ok)); + if (!ops_->FinalizeResult(&ignored, &ok)) { + // The tag was swallowed + return; + } GPR_CODEGEN_ASSERT(ignored == ops_); // Last use of func_ or status_, so ok to move them out diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index b257acc1ab..6fd1dd1d9b 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -20,6 +20,7 @@ #define GRPCPP_IMPL_CODEGEN_CHANNEL_INTERFACE_H #include <grpc/impl/codegen/connectivity_state.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/status.h> #include <grpcpp/impl/codegen/time.h> @@ -51,6 +52,7 @@ template <class W, class R> class ClientAsyncReaderWriterFactory; template <class R> class ClientAsyncResponseReaderFactory; +class InterceptedChannel; } // namespace internal /// Codegen interface for \a grpc::Channel. @@ -108,6 +110,7 @@ class ChannelInterface { template <class InputMessage, class OutputMessage> friend class ::grpc::internal::CallbackUnaryCallImpl; friend class ::grpc::internal::RpcMethod; + friend class ::grpc::internal::InterceptedChannel; virtual internal::Call CreateCall(const internal::RpcMethod& method, ClientContext* context, CompletionQueue* cq) = 0; @@ -121,6 +124,20 @@ class ChannelInterface { gpr_timespec deadline) = 0; // EXPERIMENTAL + // This is needed to keep codegen_test_minimal happy. InterceptedChannel needs + // to make use of this but can't directly call Channel's implementation + // because of the test. + // Returns an empty Call object (rather than being pure) since this is a new + // method and adding a new pure method to an interface would be a breaking + // change (even though this is private and non-API) + virtual internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + int interceptor_pos) { + return internal::Call(); + } + + // EXPERIMENTAL // A method to get the callbackable completion queue associated with this // channel. If the return value is nullptr, this channel doesn't support // callback operations. diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4d4faea063..ecb00a0769 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -77,7 +77,7 @@ class CallbackUnaryCallImpl { tag->force_run(s); return; } - ops->SendInitialMetadata(context->send_initial_metadata_, + ops->SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops->RecvInitialMetadata(context); ops->RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 24f5c431ce..f53b744dcf 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -41,6 +41,7 @@ #include <grpc/impl/codegen/compression_types.h> #include <grpc/impl/codegen/propagation_bits.h> +#include <grpcpp/impl/codegen/client_interceptor.h> #include <grpcpp/impl/codegen/config.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/create_auth_context.h> @@ -402,6 +403,17 @@ class ClientContext { grpc_call* call() const { return call_; } void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel); + experimental::ClientRpcInfo* set_client_rpc_info( + const char* method, grpc::ChannelInterface* channel, + const std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>& + creators, + size_t interceptor_pos) { + rpc_info_ = experimental::ClientRpcInfo(this, method, channel); + rpc_info_.RegisterInterceptors(creators, interceptor_pos); + return &rpc_info_; + } + uint32_t initial_metadata_flags() const { return (idempotent_ ? GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST : 0) | (wait_for_ready_ ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0) | @@ -439,6 +451,8 @@ class ClientContext { bool initial_metadata_corked_; grpc::string debug_error_string_; + + experimental::ClientRpcInfo rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f460c5ac0c..00113f04aa 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -19,23 +19,76 @@ #ifndef GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H +#include <vector> + +#include <grpc/impl/codegen/log.h> #include <grpcpp/impl/codegen/interceptor.h> +#include <grpcpp/impl/codegen/string_ref.h> namespace grpc { -namespace experimental { -class ClientInterceptor { - public: - virtual ~ClientInterceptor() {} - virtual void Intercept(InterceptorBatchMethods* methods) = 0; -}; +class ClientContext; +class Channel; -class ClientRpcInfo {}; +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ClientRpcInfo; class ClientInterceptorFactoryInterface { public: virtual ~ClientInterceptorFactoryInterface() {} - virtual ClientInterceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; + virtual Interceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; +}; + +class ClientRpcInfo { + public: + ClientRpcInfo() {} + + ~ClientRpcInfo(){}; + + ClientRpcInfo(const ClientRpcInfo&) = delete; + ClientRpcInfo(ClientRpcInfo&&) = default; + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + ChannelInterface* channel() { return channel_; } + grpc::ClientContext* client_context() { return ctx_; } + + private: + ClientRpcInfo(grpc::ClientContext* ctx, const char* method, + grpc::ChannelInterface* channel) + : ctx_(ctx), method_(method), channel_(channel) {} + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + void RegisterInterceptors( + const std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>& creators, + int interceptor_pos) { + for (auto it = creators.begin() + interceptor_pos; it != creators.end(); + ++it) { + interceptors_.push_back(std::unique_ptr<experimental::Interceptor>( + (*it)->CreateClientInterceptor(this))); + } + } + + grpc::ClientContext* ctx_ = nullptr; + const char* method_ = nullptr; + grpc::ChannelInterface* channel_ = nullptr; + std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; + bool hijacked_ = false; + size_t hijacked_interceptor_ = 0; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ClientContext; }; } // namespace experimental diff --git a/include/grpcpp/impl/codegen/client_unary_call.h b/include/grpcpp/impl/codegen/client_unary_call.h index e4e8364e07..b1c80764f2 100644 --- a/include/grpcpp/impl/codegen/client_unary_call.h +++ b/include/grpcpp/impl/codegen/client_unary_call.h @@ -61,7 +61,7 @@ class BlockingUnaryCallImpl { if (!status_.ok()) { return; } - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops.RecvInitialMetadata(context); ops.RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h index 4ca19f4672..5eef2c281f 100644 --- a/include/grpcpp/impl/codegen/completion_queue.h +++ b/include/grpcpp/impl/codegen/completion_queue.h @@ -300,14 +300,17 @@ class CompletionQueue : private GrpcLibraryCodegen { bool Pluck(internal::CompletionQueueTag* tag) { auto deadline = g_core_codegen_interface->gpr_inf_future(GPR_CLOCK_REALTIME); - auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( - cq_, tag, deadline, nullptr); - bool ok = ev.success != 0; - void* ignored = tag; - GPR_CODEGEN_ASSERT(tag->FinalizeResult(&ignored, &ok)); - GPR_CODEGEN_ASSERT(ignored == tag); - // Ignore mutations by FinalizeResult: Pluck returns the C API status - return ev.success != 0; + while (true) { + auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( + cq_, tag, deadline, nullptr); + bool ok = ev.success != 0; + void* ignored = tag; + if (tag->FinalizeResult(&ignored, &ok)) { + GPR_CODEGEN_ASSERT(ignored == tag); + // Ignore mutations by FinalizeResult: Pluck returns the C API status + return ev.success != 0; + } + } } /// Performs a single polling pluck on \a tag. diff --git a/include/grpcpp/impl/codegen/core_codegen.h b/include/grpcpp/impl/codegen/core_codegen.h index e9df96bf04..6ef184d01a 100644 --- a/include/grpcpp/impl/codegen/core_codegen.h +++ b/include/grpcpp/impl/codegen/core_codegen.h @@ -63,6 +63,9 @@ class CoreCodegen final : public CoreCodegenInterface { void gpr_cv_signal(gpr_cv* cv) override; void gpr_cv_broadcast(gpr_cv* cv) override; + grpc_call_error grpc_call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* tag, + void* reserved) override; grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/core_codegen_interface.h b/include/grpcpp/impl/codegen/core_codegen_interface.h index 1167a188a2..25e3abccca 100644 --- a/include/grpcpp/impl/codegen/core_codegen_interface.h +++ b/include/grpcpp/impl/codegen/core_codegen_interface.h @@ -100,6 +100,9 @@ class CoreCodegenInterface { virtual grpc_slice grpc_slice_new_with_len(void* p, size_t len, void (*destroy)(void*, size_t)) = 0; + virtual grpc_call_error grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, size_t nops, + void* tag, void* reserved) = 0; virtual grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/intercepted_channel.h b/include/grpcpp/impl/codegen/intercepted_channel.h new file mode 100644 index 0000000000..612e56d862 --- /dev/null +++ b/include/grpcpp/impl/codegen/intercepted_channel.h @@ -0,0 +1,80 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H +#define GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H + +#include <grpcpp/impl/codegen/channel_interface.h> + +namespace grpc { + +namespace internal { + +class InterceptorBatchMethodsImpl; + +/// An InterceptedChannel is available to client Interceptors. An +/// InterceptedChannel is unique to an interceptor, and when an RPC is started +/// on this channel, only those interceptors that come after this interceptor +/// see the RPC. +class InterceptedChannel : public ChannelInterface { + public: + virtual ~InterceptedChannel() { channel_ = nullptr; } + + /// Get the current channel state. If the channel is in IDLE and + /// \a try_to_connect is set to true, try to connect. + grpc_connectivity_state GetState(bool try_to_connect) override { + return channel_->GetState(try_to_connect); + } + + private: + InterceptedChannel(ChannelInterface* channel, int pos) + : channel_(channel), interceptor_pos_(pos) {} + + Call CreateCall(const RpcMethod& method, ClientContext* context, + CompletionQueue* cq) override { + return channel_->CreateCallInternal(method, context, cq, interceptor_pos_); + } + + void PerformOpsOnCall(CallOpSetInterface* ops, Call* call) override { + return channel_->PerformOpsOnCall(ops, call); + } + void* RegisterMethod(const char* method) override { + return channel_->RegisterMethod(method); + } + + void NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, CompletionQueue* cq, + void* tag) override { + return channel_->NotifyOnStateChangeImpl(last_observed, deadline, cq, tag); + } + bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) override { + return channel_->WaitForStateChangeImpl(last_observed, deadline); + } + + CompletionQueue* CallbackCQ() override { return channel_->CallbackCQ(); } + + ChannelInterface* channel_; + int interceptor_pos_; + + friend class InterceptorBatchMethodsImpl; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 6402a3a946..cdd34b80d1 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -19,7 +19,17 @@ #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H +#include <grpc/impl/codegen/grpc_types.h> +#include <grpcpp/impl/codegen/byte_buffer.h> +#include <grpcpp/impl/codegen/channel_interface.h> +#include <grpcpp/impl/codegen/config.h> +#include <grpcpp/impl/codegen/core_codegen_interface.h> +#include <grpcpp/impl/codegen/metadata_map.h> + namespace grpc { + +class Status; + namespace experimental { class InterceptedMessage { public: @@ -35,6 +45,7 @@ enum class InterceptionHookPoints { PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, + PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be registered by the global interceptor */ PRE_RECV_INITIAL_METADATA, @@ -50,7 +61,7 @@ enum class InterceptionHookPoints { class InterceptorBatchMethods { public: - virtual ~InterceptorBatchMethods(); + virtual ~InterceptorBatchMethods(){}; // Queries to check whether the current batch has an interception hook point // of type \a type virtual bool QueryInterceptionHookPoint(InterceptionHookPoints type) = 0; @@ -60,7 +71,53 @@ class InterceptorBatchMethods { // Calling this indicates that the interceptor has hijacked the RPC (only // valid if the batch contains send_initial_metadata on the client side) virtual void Hijack() = 0; + + // Returns a modifable ByteBuffer holding serialized form of the message to be + // sent + virtual ByteBuffer* GetSendMessage() = 0; + + // Returns a modifiable multimap of the initial metadata to be sent + virtual std::multimap<grpc::string, grpc::string>* + GetSendInitialMetadata() = 0; + + // Returns the status to be sent + virtual Status GetSendStatus() = 0; + + // Modifies the status with \a status + virtual void ModifySendStatus(const Status& status) = 0; + + // Returns a modifiable multimap of the trailing metadata to be sent + virtual std::multimap<grpc::string, grpc::string>* + GetSendTrailingMetadata() = 0; + + // Returns a pointer to the modifiable received message. Note that the message + // is already deserialized + virtual void* GetRecvMessage() = 0; + + // Returns a modifiable multimap of the received initial metadata + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvInitialMetadata() = 0; + + // Returns a modifiable view of the received status + virtual Status* GetRecvStatus() = 0; + + // Returns a modifiable multimap of the received trailing metadata + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvTrailingMetadata() = 0; + + // Gets an intercepted channel. When a call is started on this interceptor, + // only interceptors after the current interceptor are created from the + // factory objects registered with the channel. + virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; }; + +class Interceptor { + public: + virtual ~Interceptor() {} + + virtual void Intercept(InterceptorBatchMethods* methods) = 0; +}; + } // namespace experimental } // namespace grpc diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h new file mode 100644 index 0000000000..cf564977f6 --- /dev/null +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -0,0 +1,383 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H +#define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H + +#include <grpcpp/impl/codegen/client_interceptor.h> +#include <grpcpp/impl/codegen/server_interceptor.h> + +#include <grpc/impl/codegen/grpc_types.h> + +namespace grpc { +namespace internal { + +/// Internal methods for setting the state +class InternalInterceptorBatchMethods + : public experimental::InterceptorBatchMethods { + public: + virtual ~InternalInterceptorBatchMethods() {} + + virtual void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) = 0; + + virtual void SetSendMessage(ByteBuffer* buf) = 0; + + virtual void SetSendInitialMetadata( + std::multimap<grpc::string, grpc::string>* metadata) = 0; + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) = 0; + + virtual void SetSendTrailingMetadata( + std::multimap<grpc::string, grpc::string>* metadata) = 0; + + virtual void SetRecvMessage(void* message) = 0; + + virtual void SetRecvInitialMetadata(MetadataMap* map) = 0; + + virtual void SetRecvStatus(Status* status) = 0; + + virtual void SetRecvTrailingMetadata(MetadataMap* map) = 0; +}; + +class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { + public: + InterceptorBatchMethodsImpl() { + for (auto i = static_cast<experimental::InterceptionHookPoints>(0); + i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; + i = static_cast<experimental::InterceptionHookPoints>( + static_cast<size_t>(i) + 1)) { + hooks_[static_cast<size_t>(i)] = false; + } + } + + ~InterceptorBatchMethodsImpl() {} + + bool QueryInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + return hooks_[static_cast<size_t>(type)]; + } + + void Proceed() override { /* fill this */ + if (call_->client_rpc_info() != nullptr) { + return ProceedClient(); + } + GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); + ProceedServer(); + } + + void Hijack() override { + // Only the client can hijack when sending down initial metadata + GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && + call_->client_rpc_info() != nullptr); + // It is illegal to call Hijack twice + GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); + auto* rpc_info = call_->client_rpc_info(); + rpc_info->hijacked_ = true; + rpc_info->hijacked_interceptor_ = current_interceptor_index_; + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + hooks_[static_cast<size_t>(type)] = true; + } + + ByteBuffer* GetSendMessage() override { return send_message_; } + + std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { + return send_initial_metadata_; + } + + Status GetSendStatus() override { + return Status(static_cast<StatusCode>(*code_), *error_message_, + *error_details_); + } + + void ModifySendStatus(const Status& status) override { + *code_ = static_cast<grpc_status_code>(status.error_code()); + *error_details_ = status.error_details(); + *error_message_ = status.error_message(); + } + + std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() + override { + return send_trailing_metadata_; + } + + void* GetRecvMessage() override { return recv_message_; } + + std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() + override { + return recv_initial_metadata_->map(); + } + + Status* GetRecvStatus() override { return recv_status_; } + + std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() + override { + return recv_trailing_metadata_->map(); + } + + void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } + + void SetSendInitialMetadata( + std::multimap<grpc::string, grpc::string>* metadata) override { + send_initial_metadata_ = metadata; + } + + void SetSendStatus(grpc_status_code* code, grpc::string* error_details, + grpc::string* error_message) override { + code_ = code; + error_details_ = error_details; + error_message_ = error_message; + } + + void SetSendTrailingMetadata( + std::multimap<grpc::string, grpc::string>* metadata) override { + send_trailing_metadata_ = metadata; + } + + void SetRecvMessage(void* message) override { recv_message_ = message; } + + void SetRecvInitialMetadata(MetadataMap* map) override { + recv_initial_metadata_ = map; + } + + void SetRecvStatus(Status* status) override { recv_status_ = status; } + + void SetRecvTrailingMetadata(MetadataMap* map) override { + recv_trailing_metadata_ = map; + } + + std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { + auto* info = call_->client_rpc_info(); + if (info == nullptr) { + return std::unique_ptr<ChannelInterface>(nullptr); + } + // The intercepted channel starts from the interceptor just after the + // current interceptor + return std::unique_ptr<ChannelInterface>(new InterceptedChannel( + info->channel(), current_interceptor_index_ + 1)); + } + + // Clears all state + void ClearState() { + reverse_ = false; + ran_hijacking_interceptor_ = false; + ClearHookPoints(); + } + + // Prepares for Post_recv operations + void SetReverse() { + reverse_ = true; + ran_hijacking_interceptor_ = false; + ClearHookPoints(); + } + + // This needs to be set before interceptors are run + void SetCall(Call* call) { call_ = call; } + + // This needs to be set before interceptors are run using RunInterceptors(). + // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. + void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } + + // Returns true if no interceptors are run. This should be used only by + // subclasses of CallOpSetInterface. SetCall and SetCallOpSetInterface should + // have been called before this. After all the interceptors are done running, + // either ContinueFillOpsAfterInterception or + // ContinueFinalizeOpsAfterInterception will be called. Note that neither of + // them is invoked if there were no interceptors registered. + bool RunInterceptors() { + GPR_CODEGEN_ASSERT(ops_); + auto* client_rpc_info = call_->client_rpc_info(); + if (client_rpc_info != nullptr) { + if (client_rpc_info->interceptors_.size() == 0) { + return true; + } else { + RunClientInterceptors(); + return false; + } + } + + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + RunServerInterceptors(); + return false; + } + + // Returns true if no interceptors are run. Returns false otherwise if there + // are interceptors registered. After the interceptors are done running \a f + // will be invoked. This is to be used only by BaseAsyncRequest and + // SyncRequest. + bool RunInterceptors(std::function<void(void)> f) { + // This is used only by the server for initial call request + GPR_CODEGEN_ASSERT(reverse_ == true); + GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + callback_ = std::move(f); + RunServerInterceptors(); + return false; + } + + private: + void RunClientInterceptors() { + auto* rpc_info = call_->client_rpc_info(); + if (!reverse_) { + current_interceptor_index_ = 0; + } else { + if (rpc_info->hijacked_) { + current_interceptor_index_ = rpc_info->hijacked_interceptor_; + } else { + current_interceptor_index_ = rpc_info->interceptors_.size() - 1; + } + } + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void RunServerInterceptors() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + current_interceptor_index_ = 0; + } else { + current_interceptor_index_ = rpc_info->interceptors_.size() - 1; + } + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void ProceedClient() { + auto* rpc_info = call_->client_rpc_info(); + if (rpc_info->hijacked_ && !reverse_ && + current_interceptor_index_ == rpc_info->hijacked_interceptor_ && + !ran_hijacking_interceptor_) { + // We now need to provide hijacked recv ops to this interceptor + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, current_interceptor_index_); + return; + } + if (!reverse_) { + current_interceptor_index_++; + // We are going down the stack of interceptors + if (current_interceptor_index_ < rpc_info->interceptors_.size()) { + if (rpc_info->hijacked_ && + current_interceptor_index_ > rpc_info->hijacked_interceptor_) { + // This is a hijacked RPC and we are done with hijacking + ops_->ContinueFillOpsAfterInterception(); + } else { + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFillOpsAfterInterception(); + } + } else { + // We are going up the stack of interceptors + if (current_interceptor_index_ > 0) { + // Continue running interceptors + current_interceptor_index_--; + rpc_info->RunInterceptor(this, current_interceptor_index_); + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFinalizeResultAfterInterception(); + } + } + } + + void ProceedServer() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + current_interceptor_index_++; + if (current_interceptor_index_ < rpc_info->interceptors_.size()) { + return rpc_info->RunInterceptor(this, current_interceptor_index_); + } else if (ops_) { + return ops_->ContinueFillOpsAfterInterception(); + } + } else { + // We are going up the stack of interceptors + if (current_interceptor_index_ > 0) { + // Continue running interceptors + current_interceptor_index_--; + return rpc_info->RunInterceptor(this, current_interceptor_index_); + } else if (ops_) { + return ops_->ContinueFinalizeResultAfterInterception(); + } + } + GPR_CODEGEN_ASSERT(callback_); + callback_(); + } + + void ClearHookPoints() { + for (auto i = static_cast<experimental::InterceptionHookPoints>(0); + i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; + i = static_cast<experimental::InterceptionHookPoints>( + static_cast<size_t>(i) + 1)) { + hooks_[static_cast<size_t>(i)] = false; + } + } + + std::array<bool, + static_cast<size_t>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> + hooks_; + + size_t current_interceptor_index_ = 0; // Current iterator + bool reverse_ = false; + bool ran_hijacking_interceptor_ = false; + Call* call_ = nullptr; // The Call object is present along with CallOpSet + // object/callback + CallOpSetInterface* ops_ = nullptr; + std::function<void(void)> callback_; + + ByteBuffer* send_message_ = nullptr; + + std::multimap<grpc::string, grpc::string>* send_initial_metadata_; + + grpc_status_code* code_ = nullptr; + grpc::string* error_details_ = nullptr; + grpc::string* error_message_ = nullptr; + Status send_status_; + + std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr; + + void* recv_message_ = nullptr; + + MetadataMap* recv_initial_metadata_ = nullptr; + + Status* recv_status_ = nullptr; + + MetadataMap* recv_trailing_metadata_ = nullptr; +}; + +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H diff --git a/include/grpcpp/impl/codegen/metadata_map.h b/include/grpcpp/impl/codegen/metadata_map.h index 5e062a50f8..0bba3ed4e3 100644 --- a/include/grpcpp/impl/codegen/metadata_map.h +++ b/include/grpcpp/impl/codegen/metadata_map.h @@ -19,6 +19,8 @@ #ifndef GRPCPP_IMPL_CODEGEN_METADATA_MAP_H #define GRPCPP_IMPL_CODEGEN_METADATA_MAP_H +#include <map> + #include <grpc/impl/codegen/log.h> #include <grpcpp/impl/codegen/slice.h> diff --git a/include/grpcpp/impl/codegen/method_handler_impl.h b/include/grpcpp/impl/codegen/method_handler_impl.h index 53117f941b..4f02e3e39b 100644 --- a/include/grpcpp/impl/codegen/method_handler_impl.h +++ b/include/grpcpp/impl/codegen/method_handler_impl.h @@ -59,21 +59,21 @@ class RpcMethodHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits<RequestType>::Deserialize( - param.request.bbuf_ptr(), &req); ResponseType rsp; + Status status = param.status; if (status.ok()) { - status = CatchingFunctionHandler([this, ¶m, &req, &rsp] { - return func_(service_, param.server_context, &req, &rsp); + status = CatchingFunctionHandler([this, ¶m, &rsp] { + return func_(service_, param.server_context, + static_cast<RequestType*>(param.request), &rsp); }); + delete static_cast<RequestType*>(param.request); } GPR_CODEGEN_ASSERT(!param.server_context->sent_initial_metadata_); CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage, CallOpServerSendStatus> ops; - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -81,11 +81,24 @@ class RpcMethodHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits<RequestType>::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: /// Application provided rpc handler function. std::function<Status(ServiceType*, ServerContext*, const RequestType*, @@ -117,7 +130,7 @@ class ClientStreamingHandler : public MethodHandler { CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -126,7 +139,7 @@ class ClientStreamingHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } @@ -150,26 +163,25 @@ class ServerStreamingHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits<RequestType>::Deserialize( - param.request.bbuf_ptr(), &req); - + Status status = param.status; if (status.ok()) { ServerWriter<ResponseType> writer(param.call, param.server_context); - status = CatchingFunctionHandler([this, ¶m, &req, &writer] { - return func_(service_, param.server_context, &req, &writer); + status = CatchingFunctionHandler([this, ¶m, &writer] { + return func_(service_, param.server_context, + static_cast<RequestType*>(param.request), &writer); }); + delete static_cast<RequestType*>(param.request); } CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -177,6 +189,19 @@ class ServerStreamingHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits<RequestType>::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: std::function<Status(ServiceType*, ServerContext*, const RequestType*, ServerWriter<ResponseType>*)> @@ -206,7 +231,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -218,7 +243,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { "Service did not provide response message"); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -281,14 +306,14 @@ class ErrorMethodHandler : public MethodHandler { static void FillOps(ServerContext* context, T* ops) { Status status(code, ""); if (!context->sent_initial_metadata_) { - ops->SendInitialMetadata(context->initial_metadata_, + ops->SendInitialMetadata(&context->initial_metadata_, context->initial_metadata_flags()); if (context->compression_level_set()) { ops->set_compression_level(context->compression_level()); } context->sent_initial_metadata_ = true; } - ops->ServerSendStatus(context->trailing_metadata_, status); + ops->ServerSendStatus(&context->trailing_metadata_, status); } void RunHandler(const HandlerParameter& param) final { @@ -296,11 +321,14 @@ class ErrorMethodHandler : public MethodHandler { FillOps(param.server_context, &ops); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); - // We also have to destroy any request payload in the handler parameter - ByteBuffer* payload = param.request.bbuf_ptr(); - if (payload != nullptr) { - payload->Clear(); + } + + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + // We have to destroy any request payload + if (req != nullptr) { + g_core_codegen_interface->grpc_byte_buffer_destroy(req); } + return nullptr; } }; diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index 5cf88e216f..44da2bd768 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -40,17 +40,26 @@ class MethodHandler { public: virtual ~MethodHandler() {} struct HandlerParameter { - HandlerParameter(Call* c, ServerContext* context, grpc_byte_buffer* req) - : call(c), server_context(context) { - request.set_buffer(req); - } - ~HandlerParameter() { request.Release(); } + HandlerParameter(Call* c, ServerContext* context, void* req, + Status req_status) + : call(c), server_context(context), request(req), status(req_status) {} + ~HandlerParameter() {} Call* call; ServerContext* server_context; - // Handler required to destroy these contents - ByteBuffer request; + void* request; + Status status; }; virtual void RunHandler(const HandlerParameter& param) = 0; + + /* Returns a pointer to the deserialized request. \a status reflects the + result of deserialization. This pointer and the status should be filled in + a HandlerParameter and passed to RunHandler. It is illegal to access the + pointer after calling RunHandler. Ownership of the deserialized request is + retained by the handler. Returns nullptr if deserialization failed. */ + virtual void* Deserialize(grpc_byte_buffer* req, Status* status) { + GPR_CODEGEN_ASSERT(req == nullptr); + return nullptr; + } }; /// Server side rpc method class diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index d53c09aa1b..7559fb3b34 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -26,11 +26,13 @@ #include <grpc/impl/codegen/compression_types.h> #include <grpcpp/impl/codegen/call.h> +#include <grpcpp/impl/codegen/call_op_set.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/codegen/config.h> #include <grpcpp/impl/codegen/create_auth_context.h> #include <grpcpp/impl/codegen/metadata_map.h> #include <grpcpp/impl/codegen/security/auth_context.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/codegen/string_ref.h> #include <grpcpp/impl/codegen/time.h> @@ -285,6 +287,18 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } + experimental::ServerRpcInfo* set_server_rpc_info( + const char* method, + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>& + creators) { + if (creators.size() != 0) { + rpc_info_ = new experimental::ServerRpcInfo(this, method); + rpc_info_->RegisterInterceptors(creators); + } + return rpc_info_; + } + CompletionOp* completion_op_; bool has_notify_when_done_tag_; void* async_notify_when_done_tag_; @@ -306,6 +320,8 @@ class ServerContext { internal::CallOpSendMessage> pending_ops_; bool has_pending_ops_; + + experimental::ServerRpcInfo* rpc_info_ = nullptr; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/server_interceptor.h b/include/grpcpp/impl/codegen/server_interceptor.h new file mode 100644 index 0000000000..c39e9a988d --- /dev/null +++ b/include/grpcpp/impl/codegen/server_interceptor.h @@ -0,0 +1,100 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H +#define GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H + +#include <atomic> +#include <vector> + +#include <grpc/impl/codegen/log.h> +#include <grpcpp/impl/codegen/interceptor.h> +#include <grpcpp/impl/codegen/string_ref.h> + +namespace grpc { + +class ServerContext; + +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ServerRpcInfo; + +class ServerInterceptorFactoryInterface { + public: + virtual ~ServerInterceptorFactoryInterface() {} + virtual Interceptor* CreateServerInterceptor(ServerRpcInfo* info) = 0; +}; + +class ServerRpcInfo { + public: + ~ServerRpcInfo(){}; + + ServerRpcInfo(const ServerRpcInfo&) = delete; + ServerRpcInfo(ServerRpcInfo&&) = default; + ServerRpcInfo& operator=(ServerRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + grpc::ServerContext* server_context() { return ctx_; } + + private: + ServerRpcInfo(grpc::ServerContext* ctx, const char* method) + : ctx_(ctx), method_(method) { + ref_.store(1); + } + + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + void RegisterInterceptors( + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>& + creators) { + for (const auto& creator : creators) { + interceptors_.push_back(std::unique_ptr<experimental::Interceptor>( + creator->CreateServerInterceptor(this))); + } + } + + void Ref() { ref_++; } + void Unref() { + if (--ref_ == 0) { + delete this; + } + } + + grpc::ServerContext* ctx_ = nullptr; + const char* method_ = nullptr; + std::atomic_int ref_; + std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ServerContext; +}; + +} // namespace experimental +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 237991cde6..92c87a5f7e 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -21,10 +21,12 @@ #include <grpc/impl/codegen/grpc_types.h> #include <grpcpp/impl/codegen/byte_buffer.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/call_hook.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/rpc_service_method.h> +#include <grpcpp/impl/codegen/server_context.h> namespace grpc { @@ -148,44 +150,67 @@ class ServerInterface : public internal::CallHook { public: BaseAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize); virtual ~BaseAsyncRequest(); bool FinalizeResult(void** tag, bool* status) override; + private: + void ContinueFinalizeResultAfterInterception(); + protected: ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; const bool delete_on_finalize_; grpc_call* call_; + internal::Call call_wrapper_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; + bool done_intercepting_; }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag); - - // uses BaseAsyncRequest::FinalizeResult + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, + const char* name); + + virtual bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(name_, + *server_->interceptor_creators())); + return BaseAsyncRequest::FinalizeResult(tag, status); + } protected: void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); + const char* name_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { - IssueRequest(registered_method, nullptr, notification_cq); + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()) { + IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } // uses RegisteredAsyncRequest::FinalizeResult @@ -194,13 +219,15 @@ class ServerInterface : public internal::CallHook { template <class Message> class PayloadAsyncRequest final : public RegisteredAsyncRequest { public: - PayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + PayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag), + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()), registered_method_(registered_method), server_(server), context_(context), @@ -209,7 +236,8 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { - IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq); + IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), + notification_cq); } ~PayloadAsyncRequest() { @@ -217,6 +245,10 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return RegisteredAsyncRequest::FinalizeResult(tag, status); + } if (*status) { if (!payload_.Valid() || !SerializationTraits<Message>::Deserialize( payload_.bbuf_ptr(), request_) @@ -235,15 +267,20 @@ class ServerInterface : public internal::CallHook { return false; } } + /* Set interception point for recv message */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); return RegisteredAsyncRequest::FinalizeResult(tag, status); } private: - void* const registered_method_; + internal::RpcServiceMethod* const registered_method_; ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; Message* const request_; @@ -272,9 +309,8 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); - new PayloadAsyncRequest<Message>(method->server_tag(), this, context, - stream, call_cq, notification_cq, tag, - message); + new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq, + notification_cq, tag, message); } void RequestAsyncCall(internal::RpcServiceMethod* method, @@ -283,8 +319,8 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); - new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, - call_cq, notification_cq, tag); + new NoPayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag); } void RequestAsyncGenericCall(GenericServerContext* context, @@ -295,6 +331,13 @@ class ServerInterface : public internal::CallHook { new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag, true); } + + private: + virtual const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() { + return nullptr; + } }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/sync_stream.h b/include/grpcpp/impl/codegen/sync_stream.h index cbfcf25d0a..6981076f04 100644 --- a/include/grpcpp/impl/codegen/sync_stream.h +++ b/include/grpcpp/impl/codegen/sync_stream.h @@ -250,7 +250,7 @@ class ClientReader final : public ClientReaderInterface<R> { ::grpc::internal::CallOpSendMessage, ::grpc::internal::CallOpClientSendClose> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); // TODO(ctiller): don't assert GPR_CODEGEN_ASSERT(ops.SendMessage(request).ok()); @@ -327,7 +327,7 @@ class ClientWriter : public ClientWriterInterface<W> { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -386,7 +386,7 @@ class ClientWriter : public ClientWriterInterface<W> { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -498,7 +498,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface<W, R> { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -557,7 +557,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface<W, R> { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -583,7 +583,7 @@ class ServerReader final : public ServerReaderInterface<R> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet<internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -635,7 +635,7 @@ class ServerWriter final : public ServerWriterInterface<W> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet<internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -660,7 +660,7 @@ class ServerWriter final : public ServerWriterInterface<W> { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); @@ -708,7 +708,7 @@ class ServerReaderWriterBody final { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); CallOpSet<CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -738,7 +738,7 @@ class ServerReaderWriterBody final { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index 8d3e856502..2b89ffd317 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -174,7 +174,11 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq = nullptr); + grpc_resource_quota* server_rq = nullptr, + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators = std::vector<std::unique_ptr< + experimental::ServerInterceptorFactoryInterface>>()); /// Start the server. /// @@ -187,6 +191,12 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() override { + return &interceptor_creators_; + } + friend class AsyncGenericService; friend class ServerBuilder; friend class ServerInitializer; @@ -251,6 +261,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { // A special handler for resource exhausted in sync case std::unique_ptr<internal::MethodHandler> resource_exhausted_handler_; + + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators_; }; } // namespace grpc diff --git a/include/grpcpp/server_builder.h b/include/grpcpp/server_builder.h index a58a59c2d8..028b8cffaa 100644 --- a/include/grpcpp/server_builder.h +++ b/include/grpcpp/server_builder.h @@ -28,6 +28,7 @@ #include <grpc/support/cpu.h> #include <grpc/support/workaround_list.h> #include <grpcpp/impl/channel_argument_option.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/server_builder_option.h> #include <grpcpp/impl/server_builder_plugin.h> #include <grpcpp/support/config.h> @@ -212,6 +213,29 @@ class ServerBuilder { /// doc/workarounds.md. ServerBuilder& EnableWorkaround(grpc_workaround_list id); + /// NOTE: class experimental_type is not part of the public API of this class. + /// TODO(yashykt): Integrate into public API when this is no longer + /// experimental. + class experimental_type { + public: + explicit experimental_type(ServerBuilder* builder) : builder_(builder) {} + + void SetInterceptorCreators( + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators) { + builder_->interceptor_creators_ = std::move(interceptor_creators); + } + + private: + ServerBuilder* builder_; + }; + + /// NOTE: The function experimental() is not stable public API. It is a view + /// to the experimental components of this class. It may be changed or removed + /// at any time. + experimental_type experimental() { return experimental_type(this); } + protected: /// Experimental, to be deprecated struct Port { @@ -297,6 +321,8 @@ class ServerBuilder { grpc_compression_algorithm algorithm; } maybe_default_compression_algorithm_; uint32_t enabled_compression_algorithms_bitset_; + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators_; }; } // namespace grpc |