From 6c3cb2299124773abe4b7039b94a976c5552c432 Mon Sep 17 00:00:00 2001 From: Mehrdad Afshari Date: Mon, 12 Feb 2018 01:09:34 -0800 Subject: Add server-side interceptor helper facility to GenericInterceptor --- .../Grpc.Core/Interceptors/GenericInterceptor.cs | 150 +++++++++++++++++++-- 1 file changed, 137 insertions(+), 13 deletions(-) diff --git a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs index b9fc5e0a19..ed90ded889 100644 --- a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs +++ b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs @@ -29,7 +29,6 @@ namespace Grpc.Core.Interceptors /// public abstract class GenericInterceptor : Interceptor { - /// /// Provides hooks through which an invocation should be intercepted. /// @@ -93,6 +92,65 @@ namespace Grpc.Core.Interceptors return null; } + /// + /// Provides hooks through which a server-side handler should be intercepted. + /// + public sealed class ServerCallArbitrator + where TRequest : class + where TResponse : class + { + internal ServerCallArbitrator Freeze() + { + return (ServerCallArbitrator)MemberwiseClone(); + } + /// + /// Override the request for the outgoing invocation for non-client-streaming invocations. + /// + public TRequest UnaryRequest { get; set; } + /// + /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. + /// + public Func OnUnaryResponse { get; set; } + /// + /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message. + /// + public Func OnRequestMessage { get; set; } + /// + /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message. + /// + public Func OnResponseMessage { get; set; } + /// + /// Callback that gets invoked when handler is finished executing. + /// + public Action OnHandlerEnd { get; set; } + /// + /// Callback that gets invoked when request stream is finished. + /// + public Action OnRequestStreamEnd { get; set; } + } + + /// + /// Intercepts an incoming service handler invocation on the server side. + /// Derived classes that intend to intercept incoming handlers on the server side should + /// override this and return the appropriate hooks in the form of a ServerCallArbitrator instance. + /// + /// The context of the incoming invocation. + /// True if the invocation is client-streaming. + /// True if the invocation is server-streaming. + /// The request message for client-unary invocations, null otherwise. + /// Request message type for the current invocation. + /// Response message type for the current invocation. + /// + /// The derived class should return an instance of ServerCallArbitrator to control the trajectory + /// as they see fit, or null if it does not intend to pursue the invocation any further. + /// + protected virtual Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) + where TRequest : class + where TResponse : class + { + return Task.FromResult>(null); + } + /// /// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly. /// @@ -138,7 +196,7 @@ namespace Grpc.Core.Interceptors if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) { response = new AsyncServerStreamingCall( - new WrappedClientStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd), + new WrappedAsyncStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; @@ -187,7 +245,7 @@ namespace Grpc.Core.Interceptors var responseStream = response.ResponseStream; if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) { - responseStream = new WrappedClientStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd); + responseStream = new WrappedAsyncStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd); } response = new AsyncDuplexStreamingCall(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } @@ -199,9 +257,17 @@ namespace Grpc.Core.Interceptors /// /// Request message type for this method. /// Response message type for this method. - public override Task UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod continuation) + public override async Task UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod continuation) { - return continuation(request, context); + var arbitrator = (await InterceptHandler(context, false, false, request))?.Freeze(); + request = arbitrator?.UnaryRequest ?? request; + var response = await continuation(request, context); + if (arbitrator?.OnUnaryResponse != null) + { + response = arbitrator.OnUnaryResponse(response); + } + arbitrator?.OnHandlerEnd(); + return response; } /// @@ -209,9 +275,20 @@ namespace Grpc.Core.Interceptors /// /// Request message type for this method. /// Response message type for this method. - public override Task ClientStreamingServerHandler(IAsyncStreamReader requestStream, ServerCallContext context, ClientStreamingServerMethod continuation) + public override async Task ClientStreamingServerHandler(IAsyncStreamReader requestStream, ServerCallContext context, ClientStreamingServerMethod continuation) { - return continuation(requestStream, context); + var arbitrator = (await InterceptHandler(context, true, false, null))?.Freeze(); + if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + { + requestStream = new WrappedAsyncStreamReader(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + } + var response = await continuation(requestStream, context); + if (arbitrator?.OnUnaryResponse != null) + { + response = arbitrator.OnUnaryResponse(response); + } + arbitrator?.OnHandlerEnd(); + return response; } /// @@ -219,9 +296,16 @@ namespace Grpc.Core.Interceptors /// /// Request message type for this method. /// Response message type for this method. - public override Task ServerStreamingServerHandler(TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation) + public override async Task ServerStreamingServerHandler(TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation) { - return continuation(request, responseStream, context); + var arbitrator = (await InterceptHandler(context, false, true, request))?.Freeze(); + request = arbitrator?.UnaryRequest ?? request; + if (arbitrator?.OnResponseMessage != null) + { + responseStream = new WrappedAsyncStreamWriter(responseStream, arbitrator.OnResponseMessage); + } + await continuation(request, responseStream, context); + arbitrator?.OnHandlerEnd(); } /// @@ -229,17 +313,27 @@ namespace Grpc.Core.Interceptors /// /// Request message type for this method. /// Response message type for this method. - public override Task DuplexStreamingServerHandler(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context, DuplexStreamingServerMethod continuation) + public override async Task DuplexStreamingServerHandler(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context, DuplexStreamingServerMethod continuation) { - return continuation(requestStream, responseStream, context); + var arbitrator = (await InterceptHandler(context, true, true, null))?.Freeze(); + if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + { + requestStream = new WrappedAsyncStreamReader(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + } + if (arbitrator?.OnResponseMessage != null) + { + responseStream = new WrappedAsyncStreamWriter(responseStream, arbitrator.OnResponseMessage); + } + await continuation(requestStream, responseStream, context); + arbitrator?.OnHandlerEnd(); } - private class WrappedClientStreamReader : IAsyncStreamReader + private class WrappedAsyncStreamReader : IAsyncStreamReader { readonly IAsyncStreamReader reader; readonly Func onMessage; readonly Action onStreamEnd; - public WrappedClientStreamReader(IAsyncStreamReader reader, Func onMessage, Action onStreamEnd) + public WrappedAsyncStreamReader(IAsyncStreamReader reader, Func onMessage, Action onStreamEnd) { this.reader = reader; this.onMessage = onMessage; @@ -321,5 +415,35 @@ namespace Grpc.Core.Interceptors } } } + + private class WrappedAsyncStreamWriter : IServerStreamWriter + { + readonly IAsyncStreamWriter writer; + readonly Func onMessage; + public WrappedAsyncStreamWriter(IAsyncStreamWriter writer, Func onMessage) + { + this.writer = writer; + this.onMessage = onMessage; + } + public Task WriteAsync(T message) + { + if (onMessage != null) + { + message = onMessage(message); + } + return writer.WriteAsync(message); + } + public WriteOptions WriteOptions + { + get + { + return writer.WriteOptions; + } + set + { + writer.WriteOptions = value; + } + } + } } } -- cgit v1.2.3