diff options
author | Mehrdad Afshari <mehrdada@users.noreply.github.com> | 2018-02-22 07:52:48 -0800 |
---|---|---|
committer | Mehrdad Afshari <mehrdada@users.noreply.github.com> | 2018-02-22 08:00:55 -0800 |
commit | a7c1b6251c151bbb3b020e88ab340cedb4ca4d0d (patch) | |
tree | 5c6f2f91e3f48b877e8880f45bfd6f53200507a6 /src/csharp/Grpc.Core.Tests | |
parent | 074b802c9f3b1c22f57f5cea57e755487cc01832 (diff) |
Eliminate GenericInterceptor to simplify this PR
Diffstat (limited to 'src/csharp/Grpc.Core.Tests')
-rw-r--r-- | src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs | 130 | ||||
-rw-r--r-- | src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs | 93 |
2 files changed, 142 insertions, 81 deletions
diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs index d7c01d08ac..02f6f6ffc6 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs @@ -58,22 +58,6 @@ namespace Grpc.Core.Interceptors.Tests Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), "")); } - private class CallbackInterceptor : GenericInterceptor - { - readonly Action callback; - - public CallbackInterceptor(Action callback) - { - this.callback = callback; - } - - protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request) - { - callback(); - return null; - } - } - [Test] public void CheckInterceptorOrderInClientInterceptors() { @@ -118,23 +102,6 @@ namespace Grpc.Core.Interceptors.Tests Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor[]))); } - private class CountingInterceptor : GenericInterceptor - { - protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request) - { - if (!clientStreaming) - { - return null; - } - int counter = 0; - return new ClientCallHooks<TRequest, TResponse> - { - OnRequestMessage = m => { counter++; return m; }, - OnUnaryResponse = x => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker - }; - } - } - [Test] public async Task CountNumberOfRequestsInClientInterceptors() { @@ -151,7 +118,7 @@ namespace Grpc.Core.Interceptors.Tests return stringBuilder.ToString(); }); - var callInvoker = helper.GetChannel().Intercept(new CountingInterceptor()); + var callInvoker = helper.GetChannel().Intercept(new ClientStreamingCountingInterceptor()); var server = helper.GetServer(); server.Start(); @@ -162,5 +129,100 @@ namespace Grpc.Core.Interceptors.Tests Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); Assert.IsNotNull(call.GetTrailers()); } + + private class CallbackInterceptor : Interceptor + { + readonly Action callback; + + public CallbackInterceptor(Action callback) + { + this.callback = GrpcPreconditions.CheckNotNull(callback, nameof(callback)); + } + + public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(context); + } + + public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(context); + } + } + + private class ClientStreamingCountingInterceptor : Interceptor + { + public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation) + { + var response = continuation(context); + int counter = 0; + var requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, + message => { counter++; return message; }, null); + var responseAsync = response.ResponseAsync.ContinueWith( + unaryResponse => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker + ); + return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); + } + } + + private class WrappedClientStreamWriter<T> : IClientStreamWriter<T> + { + readonly IClientStreamWriter<T> writer; + readonly Func<T, T> onMessage; + readonly Action onResponseStreamEnd; + public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd) + { + this.writer = writer; + this.onMessage = onMessage; + this.onResponseStreamEnd = onResponseStreamEnd; + } + public Task CompleteAsync() + { + if (onResponseStreamEnd != null) + { + return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd()); + } + return writer.CompleteAsync(); + } + public Task WriteAsync(T message) + { + if (onMessage != null) + { + message = onMessage(message); + } + return writer.WriteAsync(message); + } + public WriteOptions WriteOptions + { + get + { + return writer.WriteOptions; + } + set + { + writer.WriteOptions = value; + } + } + } } } diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs index c0957a2b42..e76f21d098 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs @@ -35,33 +35,17 @@ namespace Grpc.Core.Interceptors.Tests { const string Host = "127.0.0.1"; - private class AddRequestHeaderServerInterceptor : GenericInterceptor - { - readonly Metadata.Entry header; - - public AddRequestHeaderServerInterceptor(string key, string value) - { - this.header = new Metadata.Entry(key, value); - } - - protected override Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) - { - context.RequestHeaders.Add(header); - return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null); - } - - public Metadata.Entry Header => header; - } - [Test] public void AddRequestHeaderInServerInterceptor() { var helper = new MockServiceHelper(Host); - var interceptor = new AddRequestHeaderServerInterceptor("x-interceptor", "hello world"); + const string MetadataKey = "x-interceptor"; + const string MetadataValue = "hello world"; + var interceptor = new ServerCallContextInterceptor(ctx => ctx.RequestHeaders.Add(new Metadata.Entry(MetadataKey, MetadataValue))); helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => { - var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == interceptor.Header.Key)).Value; - Assert.AreEqual(interceptorHeader, interceptor.Header.Value); + var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == MetadataKey)).Value; + Assert.AreEqual(interceptorHeader, MetadataValue); return Task.FromResult("PASS"); }); helper.ServiceDefinition = helper.ServiceDefinition.Intercept(interceptor); @@ -71,22 +55,6 @@ namespace Grpc.Core.Interceptors.Tests Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); } - private class ArbitraryActionInterceptor : GenericInterceptor - { - readonly Action action; - - public ArbitraryActionInterceptor(Action action) - { - this.action = action; - } - - protected override Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) - { - action(); - return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null); - } - } - [Test] public void VerifyInterceptorOrdering() { @@ -97,11 +65,11 @@ namespace Grpc.Core.Interceptors.Tests }); var stringBuilder = new StringBuilder(); helper.ServiceDefinition = helper.ServiceDefinition - .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("A"))) - .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("B1")), - new ArbitraryActionInterceptor(() => stringBuilder.Append("B2")), - new ArbitraryActionInterceptor(() => stringBuilder.Append("B3"))) - .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("C"))); + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("A"))) + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("B1")), + new ServerCallContextInterceptor(ctx => stringBuilder.Append("B2")), + new ServerCallContextInterceptor(ctx => stringBuilder.Append("B3"))) + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("C"))); var server = helper.GetServer(); server.Start(); var channel = helper.GetChannel(); @@ -113,15 +81,46 @@ namespace Grpc.Core.Interceptors.Tests public void CheckNullInterceptorRegistrationFails() { var helper = new MockServiceHelper(Host); - helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => - { - return Task.FromResult("PASS"); - }); var sd = helper.ServiceDefinition; Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor))); Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{default(Interceptor)})); - Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ArbitraryActionInterceptor(()=>{}), null})); + Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ServerCallContextInterceptor(ctx=>{}), null})); Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor[]))); } + + private class ServerCallContextInterceptor : Interceptor + { + readonly Action<ServerCallContext> interceptor; + + public ServerCallContextInterceptor(Action<ServerCallContext> interceptor) + { + GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor)); + this.interceptor = interceptor; + } + + public override Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(request, context); + } + + public override Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(requestStream, context); + } + + public override Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(request, responseStream, context); + } + + public override Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(requestStream, responseStream, context); + } + } } } |