diff options
41 files changed, 2243 insertions, 259 deletions
@@ -2042,6 +2042,7 @@ grpc_cc_library( "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_wrapper.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -2054,6 +2055,7 @@ grpc_cc_library( "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", @@ -2062,6 +2064,7 @@ grpc_cc_library( "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", @@ -2081,6 +2084,7 @@ grpc_cc_library( name = "grpc++_codegen_base_src", srcs = [ "src/cpp/codegen/codegen_init.cc", + "src/cpp/codegen/call_wrapper.cc", ], language = "c++", deps = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index e56c64abde..ed1cf5f74f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2870,6 +2870,7 @@ add_library(grpc++ src/cpp/util/status.cc src/cpp/util/string_ref.cc src/cpp/util/time_cc.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc ) @@ -3088,6 +3089,7 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_wrapper.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -3100,6 +3102,7 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h @@ -3108,6 +3111,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -3231,6 +3235,7 @@ add_library(grpc++_cronet src/cpp/util/status.cc src/cpp/util/string_ref.cc src/cpp/util/time_cc.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc src/core/ext/transport/chttp2/client/insecure/channel_create.cc src/core/ext/transport/chttp2/client/insecure/channel_create_posix.cc @@ -3659,6 +3664,7 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_wrapper.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -3671,6 +3677,7 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h @@ -3679,6 +3686,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -3976,6 +3984,7 @@ add_library(grpc++_test_util test/cpp/util/string_ref_helper.cc test/cpp/util/subprocess.cc test/cpp/util/test_credentials_provider.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc ) @@ -4068,6 +4077,7 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_wrapper.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4080,6 +4090,7 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h @@ -4088,6 +4099,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -4160,6 +4172,7 @@ add_library(grpc++_test_util_unsecure test/cpp/util/byte_buffer_proto_helper.cc test/cpp/util/string_ref_helper.cc test/cpp/util/subprocess.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc ) @@ -4249,6 +4262,7 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_wrapper.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4261,6 +4275,7 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h @@ -4269,6 +4284,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -4354,6 +4370,7 @@ add_library(grpc++_unsecure src/cpp/util/status.cc src/cpp/util/string_ref.cc src/cpp/util/time_cc.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc ) @@ -4571,6 +4588,7 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_wrapper.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4583,6 +4601,7 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h @@ -4591,6 +4610,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -12533,6 +12553,7 @@ add_executable(codegen_test_minimal ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h test/cpp/codegen/codegen_test_minimal.cc + src/cpp/codegen/call_wrapper.cc src/cpp/codegen/codegen_init.cc third_party/googletest/googletest/src/gtest-all.cc third_party/googletest/googlemock/src/gmock-all.cc @@ -5249,6 +5249,7 @@ LIBGRPC++_SRC = \ src/cpp/util/status.cc \ src/cpp/util/string_ref.cc \ src/cpp/util/time_cc.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ PUBLIC_HEADERS_CXX += \ @@ -5432,6 +5433,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -5444,6 +5446,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -5452,6 +5455,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -5620,6 +5624,7 @@ LIBGRPC++_CRONET_SRC = \ src/cpp/util/status.cc \ src/cpp/util/string_ref.cc \ src/cpp/util/time_cc.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ src/core/ext/transport/chttp2/client/insecure/channel_create.cc \ src/core/ext/transport/chttp2/client/insecure/channel_create_posix.cc \ @@ -6012,6 +6017,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6024,6 +6030,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -6032,6 +6039,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6367,6 +6375,7 @@ LIBGRPC++_TEST_UTIL_SRC = \ test/cpp/util/string_ref_helper.cc \ test/cpp/util/subprocess.cc \ test/cpp/util/test_credentials_provider.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ PUBLIC_HEADERS_CXX += \ @@ -6406,6 +6415,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6418,6 +6428,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -6426,6 +6437,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6513,6 +6525,7 @@ $(OBJDIR)/$(CONFIG)/test/cpp/util/create_test_channel.o: $(GENDIR)/src/proto/grp $(OBJDIR)/$(CONFIG)/test/cpp/util/string_ref_helper.o: $(GENDIR)/src/proto/grpc/channelz/channelz.pb.cc $(GENDIR)/src/proto/grpc/channelz/channelz.grpc.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/test/cpp/util/subprocess.o: $(GENDIR)/src/proto/grpc/channelz/channelz.pb.cc $(GENDIR)/src/proto/grpc/channelz/channelz.grpc.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/test/cpp/util/test_credentials_provider.o: $(GENDIR)/src/proto/grpc/channelz/channelz.pb.cc $(GENDIR)/src/proto/grpc/channelz/channelz.grpc.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc +$(OBJDIR)/$(CONFIG)/src/cpp/codegen/call_wrapper.o: $(GENDIR)/src/proto/grpc/channelz/channelz.pb.cc $(GENDIR)/src/proto/grpc/channelz/channelz.grpc.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/src/cpp/codegen/codegen_init.o: $(GENDIR)/src/proto/grpc/channelz/channelz.pb.cc $(GENDIR)/src/proto/grpc/channelz/channelz.grpc.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc @@ -6525,6 +6538,7 @@ LIBGRPC++_TEST_UTIL_UNSECURE_SRC = \ test/cpp/util/byte_buffer_proto_helper.cc \ test/cpp/util/string_ref_helper.cc \ test/cpp/util/subprocess.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ PUBLIC_HEADERS_CXX += \ @@ -6564,6 +6578,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6576,6 +6591,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -6584,6 +6600,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6668,6 +6685,7 @@ $(OBJDIR)/$(CONFIG)/test/cpp/end2end/test_service_impl.o: $(GENDIR)/src/proto/gr $(OBJDIR)/$(CONFIG)/test/cpp/util/byte_buffer_proto_helper.o: $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/test/cpp/util/string_ref_helper.o: $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/test/cpp/util/subprocess.o: $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc +$(OBJDIR)/$(CONFIG)/src/cpp/codegen/call_wrapper.o: $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc $(OBJDIR)/$(CONFIG)/src/cpp/codegen/codegen_init.o: $(GENDIR)/src/proto/grpc/health/v1/health.pb.cc $(GENDIR)/src/proto/grpc/health/v1/health.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.pb.cc $(GENDIR)/src/proto/grpc/testing/echo_messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.pb.cc $(GENDIR)/src/proto/grpc/testing/echo.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc $(GENDIR)/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc @@ -6708,6 +6726,7 @@ LIBGRPC++_UNSECURE_SRC = \ src/cpp/util/status.cc \ src/cpp/util/string_ref.cc \ src/cpp/util/time_cc.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ PUBLIC_HEADERS_CXX += \ @@ -6891,6 +6910,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6903,6 +6923,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -6911,6 +6932,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -17366,6 +17388,7 @@ CODEGEN_TEST_MINIMAL_SRC = \ $(GENDIR)/src/proto/grpc/testing/worker_service.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.grpc.pb.cc \ $(GENDIR)/src/proto/grpc/testing/stats.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.grpc.pb.cc \ test/cpp/codegen/codegen_test_minimal.cc \ + src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ CODEGEN_TEST_MINIMAL_OBJS = $(addprefix $(OBJDIR)/$(CONFIG)/, $(addsuffix .o, $(basename $(CODEGEN_TEST_MINIMAL_SRC)))) @@ -17413,6 +17436,8 @@ $(OBJDIR)/$(CONFIG)/src/proto/grpc/testing/stats.o: $(LIBDIR)/$(CONFIG)/libgrpc $(OBJDIR)/$(CONFIG)/test/cpp/codegen/codegen_test_minimal.o: $(LIBDIR)/$(CONFIG)/libgrpc++_core_stats.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a +$(OBJDIR)/$(CONFIG)/src/cpp/codegen/call_wrapper.o: $(LIBDIR)/$(CONFIG)/libgrpc++_core_stats.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a + $(OBJDIR)/$(CONFIG)/src/cpp/codegen/codegen_init.o: $(LIBDIR)/$(CONFIG)/libgrpc++_core_stats.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a deps_codegen_test_minimal: $(CODEGEN_TEST_MINIMAL_OBJS:.o=.dep) @@ -17423,6 +17448,7 @@ ifneq ($(NO_DEPS),true) endif endif $(OBJDIR)/$(CONFIG)/test/cpp/codegen/codegen_test_minimal.o: $(GENDIR)/src/proto/grpc/testing/control.pb.cc $(GENDIR)/src/proto/grpc/testing/control.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.grpc.pb.cc +$(OBJDIR)/$(CONFIG)/src/cpp/codegen/call_wrapper.o: $(GENDIR)/src/proto/grpc/testing/control.pb.cc $(GENDIR)/src/proto/grpc/testing/control.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.grpc.pb.cc $(OBJDIR)/$(CONFIG)/src/cpp/codegen/codegen_init.o: $(GENDIR)/src/proto/grpc/testing/control.pb.cc $(GENDIR)/src/proto/grpc/testing/control.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.pb.cc $(GENDIR)/src/proto/grpc/testing/messages.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.pb.cc $(GENDIR)/src/proto/grpc/testing/payloads.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.pb.cc $(GENDIR)/src/proto/grpc/testing/benchmark_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.pb.cc $(GENDIR)/src/proto/grpc/testing/report_qps_scenario_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.pb.cc $(GENDIR)/src/proto/grpc/testing/worker_service.grpc.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.pb.cc $(GENDIR)/src/proto/grpc/testing/stats.grpc.pb.cc diff --git a/build.yaml b/build.yaml index 9386048e21..4603132b8c 100644 --- a/build.yaml +++ b/build.yaml @@ -1206,6 +1206,7 @@ filegroups: - include/grpcpp/impl/codegen/byte_buffer.h - include/grpcpp/impl/codegen/call.h - include/grpcpp/impl/codegen/call_hook.h + - include/grpcpp/impl/codegen/call_wrapper.h - include/grpcpp/impl/codegen/callback_common.h - include/grpcpp/impl/codegen/channel_interface.h - include/grpcpp/impl/codegen/client_callback.h @@ -1218,6 +1219,7 @@ filegroups: - include/grpcpp/impl/codegen/core_codegen_interface.h - include/grpcpp/impl/codegen/create_auth_context.h - include/grpcpp/impl/codegen/grpc_library.h + - include/grpcpp/impl/codegen/intercepted_channel.h - include/grpcpp/impl/codegen/interceptor.h - include/grpcpp/impl/codegen/metadata_map.h - include/grpcpp/impl/codegen/method_handler_impl.h @@ -1226,6 +1228,7 @@ filegroups: - include/grpcpp/impl/codegen/security/auth_context.h - include/grpcpp/impl/codegen/serialization_traits.h - include/grpcpp/impl/codegen/server_context.h + - include/grpcpp/impl/codegen/server_interceptor.h - include/grpcpp/impl/codegen/server_interface.h - include/grpcpp/impl/codegen/service_type.h - include/grpcpp/impl/codegen/slice.h @@ -1240,6 +1243,7 @@ filegroups: - name: grpc++_codegen_base_src language: c++ src: + - src/cpp/codegen/call_wrapper.cc - src/cpp/codegen/codegen_init.cc uses: - grpc++_codegen_base diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index 1ab17006e3..e8b98cf84b 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -128,6 +128,7 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/byte_buffer.h', 'include/grpcpp/impl/codegen/call.h', 'include/grpcpp/impl/codegen/call_hook.h', + 'include/grpcpp/impl/codegen/call_wrapper.h', 'include/grpcpp/impl/codegen/callback_common.h', 'include/grpcpp/impl/codegen/channel_interface.h', 'include/grpcpp/impl/codegen/client_callback.h', @@ -140,6 +141,7 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/core_codegen_interface.h', 'include/grpcpp/impl/codegen/create_auth_context.h', 'include/grpcpp/impl/codegen/grpc_library.h', + 'include/grpcpp/impl/codegen/intercepted_channel.h', 'include/grpcpp/impl/codegen/interceptor.h', 'include/grpcpp/impl/codegen/metadata_map.h', 'include/grpcpp/impl/codegen/method_handler_impl.h', @@ -148,6 +150,7 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/security/auth_context.h', 'include/grpcpp/impl/codegen/serialization_traits.h', 'include/grpcpp/impl/codegen/server_context.h', + 'include/grpcpp/impl/codegen/server_interceptor.h', 'include/grpcpp/impl/codegen/server_interface.h', 'include/grpcpp/impl/codegen/service_type.h', 'include/grpcpp/impl/codegen/slice.h', @@ -217,6 +220,7 @@ Pod::Spec.new do |s| 'src/cpp/util/status.cc', 'src/cpp/util/string_ref.cc', 'src/cpp/util/time_cc.cc', + 'src/cpp/codegen/call_wrapper.cc', 'src/cpp/codegen/codegen_init.cc', 'src/core/lib/gpr/alloc.h', 'src/core/lib/gpr/arena.h', @@ -1400,6 +1400,7 @@ 'src/cpp/util/status.cc', 'src/cpp/util/string_ref.cc', 'src/cpp/util/time_cc.cc', + 'src/cpp/codegen/call_wrapper.cc', 'src/cpp/codegen/codegen_init.cc', ], }, @@ -1480,6 +1481,7 @@ 'test/cpp/util/string_ref_helper.cc', 'test/cpp/util/subprocess.cc', 'test/cpp/util/test_credentials_provider.cc', + 'src/cpp/codegen/call_wrapper.cc', 'src/cpp/codegen/codegen_init.cc', ], }, @@ -1500,6 +1502,7 @@ 'test/cpp/util/byte_buffer_proto_helper.cc', 'test/cpp/util/string_ref_helper.cc', 'test/cpp/util/subprocess.cc', + 'src/cpp/codegen/call_wrapper.cc', 'src/cpp/codegen/codegen_init.cc', ], }, @@ -1547,6 +1550,7 @@ 'src/cpp/util/status.cc', 'src/cpp/util/string_ref.cc', 'src/cpp/util/time_cc.cc', + 'src/cpp/codegen/call_wrapper.cc', 'src/cpp/codegen/codegen_init.cc', ], }, diff --git a/include/grpcpp/channel.h b/include/grpcpp/channel.h index b7c9e354de..624000b75f 100644 --- a/include/grpcpp/channel.h +++ b/include/grpcpp/channel.h @@ -67,6 +67,7 @@ class Channel final : public ChannelInterface, std::unique_ptr<std::vector< std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> interceptor_creators); + friend class internal::InterceptedChannel; Channel(const grpc::string& host, grpc_channel* c_channel, std::unique_ptr<std::vector< std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> @@ -87,6 +88,10 @@ class Channel final : public ChannelInterface, CompletionQueue* CallbackCQ() override; + internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, CompletionQueue* cq, + int interceptor_pos) override; + const grpc::string host_; grpc_channel* const c_channel_; // owned diff --git a/include/grpcpp/impl/codegen/async_stream.h b/include/grpcpp/impl/codegen/async_stream.h index 6e58fd0eef..bfb2df4f23 100644 --- a/include/grpcpp/impl/codegen/async_stream.h +++ b/include/grpcpp/impl/codegen/async_stream.h @@ -276,7 +276,7 @@ class ClientAsyncReader final : public ClientAsyncReaderInterface<R> { } void StartCallInternal(void* tag) { - init_ops_.SendInitialMetadata(context_->send_initial_metadata_, + init_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); init_ops_.set_output_tag(tag); call_.PerformOps(&init_ops_); @@ -441,7 +441,7 @@ class ClientAsyncWriter final : public ClientAsyncWriterInterface<W> { } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -612,7 +612,7 @@ class ClientAsyncReaderWriter final } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -710,7 +710,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -739,7 +739,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { void Finish(const W& msg, const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); @@ -748,10 +748,10 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { } // The response is dropped if the status is not OK. if (status.ok()) { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, finish_ops_.SendMessage(msg)); } else { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_ops_); } @@ -769,14 +769,14 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface<W, R> { GPR_CODEGEN_ASSERT(!status.ok()); finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -859,7 +859,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -904,7 +904,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -922,7 +922,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { void Finish(const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -932,7 +932,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface<W> { template <class T> void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); @@ -1025,7 +1025,7 @@ class ServerAsyncReaderWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -1075,7 +1075,7 @@ class ServerAsyncReaderWriter final EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -1094,7 +1094,7 @@ class ServerAsyncReaderWriter final finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -1106,7 +1106,7 @@ class ServerAsyncReaderWriter final template <class T> void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/impl/codegen/async_unary_call.h b/include/grpcpp/impl/codegen/async_unary_call.h index 60ff8e2f05..744b128141 100644 --- a/include/grpcpp/impl/codegen/async_unary_call.h +++ b/include/grpcpp/impl/codegen/async_unary_call.h @@ -174,7 +174,7 @@ class ClientAsyncResponseReader final } void StartCallInternal() { - single_buf.SendInitialMetadata(context_->send_initial_metadata_, + single_buf.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); } @@ -214,7 +214,7 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_buf_.set_output_tag(tag); - meta_buf_.SendInitialMetadata(ctx_->initial_metadata_, + meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_buf_.set_compression_level(ctx_->compression_level()); @@ -240,8 +240,9 @@ class ServerAsyncResponseWriter final /// metadata. void Finish(const W& msg, const Status& status, void* tag) { finish_buf_.set_output_tag(tag); + finish_buf_.set_cq_tag(&finish_buf_); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); @@ -250,10 +251,10 @@ class ServerAsyncResponseWriter final } // The response is dropped if the status is not OK. if (status.ok()) { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, finish_buf_.SendMessage(msg)); } else { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_buf_); } @@ -274,14 +275,14 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!status.ok()); finish_buf_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index 8cc5158115..d54ae31852 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -50,6 +50,11 @@ class ErrorMethodHandler; template <class R> class DeserializeFuncType; class GrpcByteBufferPeer; +template <class ServiceType, class RequestType, class ResponseType> +class RpcMethodHandler; +template <class ServiceType, class RequestType, class ResponseType> +class ServerStreamingHandler; + } // namespace internal /// A sequence of bytes. class ByteBuffer final { @@ -141,7 +146,10 @@ class ByteBuffer final { template <class R> friend class internal::CallOpRecvMessage; friend class internal::CallOpGenericRecvMessage; - friend class internal::MethodHandler; + template <class ServiceType, class RequestType, class ResponseType> + friend class RpcMethodHandler; + template <class ServiceType, class RequestType, class ResponseType> + friend class ServerStreamingHandler; template <class ServiceType, class RequestType, class ResponseType> friend class internal::RpcMethodHandler; template <class ServiceType, class RequestType, class ResponseType> diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 789ea805a3..51cc18e2b1 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -20,18 +20,24 @@ #define GRPCPP_IMPL_CODEGEN_CALL_H #include <assert.h> +#include <array> #include <cstring> #include <functional> #include <map> #include <memory> +#include <vector> #include <grpcpp/impl/codegen/byte_buffer.h> #include <grpcpp/impl/codegen/call_hook.h> +#include <grpcpp/impl/codegen/call_wrapper.h> #include <grpcpp/impl/codegen/client_context.h> +#include <grpcpp/impl/codegen/client_interceptor.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/codegen/config.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> +#include <grpcpp/impl/codegen/intercepted_channel.h> #include <grpcpp/impl/codegen/serialization_traits.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/codegen/slice.h> #include <grpcpp/impl/codegen/status.h> #include <grpcpp/impl/codegen/string_ref.h> @@ -42,7 +48,6 @@ namespace grpc { -class ByteBuffer; class CompletionQueue; extern CoreCodegenInterface* g_core_codegen_interface; @@ -201,6 +206,38 @@ class WriteOptions { }; namespace internal { + +class InternalInterceptorBatchMethods + : public experimental::InterceptorBatchMethods { + public: + virtual ~InternalInterceptorBatchMethods() {} + + virtual void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) = 0; + + virtual void SetSendMessage(ByteBuffer* buf) = 0; + + virtual void SetSendInitialMetadata( + std::multimap<grpc::string, grpc::string>* metadata) = 0; + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) = 0; + + virtual void SetSendTrailingMetadata( + std::multimap<grpc::string, grpc::string>* metadata) = 0; + + virtual void SetRecvMessage(void* message) = 0; + + virtual void SetRecvInitialMetadata(internal::MetadataMap* map) = 0; + + virtual void SetRecvStatus(Status* status) = 0; + + virtual void SetRecvTrailingMetadata(internal::MetadataMap* map) = 0; + + virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; +}; + /// Default argument for CallOpSet. I is unused by the class, but can be /// used for generating multiple names for the same thing. template <int I> @@ -208,6 +245,13 @@ class CallNoOp { protected: void AddOp(grpc_op* ops, size_t* nops) {} void FinishOp(bool* status) {} + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + } }; class CallOpSendInitialMetadata { @@ -216,14 +260,12 @@ class CallOpSendInitialMetadata { maybe_compression_level_.is_set = false; } - void SendInitialMetadata( - const std::multimap<grpc::string, grpc::string>& metadata, - uint32_t flags) { + void SendInitialMetadata(std::multimap<grpc::string, grpc::string>* metadata, + uint32_t flags) { maybe_compression_level_.is_set = false; send_ = true; flags_ = flags; - initial_metadata_ = - FillMetadataArray(metadata, &initial_metadata_count_, ""); + metadata_map_ = metadata; } void set_compression_level(grpc_compression_level level) { @@ -233,11 +275,13 @@ class CallOpSendInitialMetadata { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; + if (!send_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_INITIAL_METADATA; op->flags = flags_; op->reserved = NULL; + initial_metadata_ = + FillMetadataArray(*metadata_map_, &initial_metadata_count_, ""); op->data.send_initial_metadata.count = initial_metadata_count_; op->data.send_initial_metadata.metadata = initial_metadata_; op->data.send_initial_metadata.maybe_compression_level.is_set = @@ -248,14 +292,31 @@ class CallOpSendInitialMetadata { } } void FinishOp(bool* status) { - if (!send_) return; + if (!send_ || hijacked_) return; g_core_codegen_interface->gpr_free(initial_metadata_); send_ = false; } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); + interceptor_methods->SetSendInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + bool hijacked_ = false; bool send_; uint32_t flags_; size_t initial_metadata_count_; + std::multimap<grpc::string, grpc::string>* metadata_map_; grpc_metadata* initial_metadata_; struct { bool is_set; @@ -278,7 +339,7 @@ class CallOpSendMessage { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!send_buf_.Valid()) return; + if (!send_buf_.Valid() || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_MESSAGE; op->flags = write_options_.flags(); @@ -289,7 +350,23 @@ class CallOpSendMessage { } void FinishOp(bool* status) { send_buf_.Clear(); } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_buf_.Valid()) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); + interceptor_methods->SetSendMessage(&send_buf_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + private: + bool hijacked_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; @@ -332,7 +409,7 @@ class CallOpRecvMessage { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (message_ == nullptr) return; + if (message_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; op->flags = 0; @@ -341,7 +418,7 @@ class CallOpRecvMessage { } void FinishOp(bool* status) { - if (message_ == nullptr) return; + if (message_ == nullptr || hijacked_) return; if (recv_buf_.Valid()) { if (*status) { got_message = *status = @@ -361,10 +438,30 @@ class CallOpRecvMessage { message_ = nullptr; } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (message_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + got_message = true; + } + private: R* message_; ByteBuffer recv_buf_; bool allow_not_getting_message_; + bool hijacked_ = false; }; class DeserializeFunc { @@ -398,6 +495,7 @@ class CallOpGenericRecvMessage { // following unique_ptr::reset for some old implementations. DeserializeFunc* func = new DeserializeFuncType<R>(message); deserialize_.reset(func); + message_ = message; } // Do not change status if no message is received. @@ -407,7 +505,7 @@ class CallOpGenericRecvMessage { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!deserialize_) return; + if (!deserialize_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; op->flags = 0; @@ -416,7 +514,7 @@ class CallOpGenericRecvMessage { } void FinishOp(bool* status) { - if (!deserialize_) return; + if (!deserialize_ || hijacked_) return; if (recv_buf_.Valid()) { if (*status) { got_message = true; @@ -435,7 +533,27 @@ class CallOpGenericRecvMessage { deserialize_.reset(); } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (!deserialize_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + } + private: + void* message_; + bool hijacked_ = false; std::unique_ptr<DeserializeFunc> deserialize_; ByteBuffer recv_buf_; bool allow_not_getting_message_; @@ -449,7 +567,7 @@ class CallOpClientSendClose { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; + if (!send_ || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; op->flags = 0; @@ -457,7 +575,22 @@ class CallOpClientSendClose { } void FinishOp(bool* status) { send_ = false; } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + private: + bool hijacked_ = false; bool send_; }; @@ -466,11 +599,10 @@ class CallOpServerSendStatus { CallOpServerSendStatus() : send_status_available_(false) {} void ServerSendStatus( - const std::multimap<grpc::string, grpc::string>& trailing_metadata, + std::multimap<grpc::string, grpc::string>* trailing_metadata, const Status& status) { send_error_details_ = status.error_details(); - trailing_metadata_ = FillMetadataArray( - trailing_metadata, &trailing_metadata_count_, send_error_details_); + metadata_map_ = trailing_metadata; send_status_available_ = true; send_status_code_ = static_cast<grpc_status_code>(status.error_code()); send_error_message_ = status.error_message(); @@ -478,7 +610,9 @@ class CallOpServerSendStatus { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!send_status_available_) return; + if (!send_status_available_ || hijacked_) return; + trailing_metadata_ = FillMetadataArray( + *metadata_map_, &trailing_metadata_count_, send_error_details_); grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; op->data.send_status_from_server.trailing_metadata_count = @@ -493,17 +627,36 @@ class CallOpServerSendStatus { } void FinishOp(bool* status) { - if (!send_status_available_) return; + if (!send_status_available_ || hijacked_) return; g_core_codegen_interface->gpr_free(trailing_metadata_); send_status_available_ = false; } + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_status_available_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS); + interceptor_methods->SetSendTrailingMetadata(metadata_map_); + interceptor_methods->SetSendStatus(&send_status_code_, &send_error_details_, + &send_error_message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + private: + bool hijacked_ = false; bool send_status_available_; grpc_status_code send_status_code_; grpc::string send_error_details_; grpc::string send_error_message_; size_t trailing_metadata_count_; + std::multimap<grpc::string, grpc::string>* metadata_map_; grpc_metadata* trailing_metadata_; grpc_slice error_message_slice_; }; @@ -519,7 +672,7 @@ class CallOpRecvInitialMetadata { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (metadata_map_ == nullptr) return; + if (metadata_map_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_INITIAL_METADATA; op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); @@ -528,11 +681,31 @@ class CallOpRecvInitialMetadata { } void FinishOp(bool* status) { + if (metadata_map_ == nullptr || hijacked_) return; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); metadata_map_ = nullptr; } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA); + } + private: + bool hijacked_ = false; MetadataMap* metadata_map_; }; @@ -550,7 +723,7 @@ class CallOpClientRecvStatus { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (recv_status_ == nullptr) return; + if (recv_status_ == nullptr || hijacked_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); @@ -562,7 +735,7 @@ class CallOpClientRecvStatus { } void FinishOp(bool* status) { - if (recv_status_ == nullptr) return; + if (recv_status_ == nullptr || hijacked_) return; grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); *recv_status_ = Status(static_cast<StatusCode>(status_code_), @@ -577,10 +750,31 @@ class CallOpClientRecvStatus { if (debug_error_string_ != nullptr) { g_core_codegen_interface->gpr_free((void*)debug_error_string_); } + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvStatus(recv_status_); + interceptor_methods->SetRecvTrailingMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS); recv_status_ = nullptr; } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS); + } + private: + bool hijacked_ = false; ClientContext* client_context_; MetadataMap* metadata_map_; Status* recv_status_; @@ -598,12 +792,341 @@ class CallOpSetInterface : public CompletionQueueTag { public: /// Fills in grpc_op, starting from ops[*nops] and moving /// upwards. - virtual void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) = 0; + virtual void FillOps(internal::Call* call) = 0; /// Get the tag to be used at the core completion queue. Generally, the /// value of cq_tag will be "this". However, it can be overridden if we /// want core to process the tag differently (e.g., as a core callback) virtual void* cq_tag() = 0; + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + virtual void SetHijackingState() = 0; + + // Should be called after interceptors are done running + virtual void ContinueFillOpsAfterInterception() = 0; + + // Should be called after interceptors are done running on the finalize result + // path + virtual void ContinueFinalizeResultAfterInterception() = 0; +}; + +template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, + class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, + class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +class CallOpSet; + +class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { + public: + InterceptorBatchMethodsImpl() { + for (auto i = 0; + i < static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); + i++) { + hooks_[i] = false; + } + } + + virtual ~InterceptorBatchMethodsImpl() {} + + virtual bool QueryInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + return hooks_[static_cast<int>(type)]; + } + + virtual void Proceed() override { /* fill this */ + if (call_->client_rpc_info() != nullptr) { + return ProceedClient(); + } + GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); + ProceedServer(); + } + + virtual void Hijack() override { + // Only the client can hijack when sending down initial metadata + GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && + call_->client_rpc_info() != nullptr); + auto* rpc_info = call_->client_rpc_info(); + rpc_info->hijacked_ = true; + rpc_info->hijacked_interceptor_ = curr_iteration_; + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, curr_iteration_); + } + + virtual void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + hooks_[static_cast<int>(type)] = true; + } + + virtual ByteBuffer* GetSendMessage() override { return send_message_; } + + virtual std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() + override { + return send_initial_metadata_; + } + + virtual Status GetSendStatus() override { + return Status(static_cast<StatusCode>(*code_), *error_message_, + *error_details_); + } + + virtual void ModifySendStatus(const Status& status) override { + *code_ = static_cast<grpc_status_code>(status.error_code()); + *error_details_ = status.error_details(); + *error_message_ = status.error_message(); + } + + virtual std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() + override { + return send_trailing_metadata_; + } + + virtual void* GetRecvMessage() override { return recv_message_; } + + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvInitialMetadata() override { + return recv_initial_metadata_->map(); + } + + virtual Status* GetRecvStatus() override { return recv_status_; } + + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvTrailingMetadata() override { + return recv_trailing_metadata_->map(); + } + + virtual void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } + + virtual void SetSendInitialMetadata( + std::multimap<grpc::string, grpc::string>* metadata) override { + send_initial_metadata_ = metadata; + } + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) override { + code_ = code; + error_details_ = error_details; + error_message_ = error_message; + } + + virtual void SetSendTrailingMetadata( + std::multimap<grpc::string, grpc::string>* metadata) override { + send_trailing_metadata_ = metadata; + } + + virtual void SetRecvMessage(void* message) override { + recv_message_ = message; + } + + virtual void SetRecvInitialMetadata(internal::MetadataMap* map) override { + recv_initial_metadata_ = map; + } + + virtual void SetRecvStatus(Status* status) override { recv_status_ = status; } + + virtual void SetRecvTrailingMetadata(internal::MetadataMap* map) override { + recv_trailing_metadata_ = map; + } + + virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { + auto* info = call_->client_rpc_info(); + if (info == nullptr) { + return std::unique_ptr<ChannelInterface>(nullptr); + } + // The intercepted channel starts from the interceptor just after the + // current interceptor + return std::unique_ptr<ChannelInterface>(new internal::InterceptedChannel( + reinterpret_cast<grpc::ChannelInterface*>(info->channel()), + curr_iteration_ + 1)); + } + + // Prepares for Post_recv operations + void SetReverse() { + reverse_ = true; + ran_hijacking_interceptor_ = false; + ClearHookPoints(); + } + + // This needs to be set before interceptors are run + void SetCall(Call* call) { call_ = call; } + + void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } + + // Returns true if no interceptors are run. This should be used only by + // subclasses of CallOpSetInterface. SetCall and SetCallOpSetInterface should + // have been called before this. After all the interceptors are done running, + // either ContinueFillOpsAfterInterception or + // ContinueFinalizeOpsAfterInterception will be called. Note that neither of + // them is invoked if there were no interceptors registered. + bool RunInterceptors() { + auto* client_rpc_info = call_->client_rpc_info(); + if (client_rpc_info == nullptr || + client_rpc_info->interceptors_.size() == 0) { + return true; + } else { + RunClientInterceptors(); + return false; + } + + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + RunServerInterceptors(); + return false; + } + + // Returns true if no interceptors are run. Returns false otherwise if there + // are interceptors registered. After the interceptors are done running \a f + // will + // be invoked. This is to be used only by BaseAsyncRequest and SyncRequest. + bool RunInterceptors(std::function<void(void)> f) { + GPR_CODEGEN_ASSERT(reverse_ == true); + GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + callback_ = std::move(f); + RunServerInterceptors(); + return false; + } + + private: + void RunClientInterceptors() { + auto* rpc_info = call_->client_rpc_info(); + if (!reverse_) { + curr_iteration_ = 0; + } else { + if (rpc_info->hijacked_) { + curr_iteration_ = rpc_info->hijacked_interceptor_; + gpr_log(GPR_ERROR, "running from the hijacked %d", + rpc_info->hijacked_interceptor_); + } else { + curr_iteration_ = rpc_info->interceptors_.size() - 1; + } + } + rpc_info->RunInterceptor(this, curr_iteration_); + } + + void RunServerInterceptors() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + curr_iteration_ = 0; + } else { + curr_iteration_ = rpc_info->interceptors_.size() - 1; + } + rpc_info->RunInterceptor(this, curr_iteration_); + } + + void ProceedClient() { + auto* rpc_info = call_->client_rpc_info(); + if (rpc_info->hijacked_ && !reverse_ && + curr_iteration_ == rpc_info->hijacked_interceptor_ && + !ran_hijacking_interceptor_) { + // We now need to provide hijacked recv ops to this interceptor + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, curr_iteration_); + return; + } + if (!reverse_) { + curr_iteration_++; + // We are going down the stack of interceptors + if (curr_iteration_ < static_cast<long>(rpc_info->interceptors_.size())) { + if (rpc_info->hijacked_ && + curr_iteration_ > rpc_info->hijacked_interceptor_) { + // This is a hijacked RPC and we are done with hijacking + ops_->ContinueFillOpsAfterInterception(); + } else { + rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFillOpsAfterInterception(); + } + } else { + curr_iteration_--; + // We are going up the stack of interceptors + if (curr_iteration_ >= 0) { + // Continue running interceptors + rpc_info->RunInterceptor(this, curr_iteration_); + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFinalizeResultAfterInterception(); + } + } + } + + void ProceedServer() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + curr_iteration_++; + if (curr_iteration_ < static_cast<long>(rpc_info->interceptors_.size())) { + return rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + curr_iteration_--; + // We are going up the stack of interceptors + if (curr_iteration_ >= 0) { + // Continue running interceptors + return rpc_info->RunInterceptor(this, curr_iteration_); + } + } + // we are done running all the interceptors + if (ops_) { + ops_->ContinueFinalizeResultAfterInterception(); + } + GPR_CODEGEN_ASSERT(callback_); + callback_(); + } + + void ClearHookPoints() { + for (auto i = 0; + i < static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); + i++) { + hooks_[i] = false; + } + } + + std::array<bool, + static_cast<int>( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> + hooks_; + + int curr_iteration_ = 0; // Current iterator + bool reverse_ = false; + bool ran_hijacking_interceptor_ = false; + Call* call_ = + nullptr; // The Call object is present along with CallOpSet object + CallOpSetInterface* ops_ = nullptr; + std::function<void(void)> callback_; + + ByteBuffer* send_message_ = nullptr; + + std::multimap<grpc::string, grpc::string>* send_initial_metadata_; + + grpc_status_code* code_ = nullptr; + grpc::string* error_details_ = nullptr; + grpc::string* error_message_ = nullptr; + Status send_status_; + + std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr; + + void* recv_message_ = nullptr; + + internal::MetadataMap* recv_initial_metadata_ = nullptr; + + Status* recv_status_ = nullptr; + + internal::MetadataMap* recv_trailing_metadata_ = nullptr; }; /// Primary implementation of CallOpSetInterface. @@ -612,9 +1135,7 @@ class CallOpSetInterface : public CompletionQueueTag { /// empty base class optimization to slim this class (especially /// when there are many unused slots used). To avoid duplicate base classes, /// the template parmeter for CallNoOp is varied by argument position. -template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>, - class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, - class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +template <class Op1, class Op2, class Op3, class Op4, class Op5, class Op6> class CallOpSet : public CallOpSetInterface, public Op1, public Op2, @@ -623,42 +1144,66 @@ class CallOpSet : public CallOpSetInterface, public Op5, public Op6 { public: - CallOpSet() : cq_tag_(this), return_tag_(this), call_(nullptr) {} - + CallOpSet() : cq_tag_(this), return_tag_(this) {} // The copy constructor and assignment operator reset the value of // cq_tag_ and return_tag_ since those are only meaningful on a specific // object, not across objects. CallOpSet(const CallOpSet& other) - : cq_tag_(this), return_tag_(this), call_(other.call_) {} + : cq_tag_(this), + return_tag_(this), + call_(other.call_), + done_intercepting_(false), + interceptor_methods_(InterceptorBatchMethodsImpl()) {} + CallOpSet& operator=(const CallOpSet& other) { cq_tag_ = this; return_tag_ = this; call_ = other.call_; + done_intercepting_ = false; + interceptor_methods_ = InterceptorBatchMethodsImpl(); return *this; } - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override { - this->Op1::AddOp(ops, nops); - this->Op2::AddOp(ops, nops); - this->Op3::AddOp(ops, nops); - this->Op4::AddOp(ops, nops); - this->Op5::AddOp(ops, nops); - this->Op6::AddOp(ops, nops); - g_core_codegen_interface->grpc_call_ref(call); - call_ = call; + void FillOps(Call* call) override { + done_intercepting_ = false; + g_core_codegen_interface->grpc_call_ref(call->call()); + call_ = + *call; // It's fine to create a copy of call since it's just pointers + + if (RunInterceptors()) { + ContinueFillOpsAfterInterception(); + } else { + // After the interceptors are run, ContinueFillOpsAfterInterception will + // be run + } } bool FinalizeResult(void** tag, bool* status) override { + if (done_intercepting_) { + // We have already finished intercepting and filling in the results. This + // round trip from the core needed to be made because interceptors were + // run + *tag = return_tag_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + this->Op1::FinishOp(status); this->Op2::FinishOp(status); this->Op3::FinishOp(status); this->Op4::FinishOp(status); this->Op5::FinishOp(status); this->Op6::FinishOp(status); - *tag = return_tag_; - g_core_codegen_interface->grpc_call_unref(call_); - return true; + if (RunInterceptorsPostRecv()) { + *tag = return_tag_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + + // Interceptors are going to be run, so we can't return the tag just yet. + // After the interceptors are run, ContinueFinalizeResultAfterInterception + return false; } void set_output_tag(void* return_tag) { return_tag_ = return_tag; } @@ -671,44 +1216,76 @@ class CallOpSet : public CallOpSetInterface, /// function (such as FinalizeResult) void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } - private: - void* cq_tag_; - void* return_tag_; - grpc_call* call_; -}; - -/// Straightforward wrapping of the C call object -class Call final { - public: - /** call is owned by the caller */ - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(-1) {} - - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - int max_receive_message_size) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(max_receive_message_size) {} - - void PerformOps(CallOpSetInterface* ops) { - call_hook_->PerformOpsOnCall(ops, this); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + this->Op1::SetHijackingState(&interceptor_methods_); + this->Op2::SetHijackingState(&interceptor_methods_); + this->Op3::SetHijackingState(&interceptor_methods_); + this->Op4::SetHijackingState(&interceptor_methods_); + this->Op5::SetHijackingState(&interceptor_methods_); + this->Op6::SetHijackingState(&interceptor_methods_); } - grpc_call* call() const { return call_; } - CompletionQueue* cq() const { return cq_; } + // Should be called after interceptors are done running + void ContinueFillOpsAfterInterception() override { + static const size_t MAX_OPS = 6; + grpc_op ops[MAX_OPS]; + size_t nops = 0; + this->Op1::AddOp(ops, &nops); + this->Op2::AddOp(ops, &nops); + this->Op3::AddOp(ops, &nops); + this->Op4::AddOp(ops, &nops); + this->Op5::AddOp(ops, &nops); + this->Op6::AddOp(ops, &nops); + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), ops, nops, cq_tag(), nullptr)); + } - int max_receive_message_size() const { return max_receive_message_size_; } + // Should be called after interceptors are done running on the finalize result + // path + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, cq_tag(), nullptr)); + } private: - CallHook* call_hook_; - CompletionQueue* cq_; - grpc_call* call_; - int max_receive_message_size_; + // Returns true if no interceptors need to be run + bool RunInterceptors() { + this->Op1::SetInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetInterceptionHookPoint(&interceptor_methods_); + interceptor_methods_.SetCallOpSetInterface(this); + interceptor_methods_.SetCall(&call_); + // interceptor_methods_.SetFunctions(ContinueFillOpsAfterInterception, + // SetHijackingState, ContinueFinalizeResultAfterInterception); + return interceptor_methods_.RunInterceptors(); + } + // Returns true if no interceptors need to be run + bool RunInterceptorsPostRecv() { + interceptor_methods_.SetReverse(); + this->Op1::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetFinishInterceptionHookPoint(&interceptor_methods_); + return interceptor_methods_.RunInterceptors(); + } + + void* cq_tag_; + void* return_tag_; + Call call_; + bool done_intercepting_ = false; + InterceptorBatchMethodsImpl interceptor_methods_; }; + } // namespace internal } // namespace grpc diff --git a/include/grpcpp/impl/codegen/call_wrapper.h b/include/grpcpp/impl/codegen/call_wrapper.h new file mode 100644 index 0000000000..675e36feb9 --- /dev/null +++ b/include/grpcpp/impl/codegen/call_wrapper.h @@ -0,0 +1,91 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#ifndef GRPCPP_IMPL_CODEGEN_CALL_WRAPPER_H +#define GRPCPP_IMPL_CODEGEN_CALL_WRAPPER_H + +#include <grpc/impl/codegen/grpc_types.h> + +namespace grpc { +class CompletionQueue; + +namespace experimental { +class ClientRpcInfo; +class ServerRpcInfo; +} // namespace experimental +namespace internal { +class CallHook; +class CallOpSetInterface; + +/// Straightforward wrapping of the C call object +class Call final { + public: + Call() + : call_hook_(nullptr), + cq_(nullptr), + call_(nullptr), + max_receive_message_size_(-1) {} + /** call is owned by the caller */ + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + experimental::ClientRpcInfo* rpc_info) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1), + client_rpc_info_(rpc_info) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_receive_message_size, experimental::ServerRpcInfo* rpc_info) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(max_receive_message_size), + server_rpc_info_(rpc_info) {} + + void PerformOps(CallOpSetInterface* ops); + + grpc_call* call() const { return call_; } + CompletionQueue* cq() const { return cq_; } + + int max_receive_message_size() const { return max_receive_message_size_; } + + experimental::ClientRpcInfo* client_rpc_info() const { + return client_rpc_info_; + } + + experimental::ServerRpcInfo* server_rpc_info() const { + return server_rpc_info_; + } + + private: + CallHook* call_hook_; + CompletionQueue* cq_; + grpc_call* call_; + int max_receive_message_size_; + experimental::ClientRpcInfo* client_rpc_info_ = nullptr; + experimental::ServerRpcInfo* server_rpc_info_ = nullptr; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_CALL_WRAPPER_H
\ No newline at end of file diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index b257acc1ab..03fa502df6 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -20,6 +20,7 @@ #define GRPCPP_IMPL_CODEGEN_CHANNEL_INTERFACE_H #include <grpc/impl/codegen/connectivity_state.h> +#include <grpcpp/impl/codegen/call_wrapper.h> #include <grpcpp/impl/codegen/status.h> #include <grpcpp/impl/codegen/time.h> @@ -51,6 +52,7 @@ template <class W, class R> class ClientAsyncReaderWriterFactory; template <class R> class ClientAsyncResponseReaderFactory; +class InterceptedChannel; } // namespace internal /// Codegen interface for \a grpc::Channel. @@ -108,6 +110,7 @@ class ChannelInterface { template <class InputMessage, class OutputMessage> friend class ::grpc::internal::CallbackUnaryCallImpl; friend class ::grpc::internal::RpcMethod; + friend class ::grpc::internal::InterceptedChannel; virtual internal::Call CreateCall(const internal::RpcMethod& method, ClientContext* context, CompletionQueue* cq) = 0; @@ -119,6 +122,15 @@ class ChannelInterface { CompletionQueue* cq, void* tag) = 0; virtual bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, gpr_timespec deadline) = 0; + // This is needed to keep codegen_test_minimal happy. InterceptedChannel needs + // to make use of this but can't directly call Channel's implementation + // because of the test. + virtual internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + int interceptor_pos) { + return internal::Call(); + } // EXPERIMENTAL // A method to get the callbackable completion queue associated with this diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4d4faea063..ecb00a0769 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -77,7 +77,7 @@ class CallbackUnaryCallImpl { tag->force_run(s); return; } - ops->SendInitialMetadata(context->send_initial_metadata_, + ops->SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops->RecvInitialMetadata(context); ops->RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 24f5c431ce..59c61c4f0e 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -41,6 +41,7 @@ #include <grpc/impl/codegen/compression_types.h> #include <grpc/impl/codegen/propagation_bits.h> +#include <grpcpp/impl/codegen/client_interceptor.h> #include <grpcpp/impl/codegen/config.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/create_auth_context.h> @@ -402,6 +403,17 @@ class ClientContext { grpc_call* call() const { return call_; } void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel); + experimental::ClientRpcInfo* set_client_rpc_info( + const char* method, grpc::Channel* channel, + const std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>& + creators, + int interceptor_pos) { + rpc_info_ = experimental::ClientRpcInfo(this, method, channel); + rpc_info_.RegisterInterceptors(creators, interceptor_pos); + return &rpc_info_; + } + uint32_t initial_metadata_flags() const { return (idempotent_ ? GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST : 0) | (wait_for_ready_ ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0) | @@ -439,6 +451,8 @@ class ClientContext { bool initial_metadata_corked_; grpc::string debug_error_string_; + + experimental::ClientRpcInfo rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f460c5ac0c..8f32814838 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -19,23 +19,77 @@ #ifndef GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H +#include <vector> + +#include <grpc/impl/codegen/log.h> #include <grpcpp/impl/codegen/interceptor.h> +#include <grpcpp/impl/codegen/string_ref.h> namespace grpc { -namespace experimental { -class ClientInterceptor { - public: - virtual ~ClientInterceptor() {} - virtual void Intercept(InterceptorBatchMethods* methods) = 0; -}; +class ClientContext; +class Channel; -class ClientRpcInfo {}; +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ClientRpcInfo; class ClientInterceptorFactoryInterface { public: virtual ~ClientInterceptorFactoryInterface() {} - virtual ClientInterceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; + virtual Interceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; +}; + +class ClientRpcInfo { + public: + ClientRpcInfo() {} + + ~ClientRpcInfo(){}; + + ClientRpcInfo(const ClientRpcInfo&) = delete; + ClientRpcInfo(ClientRpcInfo&&) = default; + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + Channel* channel() { return channel_; } + grpc::ClientContext* client_context() { return ctx_; } + + private: + ClientRpcInfo(grpc::ClientContext* ctx, const char* method, + grpc::Channel* channel) + : ctx_(ctx), method_(method), channel_(channel) {} + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, + unsigned int pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + void RegisterInterceptors( + const std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>& creators, + int interceptor_pos) { + for (auto it = creators.begin() + interceptor_pos; it != creators.end(); + ++it) { + interceptors_.push_back(std::unique_ptr<experimental::Interceptor>( + (*it)->CreateClientInterceptor(this))); + } + } + + grpc::ClientContext* ctx_ = nullptr; + const char* method_ = nullptr; + grpc::Channel* channel_ = nullptr; + std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; + bool hijacked_ = false; + int hijacked_interceptor_ = false; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ClientContext; }; } // namespace experimental diff --git a/include/grpcpp/impl/codegen/client_unary_call.h b/include/grpcpp/impl/codegen/client_unary_call.h index e4e8364e07..b1c80764f2 100644 --- a/include/grpcpp/impl/codegen/client_unary_call.h +++ b/include/grpcpp/impl/codegen/client_unary_call.h @@ -61,7 +61,7 @@ class BlockingUnaryCallImpl { if (!status_.ok()) { return; } - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops.RecvInitialMetadata(context); ops.RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h index a62d48abe7..fc1ed3278a 100644 --- a/include/grpcpp/impl/codegen/completion_queue.h +++ b/include/grpcpp/impl/codegen/completion_queue.h @@ -299,14 +299,17 @@ class CompletionQueue : private GrpcLibraryCodegen { bool Pluck(internal::CompletionQueueTag* tag) { auto deadline = g_core_codegen_interface->gpr_inf_future(GPR_CLOCK_REALTIME); - auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( - cq_, tag, deadline, nullptr); - bool ok = ev.success != 0; - void* ignored = tag; - GPR_CODEGEN_ASSERT(tag->FinalizeResult(&ignored, &ok)); - GPR_CODEGEN_ASSERT(ignored == tag); - // Ignore mutations by FinalizeResult: Pluck returns the C API status - return ev.success != 0; + while (true) { + auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( + cq_, tag, deadline, nullptr); + bool ok = ev.success != 0; + void* ignored = tag; + if (tag->FinalizeResult(&ignored, &ok)) { + GPR_CODEGEN_ASSERT(ignored == tag); + // Ignore mutations by FinalizeResult: Pluck returns the C API status + return ev.success != 0; + } + } } /// Performs a single polling pluck on \a tag. diff --git a/include/grpcpp/impl/codegen/core_codegen.h b/include/grpcpp/impl/codegen/core_codegen.h index e9df96bf04..6ef184d01a 100644 --- a/include/grpcpp/impl/codegen/core_codegen.h +++ b/include/grpcpp/impl/codegen/core_codegen.h @@ -63,6 +63,9 @@ class CoreCodegen final : public CoreCodegenInterface { void gpr_cv_signal(gpr_cv* cv) override; void gpr_cv_broadcast(gpr_cv* cv) override; + grpc_call_error grpc_call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* tag, + void* reserved) override; grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/core_codegen_interface.h b/include/grpcpp/impl/codegen/core_codegen_interface.h index 1167a188a2..25e3abccca 100644 --- a/include/grpcpp/impl/codegen/core_codegen_interface.h +++ b/include/grpcpp/impl/codegen/core_codegen_interface.h @@ -100,6 +100,9 @@ class CoreCodegenInterface { virtual grpc_slice grpc_slice_new_with_len(void* p, size_t len, void (*destroy)(void*, size_t)) = 0; + virtual grpc_call_error grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, size_t nops, + void* tag, void* reserved) = 0; virtual grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/intercepted_channel.h b/include/grpcpp/impl/codegen/intercepted_channel.h new file mode 100644 index 0000000000..91d9cd84e3 --- /dev/null +++ b/include/grpcpp/impl/codegen/intercepted_channel.h @@ -0,0 +1,78 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H +#define GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H + +#include <grpcpp/impl/codegen/channel_interface.h> + +namespace grpc { + +namespace internal { + +class InterceptorBatchMethodsImpl; + +class InterceptedChannel : public ChannelInterface { + public: + virtual ~InterceptedChannel() { channel_ = nullptr; } + + /// Get the current channel state. If the channel is in IDLE and + /// \a try_to_connect is set to true, try to connect. + grpc_connectivity_state GetState(bool try_to_connect) override { + return channel_->GetState(try_to_connect); + } + + private: + InterceptedChannel(ChannelInterface* channel, int pos) + : channel_(channel), interceptor_pos_(pos) {} + + internal::Call CreateCall(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq) override { + return channel_->CreateCallInternal(method, context, cq, interceptor_pos_); + } + + void PerformOpsOnCall(internal::CallOpSetInterface* ops, + internal::Call* call) override { + return channel_->PerformOpsOnCall(ops, call); + } + void* RegisterMethod(const char* method) override { + return channel_->RegisterMethod(method); + } + + void NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, CompletionQueue* cq, + void* tag) override { + return channel_->NotifyOnStateChangeImpl(last_observed, deadline, cq, tag); + } + bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) override { + return channel_->WaitForStateChangeImpl(last_observed, deadline); + } + + CompletionQueue* CallbackCQ() override { return channel_->CallbackCQ(); } + + ChannelInterface* channel_; + int interceptor_pos_; + + friend class InterceptorBatchMethodsImpl; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H
\ No newline at end of file diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 6402a3a946..2027fd69b1 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -19,7 +19,21 @@ #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H +#include <grpc/impl/codegen/grpc_types.h> +#include <grpcpp/impl/codegen/byte_buffer.h> +#include <grpcpp/impl/codegen/channel_interface.h> +#include <grpcpp/impl/codegen/config.h> +#include <grpcpp/impl/codegen/core_codegen_interface.h> +#include <grpcpp/impl/codegen/metadata_map.h> + +// struct grpc_byte_buffer; +// struct grpc_status_code; +// struct grpc_metadata; + namespace grpc { + +class Status; + namespace experimental { class InterceptedMessage { public: @@ -35,6 +49,7 @@ enum class InterceptionHookPoints { PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, + PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be registered by the global interceptor */ PRE_RECV_INITIAL_METADATA, @@ -50,7 +65,7 @@ enum class InterceptionHookPoints { class InterceptorBatchMethods { public: - virtual ~InterceptorBatchMethods(); + virtual ~InterceptorBatchMethods(){}; // Queries to check whether the current batch has an interception hook point // of type \a type virtual bool QueryInterceptionHookPoint(InterceptionHookPoints type) = 0; @@ -60,7 +75,53 @@ class InterceptorBatchMethods { // Calling this indicates that the interceptor has hijacked the RPC (only // valid if the batch contains send_initial_metadata on the client side) virtual void Hijack() = 0; + + // Returns a modifable ByteBuffer holding serialized form of the message to be + // sent + virtual ByteBuffer* GetSendMessage() = 0; + + // Returns a modifiable multimap of the initial metadata to be sent + virtual std::multimap<grpc::string, grpc::string>* + GetSendInitialMetadata() = 0; + + // Returns the status to be sent + virtual Status GetSendStatus() = 0; + + // Modifies the status with \a status + virtual void ModifySendStatus(const Status& status) = 0; + + // Returns a modifiable multimap of the trailing metadata to be sent + virtual std::multimap<grpc::string, grpc::string>* + GetSendTrailingMetadata() = 0; + + // Returns a pointer to the modifiable received message. Note that the message + // is already deserialized + virtual void* GetRecvMessage() = 0; + + // Returns a modifiable multimap of the received initial metadata + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvInitialMetadata() = 0; + + // Returns a modifiable view of the received status + virtual Status* GetRecvStatus() = 0; + + // Returns a modifiable multimap of the received trailing metadata + virtual std::multimap<grpc::string_ref, grpc::string_ref>* + GetRecvTrailingMetadata() = 0; + + // Gets an intercepted channel. When a call is started on this interceptor, + // only interceptors after the current interceptor are created from the + // factory objects registered with the channel. + virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; }; + +class Interceptor { + public: + virtual ~Interceptor() {} + + virtual void Intercept(InterceptorBatchMethods* methods) = 0; +}; + } // namespace experimental } // namespace grpc diff --git a/include/grpcpp/impl/codegen/method_handler_impl.h b/include/grpcpp/impl/codegen/method_handler_impl.h index 53117f941b..4f02e3e39b 100644 --- a/include/grpcpp/impl/codegen/method_handler_impl.h +++ b/include/grpcpp/impl/codegen/method_handler_impl.h @@ -59,21 +59,21 @@ class RpcMethodHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits<RequestType>::Deserialize( - param.request.bbuf_ptr(), &req); ResponseType rsp; + Status status = param.status; if (status.ok()) { - status = CatchingFunctionHandler([this, ¶m, &req, &rsp] { - return func_(service_, param.server_context, &req, &rsp); + status = CatchingFunctionHandler([this, ¶m, &rsp] { + return func_(service_, param.server_context, + static_cast<RequestType*>(param.request), &rsp); }); + delete static_cast<RequestType*>(param.request); } GPR_CODEGEN_ASSERT(!param.server_context->sent_initial_metadata_); CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage, CallOpServerSendStatus> ops; - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -81,11 +81,24 @@ class RpcMethodHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits<RequestType>::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: /// Application provided rpc handler function. std::function<Status(ServiceType*, ServerContext*, const RequestType*, @@ -117,7 +130,7 @@ class ClientStreamingHandler : public MethodHandler { CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -126,7 +139,7 @@ class ClientStreamingHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } @@ -150,26 +163,25 @@ class ServerStreamingHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits<RequestType>::Deserialize( - param.request.bbuf_ptr(), &req); - + Status status = param.status; if (status.ok()) { ServerWriter<ResponseType> writer(param.call, param.server_context); - status = CatchingFunctionHandler([this, ¶m, &req, &writer] { - return func_(service_, param.server_context, &req, &writer); + status = CatchingFunctionHandler([this, ¶m, &writer] { + return func_(service_, param.server_context, + static_cast<RequestType*>(param.request), &writer); }); + delete static_cast<RequestType*>(param.request); } CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -177,6 +189,19 @@ class ServerStreamingHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits<RequestType>::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: std::function<Status(ServiceType*, ServerContext*, const RequestType*, ServerWriter<ResponseType>*)> @@ -206,7 +231,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -218,7 +243,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { "Service did not provide response message"); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -281,14 +306,14 @@ class ErrorMethodHandler : public MethodHandler { static void FillOps(ServerContext* context, T* ops) { Status status(code, ""); if (!context->sent_initial_metadata_) { - ops->SendInitialMetadata(context->initial_metadata_, + ops->SendInitialMetadata(&context->initial_metadata_, context->initial_metadata_flags()); if (context->compression_level_set()) { ops->set_compression_level(context->compression_level()); } context->sent_initial_metadata_ = true; } - ops->ServerSendStatus(context->trailing_metadata_, status); + ops->ServerSendStatus(&context->trailing_metadata_, status); } void RunHandler(const HandlerParameter& param) final { @@ -296,11 +321,14 @@ class ErrorMethodHandler : public MethodHandler { FillOps(param.server_context, &ops); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); - // We also have to destroy any request payload in the handler parameter - ByteBuffer* payload = param.request.bbuf_ptr(); - if (payload != nullptr) { - payload->Clear(); + } + + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + // We have to destroy any request payload + if (req != nullptr) { + g_core_codegen_interface->grpc_byte_buffer_destroy(req); } + return nullptr; } }; diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index 5cf88e216f..44da2bd768 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -40,17 +40,26 @@ class MethodHandler { public: virtual ~MethodHandler() {} struct HandlerParameter { - HandlerParameter(Call* c, ServerContext* context, grpc_byte_buffer* req) - : call(c), server_context(context) { - request.set_buffer(req); - } - ~HandlerParameter() { request.Release(); } + HandlerParameter(Call* c, ServerContext* context, void* req, + Status req_status) + : call(c), server_context(context), request(req), status(req_status) {} + ~HandlerParameter() {} Call* call; ServerContext* server_context; - // Handler required to destroy these contents - ByteBuffer request; + void* request; + Status status; }; virtual void RunHandler(const HandlerParameter& param) = 0; + + /* Returns a pointer to the deserialized request. \a status reflects the + result of deserialization. This pointer and the status should be filled in + a HandlerParameter and passed to RunHandler. It is illegal to access the + pointer after calling RunHandler. Ownership of the deserialized request is + retained by the handler. Returns nullptr if deserialization failed. */ + virtual void* Deserialize(grpc_byte_buffer* req, Status* status) { + GPR_CODEGEN_ASSERT(req == nullptr); + return nullptr; + } }; /// Server side rpc method class diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index b58f029de9..810c0bf35b 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -285,6 +285,16 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } + experimental::ServerRpcInfo* set_server_rpc_info( + const char* method, + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>& + creators) { + rpc_info_ = experimental::ServerRpcInfo(this, method); + rpc_info_.RegisterInterceptors(creators); + return &rpc_info_; + } + CompletionOp* completion_op_; bool has_notify_when_done_tag_; void* async_notify_when_done_tag_; @@ -306,6 +316,8 @@ class ServerContext { internal::CallOpSendMessage> pending_ops_; bool has_pending_ops_; + + experimental::ServerRpcInfo rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/server_interceptor.h b/include/grpcpp/impl/codegen/server_interceptor.h new file mode 100644 index 0000000000..3f8cbcca8d --- /dev/null +++ b/include/grpcpp/impl/codegen/server_interceptor.h @@ -0,0 +1,92 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H +#define GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H + +#include <vector> + +#include <grpc/impl/codegen/log.h> +#include <grpcpp/impl/codegen/interceptor.h> +#include <grpcpp/impl/codegen/string_ref.h> + +namespace grpc { + +class ServerContext; + +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ServerRpcInfo; + +class ServerInterceptorFactoryInterface { + public: + virtual ~ServerInterceptorFactoryInterface() {} + virtual Interceptor* CreateServerInterceptor(ServerRpcInfo* info) = 0; +}; + +class ServerRpcInfo { + public: + ServerRpcInfo() {} + + ~ServerRpcInfo(){}; + + ServerRpcInfo(const ServerRpcInfo&) = delete; + ServerRpcInfo(ServerRpcInfo&&) = default; + ServerRpcInfo& operator=(ServerRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + grpc::ServerContext* server_context() { return ctx_; } + + public: + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, + unsigned int pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + private: + ServerRpcInfo(grpc::ServerContext* ctx, const char* method) + : ctx_(ctx), method_(method) {} + + void RegisterInterceptors( + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>& + creators) { + for (const auto& creator : creators) { + interceptors_.push_back(std::unique_ptr<experimental::Interceptor>( + creator->CreateServerInterceptor(this))); + } + } + grpc::ServerContext* ctx_ = nullptr; + const char* method_ = nullptr; + std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ServerContext; +}; + +} // namespace experimental +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 237991cde6..9b4177b641 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -20,11 +20,14 @@ #define GRPCPP_IMPL_CODEGEN_SERVER_INTERFACE_H #include <grpc/impl/codegen/grpc_types.h> +//#include <grpcpp/alarm.h> #include <grpcpp/impl/codegen/byte_buffer.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/call_hook.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/rpc_service_method.h> +#include <grpcpp/impl/codegen/server_context.h> namespace grpc { @@ -148,44 +151,69 @@ class ServerInterface : public internal::CallHook { public: BaseAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize); virtual ~BaseAsyncRequest(); bool FinalizeResult(void** tag, bool* status) override; + private: + void ContinueFinalizeResultAfterInterception(); + protected: ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; const bool delete_on_finalize_; grpc_call* call_; + internal::Call call_wrapper_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; + bool done_intercepting_; + void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on + alarm.h here */ }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag); - - // uses BaseAsyncRequest::FinalizeResult + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, + const char* name); + + virtual bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(name_, + *server_->interceptor_creators())); + return BaseAsyncRequest::FinalizeResult(tag, status); + } protected: void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); + const char* name_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { - IssueRequest(registered_method, nullptr, notification_cq); + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()) { + IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } // uses RegisteredAsyncRequest::FinalizeResult @@ -194,13 +222,15 @@ class ServerInterface : public internal::CallHook { template <class Message> class PayloadAsyncRequest final : public RegisteredAsyncRequest { public: - PayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + PayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag), + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()), registered_method_(registered_method), server_(server), context_(context), @@ -209,7 +239,8 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { - IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq); + IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), + notification_cq); } ~PayloadAsyncRequest() { @@ -217,6 +248,10 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return RegisteredAsyncRequest::FinalizeResult(tag, status); + } if (*status) { if (!payload_.Valid() || !SerializationTraits<Message>::Deserialize( payload_.bbuf_ptr(), request_) @@ -235,15 +270,24 @@ class ServerInterface : public internal::CallHook { return false; } } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(name_, + *server_->interceptor_creators())); + /* Set interception point for recv message */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); return RegisteredAsyncRequest::FinalizeResult(tag, status); } private: - void* const registered_method_; + internal::RpcServiceMethod* const registered_method_; ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; Message* const request_; @@ -272,9 +316,8 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); - new PayloadAsyncRequest<Message>(method->server_tag(), this, context, - stream, call_cq, notification_cq, tag, - message); + new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq, + notification_cq, tag, message); } void RequestAsyncCall(internal::RpcServiceMethod* method, @@ -283,8 +326,8 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); - new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, - call_cq, notification_cq, tag); + new NoPayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag); } void RequestAsyncGenericCall(GenericServerContext* context, @@ -295,6 +338,13 @@ class ServerInterface : public internal::CallHook { new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag, true); } + + private: + virtual const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() { + return nullptr; + } }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/sync_stream.h b/include/grpcpp/impl/codegen/sync_stream.h index cbfcf25d0a..6981076f04 100644 --- a/include/grpcpp/impl/codegen/sync_stream.h +++ b/include/grpcpp/impl/codegen/sync_stream.h @@ -250,7 +250,7 @@ class ClientReader final : public ClientReaderInterface<R> { ::grpc::internal::CallOpSendMessage, ::grpc::internal::CallOpClientSendClose> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); // TODO(ctiller): don't assert GPR_CODEGEN_ASSERT(ops.SendMessage(request).ok()); @@ -327,7 +327,7 @@ class ClientWriter : public ClientWriterInterface<W> { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -386,7 +386,7 @@ class ClientWriter : public ClientWriterInterface<W> { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -498,7 +498,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface<W, R> { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -557,7 +557,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface<W, R> { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -583,7 +583,7 @@ class ServerReader final : public ServerReaderInterface<R> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet<internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -635,7 +635,7 @@ class ServerWriter final : public ServerWriterInterface<W> { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet<internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -660,7 +660,7 @@ class ServerWriter final : public ServerWriterInterface<W> { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); @@ -708,7 +708,7 @@ class ServerReaderWriterBody final { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); CallOpSet<CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -738,7 +738,7 @@ class ServerReaderWriterBody final { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index 8d3e856502..2b89ffd317 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -174,7 +174,11 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq = nullptr); + grpc_resource_quota* server_rq = nullptr, + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators = std::vector<std::unique_ptr< + experimental::ServerInterceptorFactoryInterface>>()); /// Start the server. /// @@ -187,6 +191,12 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() override { + return &interceptor_creators_; + } + friend class AsyncGenericService; friend class ServerBuilder; friend class ServerInitializer; @@ -251,6 +261,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { // A special handler for resource exhausted in sync case std::unique_ptr<internal::MethodHandler> resource_exhausted_handler_; + + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators_; }; } // namespace grpc diff --git a/include/grpcpp/server_builder.h b/include/grpcpp/server_builder.h index a58a59c2d8..028b8cffaa 100644 --- a/include/grpcpp/server_builder.h +++ b/include/grpcpp/server_builder.h @@ -28,6 +28,7 @@ #include <grpc/support/cpu.h> #include <grpc/support/workaround_list.h> #include <grpcpp/impl/channel_argument_option.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/server_builder_option.h> #include <grpcpp/impl/server_builder_plugin.h> #include <grpcpp/support/config.h> @@ -212,6 +213,29 @@ class ServerBuilder { /// doc/workarounds.md. ServerBuilder& EnableWorkaround(grpc_workaround_list id); + /// NOTE: class experimental_type is not part of the public API of this class. + /// TODO(yashykt): Integrate into public API when this is no longer + /// experimental. + class experimental_type { + public: + explicit experimental_type(ServerBuilder* builder) : builder_(builder) {} + + void SetInterceptorCreators( + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators) { + builder_->interceptor_creators_ = std::move(interceptor_creators); + } + + private: + ServerBuilder* builder_; + }; + + /// NOTE: The function experimental() is not stable public API. It is a view + /// to the experimental components of this class. It may be changed or removed + /// at any time. + experimental_type experimental() { return experimental_type(this); } + protected: /// Experimental, to be deprecated struct Port { @@ -297,6 +321,8 @@ class ServerBuilder { grpc_compression_algorithm algorithm; } maybe_default_compression_algorithm_; uint32_t enabled_compression_algorithms_bitset_; + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators_; }; } // namespace grpc diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 2cab41b3f5..f2a2a2fdc9 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -57,9 +57,8 @@ Channel::Channel( std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> interceptor_creators) : host_(host), c_channel_(channel) { - auto* vector = interceptor_creators.release(); - if (vector != nullptr) { - interceptor_creators_ = std::move(*vector); + if (interceptor_creators != nullptr) { + interceptor_creators_ = std::move(*interceptor_creators); } g_gli_initializer.summon(); } @@ -112,9 +111,10 @@ void ChannelResetConnectionBackoff(Channel* channel) { } // namespace experimental -internal::Call Channel::CreateCall(const internal::RpcMethod& method, - ClientContext* context, - CompletionQueue* cq) { +internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + int interceptor_pos) { const bool kRegistered = method.channel_tag() && context->authority().empty(); grpc_call* c_call = nullptr; if (kRegistered) { @@ -147,17 +147,22 @@ internal::Call Channel::CreateCall(const internal::RpcMethod& method, } grpc_census_call_set_context(c_call, context->census_context()); context->set_call(c_call, shared_from_this()); - return internal::Call(c_call, this, cq); + + auto* info = context->set_client_rpc_info( + method.name(), this, interceptor_creators_, interceptor_pos); + return internal::Call(c_call, this, cq, info); +} + +internal::Call Channel::CreateCall(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq) { + return CreateCallInternal(method, context, cq, 0); } void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), cops, nops, - ops->cq_tag(), nullptr)); + ops->FillOps( + call); // Make a copy of call. It's fine since Call just has pointers } void* Channel::RegisterMethod(const char* method) { diff --git a/src/cpp/codegen/call_wrapper.cc b/src/cpp/codegen/call_wrapper.cc new file mode 100644 index 0000000000..668fa886ef --- /dev/null +++ b/src/cpp/codegen/call_wrapper.cc @@ -0,0 +1,31 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpcpp/impl/codegen/call_wrapper.h> + +#include <grpcpp/impl/codegen/call_hook.h> + +namespace grpc { +namespace internal { + +void Call::PerformOps(CallOpSetInterface* ops) { + call_hook_->PerformOpsOnCall(ops, this); +} + +} // namespace internal +} // namespace grpc diff --git a/src/cpp/common/core_codegen.cc b/src/cpp/common/core_codegen.cc index 619aacadaa..cfaa2e7b19 100644 --- a/src/cpp/common/core_codegen.cc +++ b/src/cpp/common/core_codegen.cc @@ -102,6 +102,13 @@ size_t CoreCodegen::grpc_byte_buffer_length(grpc_byte_buffer* bb) { return ::grpc_byte_buffer_length(bb); } +grpc_call_error CoreCodegen::grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, + size_t nops, void* tag, + void* reserved) { + return ::grpc_call_start_batch(call, ops, nops, tag, reserved); +} + grpc_call_error CoreCodegen::grpc_call_cancel_with_status( grpc_call* call, grpc_status_code status, const char* description, void* reserved) { diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 8417c45e64..fc42b6c886 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -263,7 +263,8 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { std::unique_ptr<Server> server(new Server( max_receive_message_size_, &args, sync_server_cqs, sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, resource_quota_)); + sync_server_settings_.cq_timeout_msec, resource_quota_, + std::move(interceptor_creators_))); if (has_sync_methods) { // This is a Sync server diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7aeddff643..9f4ec3e4ab 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -24,10 +24,13 @@ #include <grpc/grpc.h> #include <grpc/support/alloc.h> #include <grpc/support/log.h> +#include <grpcpp/alarm.h> #include <grpcpp/completion_queue.h> #include <grpcpp/generic/async_generic_service.h> #include <grpcpp/impl/codegen/async_unary_call.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/grpc_library.h> #include <grpcpp/impl/method_handler_impl.h> #include <grpcpp/impl/rpc_service_method.h> @@ -208,13 +211,18 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size()), ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), + request_(nullptr), method_(mrd->method_), - server_(server) { + call_(mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), + server->interceptor_creators_)), + server_(server), + global_callbacks_(nullptr), + resources_(false) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -230,33 +238,73 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Run(const std::shared_ptr<GlobalCallbacks>& global_callbacks, bool resources) { - ctx_.BeginCompletionOp(&call_); - global_callbacks->PreSynchronousRequest(&ctx_); - auto* handler = resources ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_payload_)); - global_callbacks->PostSynchronousRequest(&ctx_); - request_payload_ = nullptr; - - cq_.Shutdown(); + global_callbacks_ = global_callbacks; + resources_ = resources; + + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); + + if (has_request_payload_) { + // Set interception point for RECV MESSAGE + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + request_ = handler->Deserialize(request_payload_, &request_status_); + + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); + } - internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); - cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + auto f = std::bind(&CallData::ContinueRunAfterInterception, this); + if (interceptor_methods_.RunInterceptors(f)) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } - /* Ensure the cq_ is shutdown */ - DummyTag ignored_tag; - GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + void ContinueRunAfterInterception() { + { + ctx_.BeginCompletionOp(&call_); + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( + &call_, &ctx_, request_, request_status_)); + request_ = nullptr; + global_callbacks_->PostSynchronousRequest(&ctx_); + + cq_.Shutdown(); + + internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + /* Ensure the cq_ is shutdown */ + DummyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + } + delete this; } private: CompletionQueue cq_; - internal::Call call_; ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; internal::RpcServiceMethod* const method_; + internal::Call call_; Server* server_; + std::shared_ptr<GlobalCallbacks> global_callbacks_; + bool resources_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; private: @@ -318,8 +366,9 @@ class Server::SyncRequestThreadManager : public ThreadManager { } if (ok) { - // Calldata takes ownership of the completion queue inside sync_req - SyncRequest::CallData cd(server_, sync_req); + // Calldata takes ownership of the completion queue and interceptors + // inside sync_req + auto* cd = new SyncRequest::CallData(server_, sync_req); // Prepare for the next request if (!IsShutdown()) { sync_req->SetupRequest(); // Create new completion queue for sync_req @@ -327,7 +376,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_, resources); + cd->Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -389,7 +438,10 @@ Server::Server( std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq) + grpc_resource_quota* server_rq, + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators) : max_receive_message_size_(max_receive_message_size), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), @@ -398,7 +450,8 @@ Server::Server( has_generic_service_(false), server_(nullptr), server_initializer_(new ServerInitializer(this)), - health_check_service_disabled_(false) { + health_check_service_disabled_(false), + interceptor_creators_(std::move(interceptor_creators)) { g_gli_initializer.summon(); gpr_once_init(&g_once_init_callbacks, InitGlobalCallbacks); global_callbacks_ = g_callbacks; @@ -681,31 +734,27 @@ void Server::Wait() { void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - auto result = - grpc_call_start_batch(call->call(), cops, nops, ops->cq_tag(), nullptr); - if (result != GRPC_CALL_OK) { - gpr_log(GPR_ERROR, "Fatal: grpc_call_start_batch returned %d", result); - grpc_call_log_batch(__FILE__, __LINE__, GPR_LOG_SEVERITY_ERROR, - call->call(), cops, nops, ops); - abort(); - } + ops->FillOps(call); } ServerInterface::BaseAsyncRequest::BaseAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag, bool delete_on_finalize) + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) : server_(server), context_(context), stream_(stream), call_cq_(call_cq), + notification_cq_(notification_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), - call_(nullptr) { + call_(nullptr), + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops } @@ -715,15 +764,47 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (done_intercepting_) { + delete static_cast<Alarm*>(dummy_alarm_); + dummy_alarm_ = nullptr; + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } context_->set_call(call_); context_->cq_ = call_cq_; - internal::Call call(call_, server_, call_cq_, - server_->max_receive_message_size()); - if (*status && call_) { - context_->BeginCompletionOp(&call); + if (call_wrapper_.call() == nullptr) { + // Fill it since it is empty. + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); } + // just the pointers inside call are copied here - stream_->BindCall(&call); + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + auto f = std::bind(&ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception, + this); + if (interceptor_methods_.RunInterceptors(f)) { + // There are no interceptors to run. Continue + } else { + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. + return false; + } + } + if (*status && call_) { + context_->BeginCompletionOp(&call_wrapper_); + } *tag = tag_; if (delete_on_finalize_) { delete this; @@ -731,11 +812,23 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, return true; } +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_); + // Queue a tag which will be returned immediately + dummy_alarm_ = new Alarm(); + static_cast<Alarm*>(dummy_alarm_) + ->Set(notification_cq_, + g_core_codegen_interface->gpr_time_0(GPR_CLOCK_MONOTONIC), this); +} + ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerCompletionQueue* notification_cq, void* tag, const char* name) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -751,7 +844,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( ServerInterface* server, GenericServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) - : BaseAsyncRequest(server, context, stream, call_cq, tag, + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, delete_on_finalize) { grpc_call_details_init(&call_details_); GPR_ASSERT(notification_cq); @@ -764,6 +857,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + // If we are done intercepting, there is nothing more for us to do + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } // TODO(yangg) remove the copy here. if (*status) { static_cast<GenericServerContext*>(context_)->method_ = @@ -774,16 +871,26 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, } grpc_slice_unref(call_details_.method); grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info( + static_cast<GenericServerContext*>(context_)->method_.c_str(), + *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { - if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) { - new UnimplementedAsyncRequest(server_, cq_); - new UnimplementedAsyncResponse(this); + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + // We either had no interceptors run or we are done intercepting + if (*status) { + new UnimplementedAsyncRequest(server_, cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } } else { - delete this; + // The tag was swallowed due to interception. We will see it again. } return false; } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index b7254b6bb9..42ae0ed138 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -45,9 +45,10 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { tag_(nullptr), refs_(2), finalized_(false), - cancelled_(0) {} + cancelled_(0), + done_intercepting_(false) {} - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override; + void FillOps(internal::Call* call) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -66,6 +67,35 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { void Unref(); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_CODEGEN_ASSERT(false); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + /* We don't have a tag to return. */ + std::unique_lock<std::mutex> lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return; + } + /* Start a dummy op so that we can return the tag */ + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, this, nullptr)); + } + private: bool CheckCancelledNoPluck() { std::lock_guard<std::mutex> g(mu_); @@ -78,6 +108,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { int refs_; bool finalized_; int cancelled_; + bool done_intercepting_; + internal::Call call_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; void ServerContext::CompletionOp::Unref() { @@ -88,29 +121,60 @@ void ServerContext::CompletionOp::Unref() { } } -void ServerContext::CompletionOp::FillOps(grpc_call* call, grpc_op* ops, - size_t* nops) { - ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; - ops->data.recv_close_on_server.cancelled = &cancelled_; - ops->flags = 0; - ops->reserved = nullptr; - *nops = 1; +void ServerContext::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + call_ = *call; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); + GPR_ASSERT(GRPC_CALL_OK == + grpc_call_start_batch(call->call(), &ops, 1, this, nullptr)); + /* No interceptors to run here */ } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { - std::unique_lock<std::mutex> lock(mu_); - finalized_ = true; bool ret = false; - if (has_tag_) { - *tag = tag_; - ret = true; + std::unique_lock<std::mutex> lock(mu_); + if (done_intercepting_) { + /* We are done intercepting. */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return ret; } + finalized_ = true; + if (!*status) cancelled_ = 1; - if (--refs_ == 0) { - lock.unlock(); - delete this; + /* Release the lock since we are going to be running through interceptors now + */ + lock.unlock(); + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + /* No interceptors were run */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + lock.lock(); + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return ret; } - return ret; + /* There are interceptors to be run. Return false for now */ + return false; } // ServerContext body diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD index 0415efc1ef..019ec43f96 100644 --- a/test/cpp/end2end/BUILD +++ b/test/cpp/end2end/BUILD @@ -117,6 +117,25 @@ grpc_cc_test( ], ) +grpc_cc_test( + name = "client_interceptors_end2end_test", + srcs = ["client_interceptors_end2end_test.cc"], + external_deps = [ + "gtest", + ], + deps = [ + ":test_service_impl", + "//:gpr", + "//:grpc", + "//:grpc++", + "//src/proto/grpc/testing:echo_messages_proto", + "//src/proto/grpc/testing:echo_proto", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + "//test/cpp/util:test_util", + ], +) + grpc_cc_library( name = "end2end_test_lib", testonly = True, diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc new file mode 100644 index 0000000000..4a3a8b859a --- /dev/null +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -0,0 +1,533 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/client_interceptor.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +class ClientInterceptorsEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor(experimental::ClientRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } + methods->Proceed(); + } + + static void Reset() { num_times_run_.store(0); } + + static int GetNumTimesRun() { return num_times_run_.load(); } + + private: + static std::atomic<int> num_times_run_; +}; + +std::atomic<int> DummyInterceptor::num_times_run_; + +class DummyInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +/* Hijacks Echo RPC and fills in the expected values */ +class HijackingInterceptor : public experimental::Interceptor { + public: + HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + gpr_log(GPR_ERROR, "ran this"); + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), 0); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + gpr_log(GPR_ERROR, "ran this"); + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + auto stub = grpc::testing::EchoTestService::NewStub( + methods->GetInterceptedChannel()); + ClientContext ctx; + EchoResponse resp; + ctx.AddMetadata(metadata_map_.begin()->first, + metadata_map_.begin()->second); + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), 0); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + gpr_log(GPR_ERROR, "hijacking"); + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap<grpc::string, grpc::string> metadata_map_; +}; + +class HijackingInterceptorMakesAnotherCallFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptorMakesAnotherCall(info); + } +}; + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + gpr_log(GPR_ERROR, "ran this"); + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + EXPECT_EQ(resp->message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +void MakeCall(std::shared_ptr<Channel> channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + // Add 20 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators->push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + // Add 20 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + // Make sure only 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { + ChannelArguments args; + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + creators->push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorHijackingMakesAnotherCallTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + // Add 20 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators->push_back( + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>( + new HijackingInterceptorMakesAnotherCallFactory())); + // Add 20 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + // Make sure all interceptors were run once, since the hijacking interceptor + // makes an RPC on the intercepted channel + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 40); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc_test_init(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/doxygen/Doxyfile.c++ b/tools/doxygen/Doxyfile.c++ index 40abd726c4..aa6870fe22 100644 --- a/tools/doxygen/Doxyfile.c++ +++ b/tools/doxygen/Doxyfile.c++ @@ -945,6 +945,7 @@ include/grpcpp/impl/codegen/async_unary_call.h \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ +include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -959,6 +960,7 @@ include/grpcpp/impl/codegen/core_codegen.h \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ +include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -970,6 +972,7 @@ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ +include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 8fed272159..dc803575c4 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -946,6 +946,7 @@ include/grpcpp/impl/codegen/async_unary_call.h \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ +include/grpcpp/impl/codegen/call_wrapper.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -961,6 +962,7 @@ include/grpcpp/impl/codegen/core_codegen.h \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ +include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ @@ -972,6 +974,7 @@ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ +include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -1188,6 +1191,7 @@ src/cpp/client/generic_stub.cc \ src/cpp/client/insecure_credentials.cc \ src/cpp/client/secure_credentials.cc \ src/cpp/client/secure_credentials.h \ +src/cpp/codegen/call_wrapper.cc \ src/cpp/codegen/codegen_init.cc \ src/cpp/common/alarm.cc \ src/cpp/common/auth_property_iterator.cc \ diff --git a/tools/run_tests/generated/sources_and_headers.json b/tools/run_tests/generated/sources_and_headers.json index 20b6d36671..e3a208ac2e 100644 --- a/tools/run_tests/generated/sources_and_headers.json +++ b/tools/run_tests/generated/sources_and_headers.json @@ -11142,6 +11142,7 @@ "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_wrapper.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -11154,6 +11155,7 @@ "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", @@ -11162,6 +11164,7 @@ "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", @@ -11212,6 +11215,7 @@ "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_wrapper.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -11224,6 +11228,7 @@ "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", @@ -11232,6 +11237,7 @@ "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", @@ -11254,6 +11260,7 @@ "language": "c++", "name": "grpc++_codegen_base_src", "src": [ + "src/cpp/codegen/call_wrapper.cc", "src/cpp/codegen/codegen_init.cc" ], "third_party": false, |