diff options
author | 2018-02-12 01:09:34 -0800 | |
---|---|---|
committer | 2018-02-21 18:30:19 -0800 | |
commit | 6c3cb2299124773abe4b7039b94a976c5552c432 (patch) | |
tree | 1e36bf0c7611f0ae166760f411d896a326c4a59f | |
parent | e97fe27f687e27680577dc9e8bedae258ef5a36b (diff) |
Add server-side interceptor helper facility to GenericInterceptor
-rw-r--r-- | src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs | 150 |
1 files changed, 137 insertions, 13 deletions
diff --git a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs index b9fc5e0a19..ed90ded889 100644 --- a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs +++ b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs @@ -29,7 +29,6 @@ namespace Grpc.Core.Interceptors /// </summary> public abstract class GenericInterceptor : Interceptor { - /// <summary> /// Provides hooks through which an invocation should be intercepted. /// </summary> @@ -94,6 +93,65 @@ namespace Grpc.Core.Interceptors } /// <summary> + /// Provides hooks through which a server-side handler should be intercepted. + /// </summary> + public sealed class ServerCallArbitrator<TRequest, TResponse> + where TRequest : class + where TResponse : class + { + internal ServerCallArbitrator<TRequest, TResponse> Freeze() + { + return (ServerCallArbitrator<TRequest, TResponse>)MemberwiseClone(); + } + /// <summary> + /// Override the request for the outgoing invocation for non-client-streaming invocations. + /// </summary> + public TRequest UnaryRequest { get; set; } + /// <summary> + /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. + /// </summary> + public Func<TResponse, TResponse> OnUnaryResponse { get; set; } + /// <summary> + /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message. + /// </summary> + public Func<TRequest, TRequest> OnRequestMessage { get; set; } + /// <summary> + /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message. + /// </summary> + public Func<TResponse, TResponse> OnResponseMessage { get; set; } + /// <summary> + /// Callback that gets invoked when handler is finished executing. + /// </summary> + public Action OnHandlerEnd { get; set; } + /// <summary> + /// Callback that gets invoked when request stream is finished. + /// </summary> + public Action OnRequestStreamEnd { get; set; } + } + + /// <summary> + /// Intercepts an incoming service handler invocation on the server side. + /// Derived classes that intend to intercept incoming handlers on the server side should + /// override this and return the appropriate hooks in the form of a ServerCallArbitrator instance. + /// </summary> + /// <param name="context">The context of the incoming invocation.</param> + /// <param name="clientStreaming">True if the invocation is client-streaming.</param> + /// <param name="serverStreaming">True if the invocation is server-streaming.</param> + /// <param name="request">The request message for client-unary invocations, null otherwise.</param> + /// <typeparam name="TRequest">Request message type for the current invocation.</typeparam> + /// <typeparam name="TResponse">Response message type for the current invocation.</typeparam> + /// <returns> + /// The derived class should return an instance of ServerCallArbitrator to control the trajectory + /// as they see fit, or null if it does not intend to pursue the invocation any further. + /// </returns> + protected virtual Task<ServerCallArbitrator<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) + where TRequest : class + where TResponse : class + { + return Task.FromResult<ServerCallArbitrator<TRequest, TResponse>>(null); + } + + /// <summary> /// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly. /// </summary> public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation) @@ -138,7 +196,7 @@ namespace Grpc.Core.Interceptors if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) { response = new AsyncServerStreamingCall<TResponse>( - new WrappedClientStreamReader<TResponse>(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd), + new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; @@ -187,7 +245,7 @@ namespace Grpc.Core.Interceptors var responseStream = response.ResponseStream; if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) { - responseStream = new WrappedClientStreamReader<TResponse>(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd); + responseStream = new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd); } response = new AsyncDuplexStreamingCall<TRequest, TResponse>(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } @@ -199,9 +257,17 @@ namespace Grpc.Core.Interceptors /// </summary> /// <typeparam name="TRequest">Request message type for this method.</typeparam> /// <typeparam name="TResponse">Response message type for this method.</typeparam> - public override Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation) + public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation) { - return continuation(request, context); + var arbitrator = (await InterceptHandler<TRequest, TResponse>(context, false, false, request))?.Freeze(); + request = arbitrator?.UnaryRequest ?? request; + var response = await continuation(request, context); + if (arbitrator?.OnUnaryResponse != null) + { + response = arbitrator.OnUnaryResponse(response); + } + arbitrator?.OnHandlerEnd(); + return response; } /// <summary> @@ -209,9 +275,20 @@ namespace Grpc.Core.Interceptors /// </summary> /// <typeparam name="TRequest">Request message type for this method.</typeparam> /// <typeparam name="TResponse">Response message type for this method.</typeparam> - public override Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation) + public override async Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation) { - return continuation(requestStream, context); + var arbitrator = (await InterceptHandler<TRequest, TResponse>(context, true, false, null))?.Freeze(); + if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + { + requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + } + var response = await continuation(requestStream, context); + if (arbitrator?.OnUnaryResponse != null) + { + response = arbitrator.OnUnaryResponse(response); + } + arbitrator?.OnHandlerEnd(); + return response; } /// <summary> @@ -219,9 +296,16 @@ namespace Grpc.Core.Interceptors /// </summary> /// <typeparam name="TRequest">Request message type for this method.</typeparam> /// <typeparam name="TResponse">Response message type for this method.</typeparam> - public override Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation) + public override async Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation) { - return continuation(request, responseStream, context); + var arbitrator = (await InterceptHandler<TRequest, TResponse>(context, false, true, request))?.Freeze(); + request = arbitrator?.UnaryRequest ?? request; + if (arbitrator?.OnResponseMessage != null) + { + responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, arbitrator.OnResponseMessage); + } + await continuation(request, responseStream, context); + arbitrator?.OnHandlerEnd(); } /// <summary> @@ -229,17 +313,27 @@ namespace Grpc.Core.Interceptors /// </summary> /// <typeparam name="TRequest">Request message type for this method.</typeparam> /// <typeparam name="TResponse">Response message type for this method.</typeparam> - public override Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation) + public override async Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation) { - return continuation(requestStream, responseStream, context); + var arbitrator = (await InterceptHandler<TRequest, TResponse>(context, true, true, null))?.Freeze(); + if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + { + requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + } + if (arbitrator?.OnResponseMessage != null) + { + responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, arbitrator.OnResponseMessage); + } + await continuation(requestStream, responseStream, context); + arbitrator?.OnHandlerEnd(); } - private class WrappedClientStreamReader<T> : IAsyncStreamReader<T> + private class WrappedAsyncStreamReader<T> : IAsyncStreamReader<T> { readonly IAsyncStreamReader<T> reader; readonly Func<T, T> onMessage; readonly Action onStreamEnd; - public WrappedClientStreamReader(IAsyncStreamReader<T> reader, Func<T, T> onMessage, Action onStreamEnd) + public WrappedAsyncStreamReader(IAsyncStreamReader<T> reader, Func<T, T> onMessage, Action onStreamEnd) { this.reader = reader; this.onMessage = onMessage; @@ -321,5 +415,35 @@ namespace Grpc.Core.Interceptors } } } + + private class WrappedAsyncStreamWriter<T> : IServerStreamWriter<T> + { + readonly IAsyncStreamWriter<T> writer; + readonly Func<T, T> onMessage; + public WrappedAsyncStreamWriter(IAsyncStreamWriter<T> writer, Func<T, T> onMessage) + { + this.writer = writer; + this.onMessage = onMessage; + } + public Task WriteAsync(T message) + { + if (onMessage != null) + { + message = onMessage(message); + } + return writer.WriteAsync(message); + } + public WriteOptions WriteOptions + { + get + { + return writer.WriteOptions; + } + set + { + writer.WriteOptions = value; + } + } + } } } |