aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs')
-rw-r--r--src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs228
1 files changed, 228 insertions, 0 deletions
diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs
new file mode 100644
index 0000000000..02f6f6ffc6
--- /dev/null
+++ b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs
@@ -0,0 +1,228 @@
+#region Copyright notice and license
+
+// Copyright 2018 gRPC authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#endregion
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Grpc.Core;
+using Grpc.Core.Interceptors;
+using Grpc.Core.Internal;
+using Grpc.Core.Utils;
+using Grpc.Core.Tests;
+using NUnit.Framework;
+
+namespace Grpc.Core.Interceptors.Tests
+{
+ public class ClientInterceptorTest
+ {
+ const string Host = "127.0.0.1";
+
+ [Test]
+ public void AddRequestHeaderInClientInterceptor()
+ {
+ const string HeaderKey = "x-client-interceptor";
+ const string HeaderValue = "hello-world";
+ var helper = new MockServiceHelper(Host);
+ helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
+ {
+ var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == HeaderKey)).Value;
+ Assert.AreEqual(interceptorHeader, HeaderValue);
+ return Task.FromResult("PASS");
+ });
+ var server = helper.GetServer();
+ server.Start();
+ var callInvoker = helper.GetChannel().Intercept(metadata =>
+ {
+ metadata = metadata ?? new Metadata();
+ metadata.Add(new Metadata.Entry(HeaderKey, HeaderValue));
+ return metadata;
+ });
+ Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), ""));
+ }
+
+ [Test]
+ public void CheckInterceptorOrderInClientInterceptors()
+ {
+ var helper = new MockServiceHelper(Host);
+ helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
+ {
+ return Task.FromResult("PASS");
+ });
+ var server = helper.GetServer();
+ server.Start();
+ var stringBuilder = new StringBuilder();
+ var callInvoker = helper.GetChannel().Intercept(metadata => {
+ stringBuilder.Append("interceptor1");
+ return metadata;
+ }).Intercept(new CallbackInterceptor(() => stringBuilder.Append("array1")),
+ new CallbackInterceptor(() => stringBuilder.Append("array2")),
+ new CallbackInterceptor(() => stringBuilder.Append("array3")))
+ .Intercept(metadata =>
+ {
+ stringBuilder.Append("interceptor2");
+ return metadata;
+ }).Intercept(metadata =>
+ {
+ stringBuilder.Append("interceptor3");
+ return metadata;
+ });
+ Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), ""));
+ Assert.AreEqual("interceptor3interceptor2array1array2array3interceptor1", stringBuilder.ToString());
+ }
+
+ [Test]
+ public void CheckNullInterceptorRegistrationFails()
+ {
+ var helper = new MockServiceHelper(Host);
+ helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
+ {
+ return Task.FromResult("PASS");
+ });
+ Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor)));
+ Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(new[]{default(Interceptor)}));
+ Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(new[]{new CallbackInterceptor(()=>{}), null}));
+ Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor[])));
+ }
+
+ [Test]
+ public async Task CountNumberOfRequestsInClientInterceptors()
+ {
+ var helper = new MockServiceHelper(Host);
+ helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
+ {
+ var stringBuilder = new StringBuilder();
+ await requestStream.ForEachAsync(request =>
+ {
+ stringBuilder.Append(request);
+ return TaskUtils.CompletedTask;
+ });
+ await Task.Delay(100);
+ return stringBuilder.ToString();
+ });
+
+ var callInvoker = helper.GetChannel().Intercept(new ClientStreamingCountingInterceptor());
+
+ var server = helper.GetServer();
+ server.Start();
+ var call = callInvoker.AsyncClientStreamingCall(new Method<string, string>(MethodType.ClientStreaming, MockServiceHelper.ServiceName, "ClientStreaming", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions());
+ await call.RequestStream.WriteAllAsync(new string[] { "A", "B", "C" });
+ Assert.AreEqual("3", await call.ResponseAsync);
+
+ Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode);
+ Assert.IsNotNull(call.GetTrailers());
+ }
+
+ private class CallbackInterceptor : Interceptor
+ {
+ readonly Action callback;
+
+ public CallbackInterceptor(Action callback)
+ {
+ this.callback = GrpcPreconditions.CheckNotNull(callback, nameof(callback));
+ }
+
+ public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
+ {
+ callback();
+ return continuation(request, context);
+ }
+
+ public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
+ {
+ callback();
+ return continuation(request, context);
+ }
+
+ public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
+ {
+ callback();
+ return continuation(request, context);
+ }
+
+ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
+ {
+ callback();
+ return continuation(context);
+ }
+
+ public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
+ {
+ callback();
+ return continuation(context);
+ }
+ }
+
+ private class ClientStreamingCountingInterceptor : Interceptor
+ {
+ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
+ {
+ var response = continuation(context);
+ int counter = 0;
+ var requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream,
+ message => { counter++; return message; }, null);
+ var responseAsync = response.ResponseAsync.ContinueWith(
+ unaryResponse => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker
+ );
+ return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
+ }
+ }
+
+ private class WrappedClientStreamWriter<T> : IClientStreamWriter<T>
+ {
+ readonly IClientStreamWriter<T> writer;
+ readonly Func<T, T> onMessage;
+ readonly Action onResponseStreamEnd;
+ public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd)
+ {
+ this.writer = writer;
+ this.onMessage = onMessage;
+ this.onResponseStreamEnd = onResponseStreamEnd;
+ }
+ public Task CompleteAsync()
+ {
+ if (onResponseStreamEnd != null)
+ {
+ return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd());
+ }
+ return writer.CompleteAsync();
+ }
+ public Task WriteAsync(T message)
+ {
+ if (onMessage != null)
+ {
+ message = onMessage(message);
+ }
+ return writer.WriteAsync(message);
+ }
+ public WriteOptions WriteOptions
+ {
+ get
+ {
+ return writer.WriteOptions;
+ }
+ set
+ {
+ writer.WriteOptions = value;
+ }
+ }
+ }
+ }
+}