diff options
Diffstat (limited to 'src/csharp/Grpc.Core.Tests')
3 files changed, 357 insertions, 10 deletions
diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs new file mode 100644 index 0000000000..02f6f6ffc6 --- /dev/null +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs @@ -0,0 +1,228 @@ +#region Copyright notice and license + +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Core.Interceptors; +using Grpc.Core.Internal; +using Grpc.Core.Utils; +using Grpc.Core.Tests; +using NUnit.Framework; + +namespace Grpc.Core.Interceptors.Tests +{ + public class ClientInterceptorTest + { + const string Host = "127.0.0.1"; + + [Test] + public void AddRequestHeaderInClientInterceptor() + { + const string HeaderKey = "x-client-interceptor"; + const string HeaderValue = "hello-world"; + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => + { + var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == HeaderKey)).Value; + Assert.AreEqual(interceptorHeader, HeaderValue); + return Task.FromResult("PASS"); + }); + var server = helper.GetServer(); + server.Start(); + var callInvoker = helper.GetChannel().Intercept(metadata => + { + metadata = metadata ?? new Metadata(); + metadata.Add(new Metadata.Entry(HeaderKey, HeaderValue)); + return metadata; + }); + Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), "")); + } + + [Test] + public void CheckInterceptorOrderInClientInterceptors() + { + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => + { + return Task.FromResult("PASS"); + }); + var server = helper.GetServer(); + server.Start(); + var stringBuilder = new StringBuilder(); + var callInvoker = helper.GetChannel().Intercept(metadata => { + stringBuilder.Append("interceptor1"); + return metadata; + }).Intercept(new CallbackInterceptor(() => stringBuilder.Append("array1")), + new CallbackInterceptor(() => stringBuilder.Append("array2")), + new CallbackInterceptor(() => stringBuilder.Append("array3"))) + .Intercept(metadata => + { + stringBuilder.Append("interceptor2"); + return metadata; + }).Intercept(metadata => + { + stringBuilder.Append("interceptor3"); + return metadata; + }); + Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), "")); + Assert.AreEqual("interceptor3interceptor2array1array2array3interceptor1", stringBuilder.ToString()); + } + + [Test] + public void CheckNullInterceptorRegistrationFails() + { + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => + { + return Task.FromResult("PASS"); + }); + Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor))); + Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(new[]{default(Interceptor)})); + Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(new[]{new CallbackInterceptor(()=>{}), null})); + Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor[]))); + } + + [Test] + public async Task CountNumberOfRequestsInClientInterceptors() + { + var helper = new MockServiceHelper(Host); + helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) => + { + var stringBuilder = new StringBuilder(); + await requestStream.ForEachAsync(request => + { + stringBuilder.Append(request); + return TaskUtils.CompletedTask; + }); + await Task.Delay(100); + return stringBuilder.ToString(); + }); + + var callInvoker = helper.GetChannel().Intercept(new ClientStreamingCountingInterceptor()); + + var server = helper.GetServer(); + server.Start(); + var call = callInvoker.AsyncClientStreamingCall(new Method<string, string>(MethodType.ClientStreaming, MockServiceHelper.ServiceName, "ClientStreaming", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions()); + await call.RequestStream.WriteAllAsync(new string[] { "A", "B", "C" }); + Assert.AreEqual("3", await call.ResponseAsync); + + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + Assert.IsNotNull(call.GetTrailers()); + } + + private class CallbackInterceptor : Interceptor + { + readonly Action callback; + + public CallbackInterceptor(Action callback) + { + this.callback = GrpcPreconditions.CheckNotNull(callback, nameof(callback)); + } + + public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(request, context); + } + + public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(context); + } + + public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation) + { + callback(); + return continuation(context); + } + } + + private class ClientStreamingCountingInterceptor : Interceptor + { + public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation) + { + var response = continuation(context); + int counter = 0; + var requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, + message => { counter++; return message; }, null); + var responseAsync = response.ResponseAsync.ContinueWith( + unaryResponse => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker + ); + return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); + } + } + + private class WrappedClientStreamWriter<T> : IClientStreamWriter<T> + { + readonly IClientStreamWriter<T> writer; + readonly Func<T, T> onMessage; + readonly Action onResponseStreamEnd; + public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd) + { + this.writer = writer; + this.onMessage = onMessage; + this.onResponseStreamEnd = onResponseStreamEnd; + } + public Task CompleteAsync() + { + if (onResponseStreamEnd != null) + { + return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd()); + } + return writer.CompleteAsync(); + } + public Task WriteAsync(T message) + { + if (onMessage != null) + { + message = onMessage(message); + } + return writer.WriteAsync(message); + } + public WriteOptions WriteOptions + { + get + { + return writer.WriteOptions; + } + set + { + writer.WriteOptions = value; + } + } + } + } +} diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs new file mode 100644 index 0000000000..e76f21d098 --- /dev/null +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs @@ -0,0 +1,126 @@ +#region Copyright notice and license + +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Core.Interceptors; +using Grpc.Core.Internal; +using Grpc.Core.Tests; +using Grpc.Core.Utils; +using NUnit.Framework; + +namespace Grpc.Core.Interceptors.Tests +{ + public class ServerInterceptorTest + { + const string Host = "127.0.0.1"; + + [Test] + public void AddRequestHeaderInServerInterceptor() + { + var helper = new MockServiceHelper(Host); + const string MetadataKey = "x-interceptor"; + const string MetadataValue = "hello world"; + var interceptor = new ServerCallContextInterceptor(ctx => ctx.RequestHeaders.Add(new Metadata.Entry(MetadataKey, MetadataValue))); + helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => + { + var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == MetadataKey)).Value; + Assert.AreEqual(interceptorHeader, MetadataValue); + return Task.FromResult("PASS"); + }); + helper.ServiceDefinition = helper.ServiceDefinition.Intercept(interceptor); + var server = helper.GetServer(); + server.Start(); + var channel = helper.GetChannel(); + Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); + } + + [Test] + public void VerifyInterceptorOrdering() + { + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => + { + return Task.FromResult("PASS"); + }); + var stringBuilder = new StringBuilder(); + helper.ServiceDefinition = helper.ServiceDefinition + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("A"))) + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("B1")), + new ServerCallContextInterceptor(ctx => stringBuilder.Append("B2")), + new ServerCallContextInterceptor(ctx => stringBuilder.Append("B3"))) + .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("C"))); + var server = helper.GetServer(); + server.Start(); + var channel = helper.GetChannel(); + Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); + Assert.AreEqual("CB1B2B3A", stringBuilder.ToString()); + } + + [Test] + public void CheckNullInterceptorRegistrationFails() + { + var helper = new MockServiceHelper(Host); + var sd = helper.ServiceDefinition; + Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor))); + Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{default(Interceptor)})); + Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ServerCallContextInterceptor(ctx=>{}), null})); + Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor[]))); + } + + private class ServerCallContextInterceptor : Interceptor + { + readonly Action<ServerCallContext> interceptor; + + public ServerCallContextInterceptor(Action<ServerCallContext> interceptor) + { + GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor)); + this.interceptor = interceptor; + } + + public override Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(request, context); + } + + public override Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(requestStream, context); + } + + public override Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(request, responseStream, context); + } + + public override Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation) + { + interceptor(context); + return continuation(requestStream, responseStream, context); + } + } + } +} diff --git a/src/csharp/Grpc.Core.Tests/MockServiceHelper.cs b/src/csharp/Grpc.Core.Tests/MockServiceHelper.cs index 7f4677d57f..a925f865ff 100644 --- a/src/csharp/Grpc.Core.Tests/MockServiceHelper.cs +++ b/src/csharp/Grpc.Core.Tests/MockServiceHelper.cs @@ -37,7 +37,6 @@ namespace Grpc.Core.Tests public const string ServiceName = "tests.Test"; readonly string host; - readonly ServerServiceDefinition serviceDefinition; readonly IEnumerable<ChannelOption> channelOptions; readonly Method<string, string> unaryMethod; @@ -87,7 +86,7 @@ namespace Grpc.Core.Tests marshaller, marshaller); - serviceDefinition = ServerServiceDefinition.CreateBuilder() + ServiceDefinition = ServerServiceDefinition.CreateBuilder() .AddMethod(unaryMethod, (request, context) => unaryHandler(request, context)) .AddMethod(clientStreamingMethod, (requestStream, context) => clientStreamingHandler(requestStream, context)) .AddMethod(serverStreamingMethod, (request, responseStream, context) => serverStreamingHandler(request, responseStream, context)) @@ -131,7 +130,7 @@ namespace Grpc.Core.Tests // Disable SO_REUSEPORT to prevent https://github.com/grpc/grpc/issues/10755 server = new Server(new[] { new ChannelOption(ChannelOptions.SoReuseport, 0) }) { - Services = { serviceDefinition }, + Services = { ServiceDefinition }, Ports = { { Host, ServerPort.PickUnused, ServerCredentials.Insecure } } }; } @@ -178,13 +177,7 @@ namespace Grpc.Core.Tests } } - public ServerServiceDefinition ServiceDefinition - { - get - { - return this.serviceDefinition; - } - } + public ServerServiceDefinition ServiceDefinition { get; set; } public UnaryServerMethod<string, string> UnaryHandler { |