diff options
-rw-r--r-- | src/csharp/Grpc.Core.Tests/ClientServerTest.cs | 219 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Calls.cs | 17 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Internal/AsyncCallServer.cs | 6 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Internal/ServerCallHandler.cs | 83 | ||||
-rw-r--r-- | src/csharp/Grpc.IntegrationTesting/InteropClient.cs | 65 | ||||
-rw-r--r-- | src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs | 12 |
6 files changed, 322 insertions, 80 deletions
diff --git a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs index 9e510e82a6..c91efe53b4 100644 --- a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs +++ b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs @@ -47,117 +47,220 @@ namespace Grpc.Core.Tests const string Host = "localhost"; const string ServiceName = "/tests.Test"; - static readonly Method<string, string> UnaryEchoStringMethod = new Method<string, string>( + static readonly Method<string, string> EchoMethod = new Method<string, string>( MethodType.Unary, - "/tests.Test/UnaryEchoString", + "/tests.Test/Echo", + Marshallers.StringMarshaller, + Marshallers.StringMarshaller); + + static readonly Method<string, string> ConcatAndEchoMethod = new Method<string, string>( + MethodType.ClientStreaming, + "/tests.Test/ConcatAndEcho", + Marshallers.StringMarshaller, + Marshallers.StringMarshaller); + + static readonly Method<string, string> NonexistentMethod = new Method<string, string>( + MethodType.Unary, + "/tests.Test/NonexistentMethod", Marshallers.StringMarshaller, Marshallers.StringMarshaller); static readonly ServerServiceDefinition ServiceDefinition = ServerServiceDefinition.CreateBuilder(ServiceName) - .AddMethod(UnaryEchoStringMethod, HandleUnaryEchoString).Build(); + .AddMethod(EchoMethod, EchoHandler) + .AddMethod(ConcatAndEchoMethod, ConcatAndEchoHandler) + .Build(); + + Server server; + Channel channel; [TestFixtureSetUp] + public void InitClass() + { + GrpcEnvironment.Initialize(); + } + + [SetUp] public void Init() { GrpcEnvironment.Initialize(); + + server = new Server(); + server.AddServiceDefinition(ServiceDefinition); + int port = server.AddListeningPort(Host + ":0"); + server.Start(); + channel = new Channel(Host + ":" + port); } - [TestFixtureTearDown] + [TearDown] public void Cleanup() { + channel.Dispose(); + server.ShutdownAsync().Wait(); + } + + [TestFixtureTearDown] + public void CleanupClass() + { GrpcEnvironment.Shutdown(); } [Test] public void UnaryCall() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); - - using (Channel channel = new Channel(Host + ":" + port)) - { - var call = new Call<string, string>(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); - Assert.AreEqual("ABC", Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken))); - } - - server.ShutdownAsync().Wait(); + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); + Assert.AreEqual("ABC", Calls.BlockingUnaryCall(call, "ABC", CancellationToken.None)); } [Test] - public void CallOnDisposedChannel() + public void UnaryCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); - - Channel channel = new Channel(Host + ":" + port); - channel.Dispose(); - - var call = new Call<string, string>(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); try { - Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); - Assert.Fail(); + Calls.BlockingUnaryCall(call, "THROW", CancellationToken.None); + Assert.Fail(); } - catch (ObjectDisposedException e) + catch (RpcException e) { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); } + } - server.ShutdownAsync().Wait(); + [Test] + public void AsyncUnaryCall() + { + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); + var result = Calls.AsyncUnaryCall(call, "ABC", CancellationToken.None).Result; + Assert.AreEqual("ABC", result); } [Test] - public void UnaryCallPerformance() + public void AsyncUnaryCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); + Task.Run(async () => + { + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); + try + { + await Calls.AsyncUnaryCall(call, "THROW", CancellationToken.None); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); + } + }).Wait(); + } - using (Channel channel = new Channel(Host + ":" + port)) + [Test] + public void ClientStreamingCall() + { + Task.Run(async () => { - var call = new Call<string, string>(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); - BenchmarkUtil.RunBenchmark(100, 100, - () => { Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); }); - } + var call = new Call<string, string>(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + var callResult = Calls.AsyncClientStreamingCall(call, CancellationToken.None); - server.ShutdownAsync().Wait(); + await callResult.RequestStream.WriteAll(new string[] { "A", "B", "C" }); + Assert.AreEqual("ABC", await callResult.Result); + }).Wait(); } [Test] - public void UnknownMethodHandler() + public void ClientStreamingCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServerServiceDefinition.CreateBuilder(ServiceName).Build()); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); + Task.Run(async () => + { + var call = new Call<string, string>(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + var callResult = Calls.AsyncClientStreamingCall(call, CancellationToken.None); + // TODO(jtattermusch): if we send "A", "THROW", "C", server hangs. + await callResult.RequestStream.WriteAll(new string[] { "A", "B", "THROW" }); + + try + { + await callResult.Result; + } + catch(RpcException e) + { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); + } + }).Wait(); + } - using (Channel channel = new Channel(Host + ":" + port)) + [Test] + public void ClientStreamingCall_CancelAfterBegin() + { + Task.Run(async () => { - var call = new Call<string, string>(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); + var call = new Call<string, string>(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + + var cts = new CancellationTokenSource(); + var callResult = Calls.AsyncClientStreamingCall(call, cts.Token); + cts.Cancel(); + try { - Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); - Assert.Fail(); + await callResult.Result; } - catch (RpcException e) + catch(RpcException e) { - Assert.AreEqual(StatusCode.Unimplemented, e.Status.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); } + }).Wait(); + } + + [Test] + public void UnaryCall_DisposedChannel() + { + channel.Dispose(); + + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); + Assert.Throws(typeof(ObjectDisposedException), () => Calls.BlockingUnaryCall(call, "ABC", CancellationToken.None)); + } + + [Test] + public void UnaryCallPerformance() + { + var call = new Call<string, string>(ServiceName, EchoMethod, channel, Metadata.Empty); + BenchmarkUtil.RunBenchmark(100, 100, + () => { Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); }); + } + + [Test] + public void UnknownMethodHandler() + { + var call = new Call<string, string>(ServiceName, NonexistentMethod, channel, Metadata.Empty); + try + { + Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Unimplemented, e.Status.StatusCode); } + } - server.ShutdownAsync().Wait(); + private static async Task<string> EchoHandler(string request) + { + if (request == "THROW") + { + throw new Exception("This was thrown on purpose by a test"); + } + return request; } - /// <summary> - /// Handler for unaryEchoString method. - /// </summary> - private static Task<string> HandleUnaryEchoString(string request) + private static async Task<string> ConcatAndEchoHandler(IAsyncStreamReader<string> requestStream) { - return Task.FromResult(request); + string result = ""; + await requestStream.ForEach(async (request) => + { + if (request == "THROW") + { + throw new Exception("This was thrown on purpose by a test"); + } + result += request; + }); + return result; } } } diff --git a/src/csharp/Grpc.Core/Calls.cs b/src/csharp/Grpc.Core/Calls.cs index 9365ccd9fb..c2397290fd 100644 --- a/src/csharp/Grpc.Core/Calls.cs +++ b/src/csharp/Grpc.Core/Calls.cs @@ -46,6 +46,8 @@ namespace Grpc.Core public static TResponse BlockingUnaryCall<TRequest, TResponse>(Call<TRequest, TResponse> call, TRequest req, CancellationToken token) { var asyncCall = new AsyncCall<TRequest, TResponse>(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); + // TODO(jtattermusch): this gives a race that cancellation can be requested before the call even starts. + RegisterCancellationCallback(asyncCall, token); return asyncCall.UnaryCall(call.Channel, call.Name, req, call.Headers); } @@ -53,7 +55,9 @@ namespace Grpc.Core { var asyncCall = new AsyncCall<TRequest, TResponse>(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); - return await asyncCall.UnaryCallAsync(req, call.Headers); + var asyncResult = asyncCall.UnaryCallAsync(req, call.Headers); + RegisterCancellationCallback(asyncCall, token); + return await asyncResult; } public static AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(Call<TRequest, TResponse> call, TRequest req, CancellationToken token) @@ -61,6 +65,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall<TRequest, TResponse>(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); asyncCall.StartServerStreamingCall(req, call.Headers); + RegisterCancellationCallback(asyncCall, token); var responseStream = new ClientResponseStream<TRequest, TResponse>(asyncCall); return new AsyncServerStreamingCall<TResponse>(responseStream); } @@ -70,6 +75,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall<TRequest, TResponse>(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); var resultTask = asyncCall.ClientStreamingCallAsync(call.Headers); + RegisterCancellationCallback(asyncCall, token); var requestStream = new ClientRequestStream<TRequest, TResponse>(asyncCall); return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, resultTask); } @@ -79,11 +85,20 @@ namespace Grpc.Core var asyncCall = new AsyncCall<TRequest, TResponse>(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); asyncCall.StartDuplexStreamingCall(call.Headers); + RegisterCancellationCallback(asyncCall, token); var requestStream = new ClientRequestStream<TRequest, TResponse>(asyncCall); var responseStream = new ClientResponseStream<TRequest, TResponse>(asyncCall); return new AsyncDuplexStreamingCall<TRequest, TResponse>(requestStream, responseStream); } + private static void RegisterCancellationCallback<TRequest, TResponse>(AsyncCall<TRequest, TResponse> asyncCall, CancellationToken token) + { + if (token.CanBeCanceled) + { + token.Register( () => asyncCall.Cancel() ); + } + } + /// <summary> /// Gets shared completion queue used for async calls. /// </summary> diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs index 25dc15bbc0..449009336f 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs @@ -121,6 +121,12 @@ namespace Grpc.Core.Internal { finished = true; + if (readCompletionDelegate == null) + { + // allow disposal of native call + readingDone = true; + } + ReleaseResourcesIfPossible(); } // TODO: handle error ... diff --git a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs index 2ae924b4a8..0416eada34 100644 --- a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs +++ b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs @@ -66,13 +66,21 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream<TRequest, TResponse>(asyncCall); var responseStream = new ServerResponseStream<TRequest, TResponse>(asyncCall); - var request = await requestStream.ReadNext(); - // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. - Preconditions.CheckArgument(await requestStream.ReadNext() == null); - - var result = await handler(request); - await responseStream.Write(result); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var request = await requestStream.ReadNext(); + // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. + Preconditions.CheckArgument(await requestStream.ReadNext() == null); + var result = await handler(request); + await responseStream.Write(result); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -99,12 +107,21 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream<TRequest, TResponse>(asyncCall); var responseStream = new ServerResponseStream<TRequest, TResponse>(asyncCall); - var request = await requestStream.ReadNext(); - // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. - Preconditions.CheckArgument(await requestStream.ReadNext() == null); - - await handler(request, responseStream); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var request = await requestStream.ReadNext(); + // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. + Preconditions.CheckArgument(await requestStream.ReadNext() == null); + + await handler(request, responseStream); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -131,9 +148,18 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream<TRequest, TResponse>(asyncCall); var responseStream = new ServerResponseStream<TRequest, TResponse>(asyncCall); - var result = await handler(requestStream); - await responseStream.Write(result); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var result = await handler(requestStream); + await responseStream.Write(result); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -160,8 +186,17 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream<TRequest, TResponse>(asyncCall); var responseStream = new ServerResponseStream<TRequest, TResponse>(asyncCall); - await handler(requestStream, responseStream); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + await handler(requestStream, responseStream); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -173,15 +208,25 @@ namespace Grpc.Core.Internal // We don't care about the payload type here. var asyncCall = new AsyncCallServer<byte[], byte[]>( (payload) => payload, (payload) => payload); - + asyncCall.Initialize(call); var finishedTask = asyncCall.ServerSideCallAsync(); var requestStream = new ServerRequestStream<byte[], byte[]>(asyncCall); var responseStream = new ServerResponseStream<byte[], byte[]>(asyncCall); + await responseStream.WriteStatus(new Status(StatusCode.Unimplemented, "No such method.")); // TODO(jtattermusch): if we don't read what client has sent, the server call never gets disposed. await requestStream.ToList(); await finishedTask; } } + + internal static class HandlerUtils + { + public static Status StatusFromException(Exception e) + { + // TODO(jtattermusch): what is the right status code here? + return new Status(StatusCode.Unknown, "Exception was thrown by handler."); + } + } } diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs index 573ab30452..440702d06f 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs @@ -34,6 +34,7 @@ using System; using System.Collections.Generic; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using Google.ProtocolBuffers; @@ -166,6 +167,12 @@ namespace Grpc.IntegrationTesting case "compute_engine_creds": RunComputeEngineCreds(client); break; + case "cancel_after_begin": + RunCancelAfterBegin(client); + break; + case "cancel_after_first_response": + RunCancelAfterFirstResponse(client); + break; case "benchmark_empty_unary": RunBenchmarkEmptyUnary(client); break; @@ -351,6 +358,64 @@ namespace Grpc.IntegrationTesting Console.WriteLine("Passed!"); } + public static void RunCancelAfterBegin(TestServiceGrpc.ITestServiceClient client) + { + Task.Run(async () => + { + Console.WriteLine("running cancel_after_begin"); + + var cts = new CancellationTokenSource(); + var call = client.StreamingInputCall(cts.Token); + cts.Cancel(); + + try + { + var response = await call.Result; + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); + } + Console.WriteLine("Passed!"); + }).Wait(); + } + + public static void RunCancelAfterFirstResponse(TestServiceGrpc.ITestServiceClient client) + { + Task.Run(async () => + { + Console.WriteLine("running cancel_after_first_response"); + + var cts = new CancellationTokenSource(); + var call = client.FullDuplexCall(cts.Token); + + StreamingOutputCallResponse response; + + await call.RequestStream.Write(StreamingOutputCallRequest.CreateBuilder() + .SetResponseType(PayloadType.COMPRESSABLE) + .AddResponseParameters(ResponseParameters.CreateBuilder().SetSize(31415)) + .SetPayload(CreateZerosPayload(27182)).Build()); + + response = await call.ResponseStream.ReadNext(); + Assert.AreEqual(PayloadType.COMPRESSABLE, response.Payload.Type); + Assert.AreEqual(31415, response.Payload.Body.Length); + + cts.Cancel(); + + try + { + response = await call.ResponseStream.ReadNext(); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); + } + Console.WriteLine("Passed!"); + }).Wait(); + } + // This is not an official interop test, but it's useful. public static void RunBenchmarkEmptyUnary(TestServiceGrpc.ITestServiceClient client) { diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs b/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs index e929b76b5e..45380227c2 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs @@ -114,8 +114,16 @@ namespace Grpc.IntegrationTesting InteropClient.RunEmptyStream(client); } - // TODO: add cancel_after_begin + [Test] + public void CancelAfterBegin() + { + InteropClient.RunCancelAfterBegin(client); + } - // TODO: add cancel_after_first_response + [Test] + public void CancelAfterFirstResponse() + { + InteropClient.RunCancelAfterFirstResponse(client); + } } } |