diff options
author | 2018-10-08 19:03:33 -0700 | |
---|---|---|
committer | 2018-10-16 14:10:02 -0700 | |
commit | 63bdf4e2363a3c55edf8ddb9d089da88c31963f2 (patch) | |
tree | adbadc9a7d9adc34a6110360d85d958332fee8b6 | |
parent | 5d831da9d135d7f1c58ff61bacb6e5a2787f05c9 (diff) |
More changes for client interception
-rw-r--r-- | include/grpcpp/impl/codegen/async_stream.h | 27 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/async_unary_call.h | 9 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/call.h | 739 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/client_callback.h | 2 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/client_context.h | 9 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/client_interceptor.h | 84 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/client_unary_call.h | 5 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/core_codegen.h | 3 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/core_codegen_interface.h | 3 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor.h | 53 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/service_type.h | 2 | ||||
-rw-r--r-- | src/cpp/client/channel_cc.cc | 13 | ||||
-rw-r--r-- | src/cpp/common/core_codegen.cc | 7 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 21 | ||||
-rw-r--r-- | src/cpp/server/server_context.cc | 28 |
15 files changed, 771 insertions, 234 deletions
diff --git a/include/grpcpp/impl/codegen/async_stream.h b/include/grpcpp/impl/codegen/async_stream.h index 6e58fd0eef..258b73535e 100644 --- a/include/grpcpp/impl/codegen/async_stream.h +++ b/include/grpcpp/impl/codegen/async_stream.h @@ -188,7 +188,7 @@ class ClientAsyncReaderFactory { ::grpc::internal::Call call = channel->CreateCall(method, context, cq); return new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientAsyncReader<R>))) - ClientAsyncReader<R>(call, context, request, start, tag); + ClientAsyncReader<R>(std::move(call), context, request, start, tag); } }; } // namespace internal @@ -264,7 +264,7 @@ class ClientAsyncReader final : public ClientAsyncReaderInterface<R> { template <class W> ClientAsyncReader(::grpc::internal::Call call, ClientContext* context, const W& request, bool start, void* tag) - : context_(context), call_(call), started_(start) { + : context_(context), call_(std::move(call)), started_(start) { // TODO(ctiller): don't assert GPR_CODEGEN_ASSERT(init_ops_.SendMessage(request).ok()); init_ops_.ClientSendClose(); @@ -336,7 +336,7 @@ class ClientAsyncWriterFactory { ::grpc::internal::Call call = channel->CreateCall(method, context, cq); return new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientAsyncWriter<W>))) - ClientAsyncWriter<W>(call, context, response, start, tag); + ClientAsyncWriter<W>(std::move(call), context, response, start, tag); } }; } // namespace internal @@ -430,7 +430,7 @@ class ClientAsyncWriter final : public ClientAsyncWriterInterface<W> { template <class R> ClientAsyncWriter(::grpc::internal::Call call, ClientContext* context, R* response, bool start, void* tag) - : context_(context), call_(call), started_(start) { + : context_(context), call_(std::move(call)), started_(start) { finish_ops_.RecvMessage(response); finish_ops_.AllowNoMessage(); if (start) { @@ -501,7 +501,7 @@ class ClientAsyncReaderWriterFactory { return new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientAsyncReaderWriter<W, R>))) - ClientAsyncReaderWriter<W, R>(call, context, start, tag); + ClientAsyncReaderWriter<W, R>(std::move(call), context, start, tag); } }; } // namespace internal @@ -603,7 +603,7 @@ class ClientAsyncReaderWriter final friend class internal::ClientAsyncReaderWriterFactory<W, R>; ClientAsyncReaderWriter(::grpc::internal::Call call, ClientContext* context, bool start, void* tag) - : context_(context), call_(call), started_(start) { + : context_(context), call_(std::move(call)), started_(start) { if (start) { StartCallInternal(tag); } else { @@ -781,7 +781,10 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { } private: - void BindCall(::grpc::internal::Call* call) override { call_ = *call; } + ::grpc::internal::Call* BindCall(::grpc::internal::Call call) override { + call_ = std::move(call); + return &call_; + } ::grpc::internal::Call call_; ServerContext* ctx_; @@ -927,7 +930,10 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { } private: - void BindCall(::grpc::internal::Call* call) override { call_ = *call; } + ::grpc::internal::Call* BindCall(::grpc::internal::Call call) override { + call_ = std::move(call); + return &call_; + } template <class T> void EnsureInitialMetadataSent(T* ops) { @@ -1101,7 +1107,10 @@ class ServerAsyncReaderWriter final private: friend class ::grpc::Server; - void BindCall(::grpc::internal::Call* call) override { call_ = *call; } + ::grpc::internal::Call* BindCall(::grpc::internal::Call call) override { + call_ = std::move(call); + return &call_; + } template <class T> void EnsureInitialMetadataSent(T* ops) { diff --git a/include/grpcpp/impl/codegen/async_unary_call.h b/include/grpcpp/impl/codegen/async_unary_call.h index 60ff8e2f05..f34bc6b792 100644 --- a/include/grpcpp/impl/codegen/async_unary_call.h +++ b/include/grpcpp/impl/codegen/async_unary_call.h @@ -87,7 +87,7 @@ class ClientAsyncResponseReaderFactory { ::grpc::internal::Call call = channel->CreateCall(method, context, cq); return new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientAsyncResponseReader<R>))) - ClientAsyncResponseReader<R>(call, context, request, start); + ClientAsyncResponseReader<R>(std::move(call), context, request, start); } }; } // namespace internal @@ -165,7 +165,7 @@ class ClientAsyncResponseReader final template <class W> ClientAsyncResponseReader(::grpc::internal::Call call, ClientContext* context, const W& request, bool start) - : context_(context), call_(call), started_(start) { + : context_(context), call_(std::move(call)), started_(start) { // Bind the metadata at time of StartCallInternal but set up the rest here // TODO(ctiller): don't assert GPR_CODEGEN_ASSERT(single_buf.SendMessage(request).ok()); @@ -286,7 +286,10 @@ class ServerAsyncResponseWriter final } private: - void BindCall(::grpc::internal::Call* call) override { call_ = *call; } + ::grpc::internal::Call* BindCall(::grpc::internal::Call call) override { + call_ = std::move(call); + return &call_; + } ::grpc::internal::Call call_; ServerContext* ctx_; diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 771fc22d46..b2f133a94e 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -52,58 +52,6 @@ namespace internal { class Call; class CallHook; -/// Straightforward wrapping of the C call object -class Call final { - public: - /** call is owned by the caller */ - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(-1), - rpc_info_(nullptr, nullptr, nullptr) {} - - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - experimental::ClientRpcInfo rpc_info, - const std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>& - creators) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(-1), - rpc_info_(rpc_info) { - for (const auto& creator : creators) { - interceptors_.push_back(creator->CreateClientInterceptor(&rpc_info_)); - } - } - - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - int max_receive_message_size) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(max_receive_message_size), - rpc_info_(nullptr, nullptr, nullptr) {} - - void PerformOps(CallOpSetInterface* ops) { - call_hook_->PerformOpsOnCall(ops, this); - } - - grpc_call* call() const { return call_; } - CompletionQueue* cq() const { return cq_; } - - int max_receive_message_size() const { return max_receive_message_size_; } - - private: - CallHook* call_hook_; - CompletionQueue* cq_; - grpc_call* call_; - int max_receive_message_size_; - experimental::ClientRpcInfo rpc_info_; - std::vector<experimental::ClientInterceptor*> interceptors_; -}; - // TODO(yangg) if the map is changed before we send, the pointers will be a // mess. Make sure it does not happen. inline grpc_metadata* FillMetadataArray( @@ -255,45 +203,20 @@ class WriteOptions { }; namespace internal { - -class InterceptorBatchMethodsImpl - : public experimental::InterceptorBatchMethods { - public: - InterceptorBatchMethodsImpl() {} - - virtual ~InterceptorBatchMethodsImpl() {} - - virtual bool QueryInterceptionHookPoint( - experimental::InterceptionHookPoints type) override { - return hooks_[static_cast<int>(type)]; - } - - virtual void Proceed() override { /* fill this */ - } - - virtual void Hijack() override { /* fill this */ - } - - void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { - hooks_[static_cast<int>(type)]; - } - - private: - std::array<bool, - static_cast<int>( - experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> - hooks_; -}; - /// Default argument for CallOpSet. I is unused by the class, but can be /// used for generating multiple names for the same thing. template <int I> class CallNoOp { protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) {} - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) {} + void AddOp(grpc_op* ops, size_t* nops) {} + void FinishOp(bool* status) {} + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) {} }; class CallOpSendInitialMetadata { @@ -318,9 +241,8 @@ class CallOpSendInitialMetadata { } protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_INITIAL_METADATA; op->flags = flags_; @@ -333,16 +255,31 @@ class CallOpSendInitialMetadata { op->data.send_initial_metadata.maybe_compression_level.level = maybe_compression_level_.level; } - interceptor_methods->AddInterceptionHookPoint( - experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_) return; + void FinishOp(bool* status) { + if (!send_ || hijacked_) return; g_core_codegen_interface->gpr_free(initial_metadata_); send_ = false; } + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); + interceptor_methods->SetSendInitialMetadata(initial_metadata_, + &initial_metadata_count_); + } + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + bool hijacked_ = false; bool send_; uint32_t flags_; size_t initial_metadata_count_; @@ -367,9 +304,8 @@ class CallOpSendMessage { Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_buf_.Valid()) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_buf_.Valid() || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_MESSAGE; op->flags = write_options_.flags(); @@ -377,15 +313,27 @@ class CallOpSendMessage { op->data.send_message.send_message = send_buf_.c_buffer(); // Flags are per-message: clear them after use. write_options_.Clear(); + } + void FinishOp(bool* status) { send_buf_.Clear(); } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (!send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); + interceptor_methods->SetSendMessage(send_buf_.c_buffer()); } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - send_buf_.Clear(); + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; } private: + bool hijacked_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; @@ -427,9 +375,8 @@ class CallOpRecvMessage { bool got_message; protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (message_ == nullptr) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (message_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; op->flags = 0; @@ -437,9 +384,8 @@ class CallOpRecvMessage { op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (message_ == nullptr) return; + void FinishOp(bool* status) { + if (message_ == nullptr || hijacked_) return; if (recv_buf_.Valid()) { if (*status) { got_message = *status = @@ -457,6 +403,23 @@ class CallOpRecvMessage { } } message_ = nullptr; + } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (message_ == nullptr || !got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (message_ == nullptr || !got_message) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); } @@ -465,6 +428,7 @@ class CallOpRecvMessage { R* message_; ByteBuffer recv_buf_; bool allow_not_getting_message_; + bool hijacked_ = false; }; class DeserializeFunc { @@ -498,6 +462,7 @@ class CallOpGenericRecvMessage { // following unique_ptr::reset for some old implementations. DeserializeFunc* func = new DeserializeFuncType<R>(message); deserialize_.reset(func); + message_ = message; } // Do not change status if no message is received. @@ -506,9 +471,8 @@ class CallOpGenericRecvMessage { bool got_message; protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!deserialize_) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (!deserialize_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; op->flags = 0; @@ -516,9 +480,8 @@ class CallOpGenericRecvMessage { op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!deserialize_) return; + void FinishOp(bool* status) { + if (!deserialize_ || hijacked_) return; if (recv_buf_.Valid()) { if (*status) { got_message = true; @@ -535,11 +498,30 @@ class CallOpGenericRecvMessage { } } deserialize_.reset(); + } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (!deserialize_ || !got_message) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (!deserialize_ || !got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + } private: + void* message_; + bool hijacked_ = false; std::unique_ptr<DeserializeFunc> deserialize_; ByteBuffer recv_buf_; bool allow_not_getting_message_; @@ -552,20 +534,32 @@ class CallOpClientSendClose { void ClientSendClose() { send_ = true; } protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; op->flags = 0; op->reserved = NULL; } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - send_ = false; + void FinishOp(bool* status) { send_ = false; } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE); + } + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; } private: + bool hijacked_ = false; bool send_; }; @@ -585,9 +579,8 @@ class CallOpServerSendStatus { } protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_status_available_) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_status_available_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; op->data.send_status_from_server.trailing_metadata_count = @@ -599,18 +592,35 @@ class CallOpServerSendStatus { send_error_message_.empty() ? nullptr : &error_message_slice_; op->flags = 0; op->reserved = NULL; - interceptor_methods->AddInterceptionHookPoint( - experimental::InterceptionHookPoints::PRE_SEND_STATUS); } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_status_available_) return; + void FinishOp(bool* status) { + if (!send_status_available_ || hijacked_) return; g_core_codegen_interface->gpr_free(trailing_metadata_); send_status_available_ = false; } + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (!send_status_available_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS); + interceptor_methods->SetSendTrailingMetadata(trailing_metadata_, + &trailing_metadata_count_); + interceptor_methods->SetSendStatus(&send_status_code_, &send_error_details_, + &send_error_message_); + } + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + private: + bool hijacked_ = false; bool send_status_available_; grpc_status_code send_status_code_; grpc::string send_error_details_; @@ -630,9 +640,8 @@ class CallOpRecvInitialMetadata { } protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (metadata_map_ == nullptr) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (metadata_map_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_INITIAL_METADATA; op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); @@ -640,15 +649,32 @@ class CallOpRecvInitialMetadata { op->reserved = NULL; } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (metadata_map_ == nullptr) return; + void FinishOp(bool* status) { + if (metadata_map_ == nullptr || hijacked_) return; metadata_map_ = nullptr; + } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (metadata_map_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); } + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA); + interceptor_methods->SetRecvInitialMetadata(metadata_map_->arr()); + } + private: + bool hijacked_ = false; MetadataMap* metadata_map_; }; @@ -665,9 +691,8 @@ class CallOpClientRecvStatus { } protected: - void AddOp(grpc_op* ops, size_t* nops, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (recv_status_ == nullptr) return; + void AddOp(grpc_op* ops, size_t* nops) { + if (recv_status_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); @@ -678,9 +703,8 @@ class CallOpClientRecvStatus { op->reserved = NULL; } - void FinishOp(bool* status, - InterceptorBatchMethodsImpl* interceptor_methods) { - if (recv_status_ == nullptr) return; + void FinishOp(bool* status) { + if (recv_status_ == nullptr || hijacked_) return; grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); *recv_status_ = Status(static_cast<StatusCode>(status_code_), @@ -696,11 +720,30 @@ class CallOpClientRecvStatus { g_core_codegen_interface->gpr_free((void*)debug_error_string_); } recv_status_ = nullptr; + } + + void SetInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) {} + + void SetFinishInterceptionHookPoint( + experimental::InterceptorBatchMethods* interceptor_methods) { + if (recv_status_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS); + interceptor_methods->SetRecvStatus(recv_status_); + interceptor_methods->SetRecvTrailingMetadata(metadata_map_->arr()); + } + + void SetHijackingState( + experimental::InterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS); } private: + bool hijacked_ = false; ClientContext* client_context_; MetadataMap* metadata_map_; Status* recv_status_; @@ -709,6 +752,57 @@ class CallOpClientRecvStatus { grpc_slice error_message_; }; +/// Straightforward wrapping of the C call object +class Call final { + public: + Call() + : call_hook_(nullptr), + cq_(nullptr), + call_(nullptr), + max_receive_message_size_(-1), + rpc_info_(nullptr) {} + /** call is owned by the caller */ + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1), + rpc_info_(nullptr) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + experimental::ClientRpcInfo* rpc_info) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1), + rpc_info_(rpc_info) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_receive_message_size) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(max_receive_message_size), + rpc_info_(nullptr) {} + + void PerformOps(CallOpSetInterface* ops) { + call_hook_->PerformOpsOnCall(ops, this); + } + + grpc_call* call() const { return call_; } + CompletionQueue* cq() const { return cq_; } + + int max_receive_message_size() const { return max_receive_message_size_; } + experimental::ClientRpcInfo* rpc_info() const { return rpc_info_; } + + private: + CallHook* call_hook_; + CompletionQueue* cq_; + grpc_call* call_; + int max_receive_message_size_; + experimental::ClientRpcInfo* rpc_info_; +}; + /// An abstract collection of call ops, used to generate the /// grpc_call_op structure to pass down to the lower layers, /// and as it is-a CompletionQueueTag, also massages the final @@ -718,12 +812,261 @@ class CallOpSetInterface : public CompletionQueueTag { public: /// Fills in grpc_op, starting from ops[*nops] and moving /// upwards. - virtual void FillOps(internal::Call* call, grpc_op* ops, size_t* nops) = 0; + virtual void FillOps(internal::Call* call) = 0; /// Get the tag to be used at the core completion queue. Generally, the /// value of cq_tag will be "this". However, it can be overridden if we /// want core to process the tag differently (e.g., as a core callback) virtual void* cq_tag() = 0; + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + virtual void SetHijackingState() = 0; + + /* Should be called after interceptors are done running */ + virtual void ContinueFillOpsAfterInterception() = 0; + + /* Should be called after interceptors are done running on the finalize result + * path */ + virtual void ContinueFinalizeResultAfterInterception() = 0; +}; + +template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, + class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, + class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +class CallOpSet; + +class InterceptorBatchMethodsImpl + : public experimental::InterceptorBatchMethods { + public: + InterceptorBatchMethodsImpl() { + for (auto i = 0; + i < static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); + i++) { + hooks_[i] = false; + } + } + + virtual ~InterceptorBatchMethodsImpl() {} + + virtual bool QueryInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + return hooks_[static_cast<int>(type)]; + } + + virtual void Proceed() override { /* fill this */ + curr_iteration_ = reverse_ ? curr_iteration_ - 1 : curr_iteration_ + 1; + auto* rpc_info = call_->rpc_info(); + if (rpc_info->hijacked_ && + (!reverse_ && curr_iteration_ == rpc_info->hijacked_interceptor_ + 1)) { + /* We now need to provide hijacked recv ops to this interceptor */ + ClearHookPoints(); + ops_->SetHijackingState(); + rpc_info->RunInterceptor(this, curr_iteration_ - 1); + return; + } + if (!reverse_) { + /* We are going down the stack of interceptors */ + if (curr_iteration_ < static_cast<long>(rpc_info->interceptors_.size())) { + if (rpc_info->hijacked_ && + curr_iteration_ > rpc_info->hijacked_interceptor_) { + /* This is a hijacked RPC and we are done with hijacking */ + ops_->ContinueFillOpsAfterInterception(); + } else { + rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + /* we are done running all the interceptors without any hijacking */ + ops_->ContinueFillOpsAfterInterception(); + } + } else { + /* We are going up the stack of interceptors */ + if (curr_iteration_ >= 0) { + if (rpc_info->hijacked_ && + curr_iteration_ < rpc_info->hijacked_interceptor_) { + /* This is a hijacked RPC and we are done running the hijacking + * interceptor. */ + ops_->ContinueFinalizeResultAfterInterception(); + } else { + rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + /* we are done running all the interceptors without any hijacking */ + ops_->ContinueFinalizeResultAfterInterception(); + } + } + } + + virtual void Hijack() override { /* fill this */ + GPR_ASSERT(!reverse_); + auto* rpc_info = call_->rpc_info(); + rpc_info->hijacked_ = true; + rpc_info->hijacked_interceptor_ = curr_iteration_; + ClearHookPoints(); + ops_->SetHijackingState(); + curr_iteration_++; // increment so that we recognize that we have already + // run the hijacking interceptor + rpc_info->RunInterceptor(this, curr_iteration_ - 1); + } + + virtual void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + hooks_[static_cast<int>(type)]; + } + + virtual void GetSendMessage(grpc_byte_buffer** buf) override { + *buf = send_message_; + } + + virtual void GetSendInitialMetadata(grpc_metadata** metadata, + size_t** count) override { + *metadata = send_initial_metadata_; + *count = send_initial_metadata_count_; + } + + virtual void GetSendStatus(grpc_status_code** code, + grpc::string** error_details, + grpc::string** error_message) override { + *code = code_; + *error_details = error_details_; + *error_message = error_message_; + } + + virtual void GetSendTrailingMetadata(grpc_metadata** metadata, + size_t** count) override { + *metadata = send_trailing_metadata_; + *count = send_trailing_metadata_count_; + } + + virtual void GetRecvMessage(void** message) override { + *message = recv_message_; + } + + virtual void GetRecvInitialMetadata(grpc_metadata_array** array) override { + *array = recv_initial_metadata_; + } + + virtual void GetRecvStatus(Status** status) override { + *status = recv_status_; + } + + virtual void GetRecvTrailingMetadata(grpc_metadata_array** map) override { + *map = recv_trailing_metadata_; + } + + virtual void SetSendMessage(grpc_byte_buffer* buf) override { + send_message_ = buf; + } + + virtual void SetSendInitialMetadata(grpc_metadata* metadata, + size_t* count) override { + send_initial_metadata_ = metadata; + send_initial_metadata_count_ = count; + } + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) override { + code_ = code; + error_details_ = error_details; + error_message_ = error_message; + } + + virtual void SetSendTrailingMetadata(grpc_metadata* metadata, + size_t* count) override { + send_trailing_metadata_ = metadata; + send_trailing_metadata_count_ = count; + } + + virtual void SetRecvMessage(void* message) override { + recv_message_ = message; + } + + virtual void SetRecvInitialMetadata(grpc_metadata_array* array) override { + recv_initial_metadata_ = array; + } + + virtual void SetRecvStatus(Status* status) override { recv_status_ = status; } + + virtual void SetRecvTrailingMetadata(grpc_metadata_array* map) override { + recv_trailing_metadata_ = map; + } + + /* Prepares for Post_recv operations */ + void SetReverse() { + reverse_ = true; + ClearHookPoints(); + curr_iteration_ = 0; + } + + /* This needs to be set before interceptors are run */ + void SetCall(Call* call) { call_ = call; } + + void SetCallOpSet(CallOpSetInterface* ops) { ops_ = ops; } + + /* Returns true if no interceptors are run */ + bool RunInterceptors() { + auto* rpc_info = call_->rpc_info(); + if (rpc_info == nullptr || rpc_info->interceptors_.size() == 0) { + return true; + } + if (!reverse_) { + rpc_info->RunInterceptor(this, 0); + } else { + if (rpc_info->hijacked_) { + rpc_info->RunInterceptor(this, rpc_info->hijacked_interceptor_); + } else { + rpc_info->RunInterceptor(this, rpc_info->interceptors_.size() - 1); + } + } + return false; + } + + private: + void ClearHookPoints() { + for (auto i = 0; + i < static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); + i++) { + hooks_[i] = false; + } + } + + std::array<bool, + static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> + hooks_; + + int curr_iteration_ = 0; // Current iterator + bool reverse_ = false; + Call* call_ = + nullptr; // The Call object is present along with CallOpSet object + CallOpSetInterface* ops_ = nullptr; + + grpc_byte_buffer* send_message_ = nullptr; + + grpc_metadata* send_initial_metadata_ = nullptr; + size_t* send_initial_metadata_count_ = nullptr; + + grpc_status_code* code_ = nullptr; + grpc::string* error_details_ = nullptr; + grpc::string* error_message_ = nullptr; + + grpc_metadata* send_trailing_metadata_ = nullptr; + size_t* send_trailing_metadata_count_ = nullptr; + + void* recv_message_ = nullptr; + + grpc_metadata_array* recv_initial_metadata_ = nullptr; + + Status* recv_status_ = nullptr; + + grpc_metadata_array* recv_trailing_metadata_ = nullptr; + + // void (*hijacking_state_setter_)(); + // void (*continue_after_interception_)(); + // void (*continue_after_reverse_interception_)(); }; /// Primary implementation of CallOpSetInterface. @@ -732,9 +1075,7 @@ class CallOpSetInterface : public CompletionQueueTag { /// empty base class optimization to slim this class (especially /// when there are many unused slots used). To avoid duplicate base classes, /// the template parmeter for CallNoOp is varied by argument position. -template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, - class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, - class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +template <class Op1, class Op2, class Op3, class Op4, class Op5, class Op6> class CallOpSet : public CallOpSetInterface, public Op1, public Op2, @@ -743,29 +1084,45 @@ class CallOpSet : public CallOpSetInterface, public Op5, public Op6 { public: - CallOpSet() : cq_tag_(this), return_tag_(this), call_(nullptr) {} - void FillOps(Call* call, grpc_op* ops, size_t* nops) override { - this->Op1::AddOp(ops, nops, &interceptor_methods_); - this->Op2::AddOp(ops, nops, &interceptor_methods_); - this->Op3::AddOp(ops, nops, &interceptor_methods_); - this->Op4::AddOp(ops, nops, &interceptor_methods_); - this->Op5::AddOp(ops, nops, &interceptor_methods_); - this->Op6::AddOp(ops, nops, &interceptor_methods_); + CallOpSet() : cq_tag_(this), return_tag_(this) {} + void FillOps(Call* call) override { g_core_codegen_interface->grpc_call_ref(call->call()); - call_ = call; + call_ = + *call; // It's fine to create a copy of call since it's just pointers + + if (RunInterceptors()) { + ContinueFillOpsAfterInterception(); + } else { + /* After the interceptors are run, ContinueFillOpsAfterInterception will + * be run */ + } } bool FinalizeResult(void** tag, bool* status) override { - this->Op1::FinishOp(status, &interceptor_methods_); - this->Op2::FinishOp(status, &interceptor_methods_); - this->Op3::FinishOp(status, &interceptor_methods_); - this->Op4::FinishOp(status, &interceptor_methods_); - this->Op5::FinishOp(status, &interceptor_methods_); - this->Op6::FinishOp(status, &interceptor_methods_); + if (done_intercepting_) { + /* We have already finished intercepting and filling in the results. This + * round trip from the core needed to be made because interceptors were + * run */ + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + + this->Op1::FinishOp(status); + this->Op2::FinishOp(status); + this->Op3::FinishOp(status); + this->Op4::FinishOp(status); + this->Op5::FinishOp(status); + this->Op6::FinishOp(status); *tag = return_tag_; - g_core_codegen_interface->grpc_call_unref(call_->call()); - return true; + if (RunInterceptorsPostRecv()) { + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + + /* Interceptors are going to be run, so we can't return the tag just yet. + After the interceptors are run, ContinueFinalizeResultAfterInterception */ + return false; } void set_output_tag(void* return_tag) { return_tag_ = return_tag; } @@ -778,10 +1135,72 @@ class CallOpSet : public CallOpSetInterface, /// function (such as FinalizeResult) void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + this->Op1::SetHijackingState(&interceptor_methods_); + this->Op2::SetHijackingState(&interceptor_methods_); + this->Op3::SetHijackingState(&interceptor_methods_); + this->Op4::SetHijackingState(&interceptor_methods_); + this->Op5::SetHijackingState(&interceptor_methods_); + this->Op6::SetHijackingState(&interceptor_methods_); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override { + static const size_t MAX_OPS = 6; + grpc_op ops[MAX_OPS]; + size_t nops = 0; + this->Op1::AddOp(ops, &nops); + this->Op2::AddOp(ops, &nops); + this->Op3::AddOp(ops, &nops); + this->Op4::AddOp(ops, &nops); + this->Op5::AddOp(ops, &nops); + this->Op6::AddOp(ops, &nops); + GPR_ASSERT(GRPC_CALL_OK == g_core_codegen_interface->grpc_call_start_batch( + call_.call(), ops, nops, cq_tag(), nullptr)); + } + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + GPR_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, cq_tag(), nullptr)); + } + private: + /* Returns true if no interceptors need to be run */ + bool RunInterceptors() { + this->Op1::SetInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetInterceptionHookPoint(&interceptor_methods_); + interceptor_methods_.SetCallOpSet(this); + interceptor_methods_.SetCall(&call_); + // interceptor_methods_.SetFunctions(ContinueFillOpsAfterInterception, + // SetHijackingState, ContinueFinalizeResultAfterInterception); + return interceptor_methods_.RunInterceptors(); + } + /* Returns true if no interceptors need to be run */ + bool RunInterceptorsPostRecv() { + this->Op1::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetFinishInterceptionHookPoint(&interceptor_methods_); + interceptor_methods_.SetReverse(); + return interceptor_methods_.RunInterceptors(); + } + void* cq_tag_; void* return_tag_; - Call* call_; + Call call_; + bool done_intercepting_ = false; InterceptorBatchMethodsImpl interceptor_methods_; }; diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4d4faea063..ae9a8a95f9 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -57,7 +57,7 @@ class CallbackUnaryCallImpl { std::function<void(Status)> on_completion) { CompletionQueue* cq = channel->CallbackCQ(); GPR_CODEGEN_ASSERT(cq != nullptr); - Call call(channel->CreateCall(method, context, cq)); + Call call = channel->CreateCall(method, context, cq); using FullCallOpSet = CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage, diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 24f5c431ce..95462dfff3 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -41,6 +41,7 @@ #include <grpc/impl/codegen/compression_types.h> #include <grpc/impl/codegen/propagation_bits.h> +#include <grpcpp/impl/codegen/client_interceptor.h> #include <grpcpp/impl/codegen/config.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/create_auth_context.h> @@ -402,6 +403,12 @@ class ClientContext { grpc_call* call() const { return call_; } void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel); + experimental::ClientRpcInfo* set_client_rpc_info( + experimental::ClientRpcInfo client_rpc_info) { + rpc_info_ = std::move(client_rpc_info); + return &rpc_info_; + } + uint32_t initial_metadata_flags() const { return (idempotent_ ? GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST : 0) | (wait_for_ready_ ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0) | @@ -439,6 +446,8 @@ class ClientContext { bool initial_metadata_corked_; grpc::string debug_error_string_; + + experimental::ClientRpcInfo rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f7963a57d5..6b02a89012 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -19,12 +19,25 @@ #ifndef GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H -#include <grpcpp/impl/codegen/client_context.h> +#include <vector> + +#include <grpc/impl/codegen/log.h> #include <grpcpp/impl/codegen/interceptor.h> #include <grpcpp/impl/codegen/string_ref.h> namespace grpc { + +class ClientContext; +class Channel; + +namespace internal { +template <int I> +class CallNoOp; +} + namespace experimental { +class ClientRpcInfo; + class ClientInterceptor { public: virtual ~ClientInterceptor() {} @@ -32,28 +45,39 @@ class ClientInterceptor { virtual void Intercept(InterceptorBatchMethods* methods) = 0; }; +class ClientInterceptorFactoryInterface { + public: + virtual ~ClientInterceptorFactoryInterface() {} + virtual ClientInterceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; +}; + class ClientRpcInfo { public: + ClientRpcInfo() {} ClientRpcInfo(grpc::ClientContext* ctx, const char* method, - const grpc::Channel* channel) - : ctx_(ctx), method_(method), channel_(channel) {} + const grpc::Channel* channel, + const std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>& creators) + : ctx_(ctx), method_(method), channel_(channel) { + for (const auto& creator : creators) { + interceptors_.push_back(std::unique_ptr<experimental::ClientInterceptor>( + creator->CreateClientInterceptor(this))); + } + } ~ClientRpcInfo(){}; + ClientRpcInfo(const ClientRpcInfo&) = delete; + ClientRpcInfo(ClientRpcInfo&&) = default; + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + // Getter methods const char* method() { return method_; } - string peer() { return ctx_->peer(); } const Channel* channel() { return channel_; } // const grpc::InterceptedMessage& outgoing_message(); // grpc::InterceptedMessage *mutable_outgoing_message(); // const grpc::InterceptedMessage& received_message(); // grpc::InterceptedMessage *mutable_received_message(); - std::shared_ptr<const AuthContext> auth_context() { - return ctx_->auth_context(); - } - const struct census_context* census_context() { - return ctx_->census_context(); - } - gpr_timespec deadline() { return ctx_->raw_deadline(); } + // const std::multimap<grpc::string, grpc::string>* client_initial_metadata() // { return &ctx_->send_initial_metadata_; } const // std::multimap<grpc::string_ref, grpc::string_ref>* @@ -62,14 +86,6 @@ class ClientRpcInfo { // server_trailing_metadata() { return &ctx_->GetServerTrailingMetadata(); } // const Status *status(); - // Setter methods - template <typename T> - void set_deadline(const T& deadline) { - ctx_->set_deadline(deadline); - } - void set_census_context(struct census_context* cc) { - ctx_->set_census_context(cc); - } // template <class M> // void set_outgoing_message(M* msg); // edit outgoing message // template <class M> @@ -83,16 +99,30 @@ class ClientRpcInfo { // grpc::string>& overwrite); void set_server_trailing_metadata(const // std::multimap<grpc::string, grpc::string>& overwrite); void // set_status(Status status); - private: - grpc::ClientContext* ctx_; - const char* method_; - const grpc::Channel* channel_; -}; + public: + /* Runs interceptor at pos \a pos. If \a reverse is set, the interceptor order + * is the reverse */ + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, + unsigned int pos) { + GPR_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + grpc::ClientContext* ctx_ = nullptr; + const char* method_ = nullptr; + const grpc::Channel* channel_ = nullptr; -class ClientInterceptorFactoryInterface { public: - virtual ~ClientInterceptorFactoryInterface() {} - virtual ClientInterceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; + std::vector<std::unique_ptr<experimental::ClientInterceptor>> interceptors_; + bool hijacked_ = false; + int hijacked_interceptor_ = false; + // template <class Op1 = internal::CallNoOp<1>, class Op2 = + // internal::CallNoOp<2>, + // class Op3 = internal::CallNoOp<3>, class Op4 = + // internal::CallNoOp<4>, class Op5 = internal::CallNoOp<5>, class Op6 + // = internal::CallNoOp<6>> + // friend class internal::InterceptorBatchMethodsImpl; }; } // namespace experimental diff --git a/include/grpcpp/impl/codegen/client_unary_call.h b/include/grpcpp/impl/codegen/client_unary_call.h index e4e8364e07..dad31546cc 100644 --- a/include/grpcpp/impl/codegen/client_unary_call.h +++ b/include/grpcpp/impl/codegen/client_unary_call.h @@ -52,7 +52,7 @@ class BlockingUnaryCallImpl { CompletionQueue cq(grpc_completion_queue_attributes{ GRPC_CQ_CURRENT_VERSION, GRPC_CQ_PLUCK, GRPC_CQ_DEFAULT_POLLING, nullptr}); // Pluckable completion queue - Call call(channel->CreateCall(method, context, &cq)); + call_ = std::move(channel->CreateCall(method, context, &cq)); CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage, CallOpRecvInitialMetadata, CallOpRecvMessage<OutputMessage>, CallOpClientSendClose, CallOpClientRecvStatus> @@ -68,7 +68,7 @@ class BlockingUnaryCallImpl { ops.AllowNoMessage(); ops.ClientSendClose(); ops.ClientRecvStatus(context, &status_); - call.PerformOps(&ops); + call_.PerformOps(&ops); if (cq.Pluck(&ops)) { if (!ops.got_message && status_.ok()) { status_ = Status(StatusCode::UNIMPLEMENTED, @@ -82,6 +82,7 @@ class BlockingUnaryCallImpl { private: Status status_; + Call call_; }; } // namespace internal diff --git a/include/grpcpp/impl/codegen/core_codegen.h b/include/grpcpp/impl/codegen/core_codegen.h index e9df96bf04..6ef184d01a 100644 --- a/include/grpcpp/impl/codegen/core_codegen.h +++ b/include/grpcpp/impl/codegen/core_codegen.h @@ -63,6 +63,9 @@ class CoreCodegen final : public CoreCodegenInterface { void gpr_cv_signal(gpr_cv* cv) override; void gpr_cv_broadcast(gpr_cv* cv) override; + grpc_call_error grpc_call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* tag, + void* reserved) override; grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/core_codegen_interface.h b/include/grpcpp/impl/codegen/core_codegen_interface.h index 1167a188a2..25e3abccca 100644 --- a/include/grpcpp/impl/codegen/core_codegen_interface.h +++ b/include/grpcpp/impl/codegen/core_codegen_interface.h @@ -100,6 +100,9 @@ class CoreCodegenInterface { virtual grpc_slice grpc_slice_new_with_len(void* p, size_t len, void (*destroy)(void*, size_t)) = 0; + virtual grpc_call_error grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, size_t nops, + void* tag, void* reserved) = 0; virtual grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 84dce42f97..0b6d796d35 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -19,7 +19,17 @@ #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H +#include <grpc/impl/codegen/grpc_types.h> +#include <grpcpp/impl/codegen/config.h> + +// struct grpc_byte_buffer; +// struct grpc_status_code; +// struct grpc_metadata; + namespace grpc { + +class Status; + namespace experimental { class InterceptedMessage { public: @@ -35,6 +45,7 @@ enum class InterceptionHookPoints { PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, + PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be registered by the global interceptor */ PRE_RECV_INITIAL_METADATA, @@ -60,6 +71,48 @@ class InterceptorBatchMethods { // Calling this indicates that the interceptor has hijacked the RPC (only // valid if the batch contains send_initial_metadata on the client side) virtual void Hijack() = 0; + + virtual void AddInterceptionHookPoint(InterceptionHookPoints type) = 0; + + virtual void GetSendMessage(grpc_byte_buffer** buf) = 0; + + virtual void GetSendInitialMetadata(grpc_metadata** metadata, + size_t** count) = 0; + + virtual void GetSendStatus(grpc_status_code** code, + grpc::string** error_details, + grpc::string** error_message) = 0; + + virtual void GetSendTrailingMetadata(grpc_metadata** metadata, + size_t** count) = 0; + + virtual void GetRecvMessage(void** message) = 0; + + virtual void GetRecvInitialMetadata(grpc_metadata_array** array) = 0; + + virtual void GetRecvStatus(Status** status) = 0; + + virtual void GetRecvTrailingMetadata(grpc_metadata_array** map) = 0; + + virtual void SetSendMessage(grpc_byte_buffer* buf) = 0; + + virtual void SetSendInitialMetadata(grpc_metadata* metadata, + size_t* count) = 0; + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) = 0; + + virtual void SetSendTrailingMetadata(grpc_metadata* metadata, + size_t* count) = 0; + + virtual void SetRecvMessage(void* message) = 0; + + virtual void SetRecvInitialMetadata(grpc_metadata_array* array) = 0; + + virtual void SetRecvStatus(Status* status) = 0; + + virtual void SetRecvTrailingMetadata(grpc_metadata_array* map) = 0; }; } // namespace experimental } // namespace grpc diff --git a/include/grpcpp/impl/codegen/service_type.h b/include/grpcpp/impl/codegen/service_type.h index 9f1a052168..e6ade93087 100644 --- a/include/grpcpp/impl/codegen/service_type.h +++ b/include/grpcpp/impl/codegen/service_type.h @@ -50,7 +50,7 @@ class ServerAsyncStreamingInterface { private: friend class ::grpc::ServerInterface; - virtual void BindCall(Call* call) = 0; + virtual Call* BindCall(Call call) = 0; }; } // namespace internal diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index eba92f00e9..d56e88d035 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -148,18 +148,15 @@ internal::Call Channel::CreateCall(const internal::RpcMethod& method, grpc_census_call_set_context(c_call, context->census_context()); context->set_call(c_call, shared_from_this()); - experimental::ClientRpcInfo info(context, method.name(), this); - return internal::Call(c_call, this, cq, info, interceptor_creators_); + auto* info = context->set_client_rpc_info(experimental::ClientRpcInfo( + context, method.name(), this, interceptor_creators_)); + return std::move(internal::Call(c_call, this, cq, info)); } void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call, cops, &nops); - GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), cops, nops, - ops->cq_tag(), nullptr)); + ops->FillOps( + call); // Make a copy of call. It's fine since Call just has pointers } void* Channel::RegisterMethod(const char* method) { diff --git a/src/cpp/common/core_codegen.cc b/src/cpp/common/core_codegen.cc index 619aacadaa..cfaa2e7b19 100644 --- a/src/cpp/common/core_codegen.cc +++ b/src/cpp/common/core_codegen.cc @@ -102,6 +102,13 @@ size_t CoreCodegen::grpc_byte_buffer_length(grpc_byte_buffer* bb) { return ::grpc_byte_buffer_length(bb); } +grpc_call_error CoreCodegen::grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, + size_t nops, void* tag, + void* reserved) { + return ::grpc_call_start_batch(call, ops, nops, tag, reserved); +} + grpc_call_error CoreCodegen::grpc_call_cancel_with_status( grpc_call* call, grpc_status_code status, const char* description, void* reserved) { diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 27629f2be0..2faaf618a5 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -667,18 +667,7 @@ void Server::Wait() { void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call, cops, &nops); - // TODO(vjpai): Use ops->cq_tag once this case supports callbacks - auto result = grpc_call_start_batch(call->call(), cops, nops, ops, nullptr); - if (result != GRPC_CALL_OK) { - gpr_log(GPR_ERROR, "Fatal: grpc_call_start_batch returned %d", result); - grpc_call_log_batch(__FILE__, __LINE__, GPR_LOG_SEVERITY_ERROR, - call->call(), cops, nops, ops); - abort(); - } + ops->FillOps(call); } ServerInterface::BaseAsyncRequest::BaseAsyncRequest( @@ -705,11 +694,13 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, context_->cq_ = call_cq_; internal::Call call(call_, server_, call_cq_, server_->max_receive_message_size()); + + // just the pointers inside call are copied here + auto* new_call = stream_->BindCall(std::move(call)); if (*status && call_) { - context_->BeginCompletionOp(&call); + context_->BeginCompletionOp(new_call); } - // just the pointers inside call are copied here - stream_->BindCall(&call); + *tag = tag_; if (delete_on_finalize_) { delete this; diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index cfa6c8d7e8..dd94a44e1d 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -47,7 +47,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { finalized_(false), cancelled_(0) {} - void FillOps(internal::Call* call, grpc_op* ops, size_t* nops) override; + void FillOps(internal::Call* call) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -66,6 +66,17 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { void Unref(); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override {} + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override {} + private: bool CheckCancelledNoPluck() { std::lock_guard<std::mutex> g(mu_); @@ -88,13 +99,14 @@ void ServerContext::CompletionOp::Unref() { } } -void ServerContext::CompletionOp::FillOps(internal::Call* call, grpc_op* ops, - size_t* nops) { - ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; - ops->data.recv_close_on_server.cancelled = &cancelled_; - ops->flags = 0; - ops->reserved = nullptr; - *nops = 1; +void ServerContext::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + GPR_ASSERT(GRPC_CALL_OK == + grpc_call_start_batch(call->call(), &ops, 1, cq_tag(), nullptr)); } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { |