aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorGravatar Jan Tattermusch <jtattermusch@users.noreply.github.com>2017-10-06 10:22:27 +0200
committerGravatar GitHub <noreply@github.com>2017-10-06 10:22:27 +0200
commitcbae862552c3f8d5a28c29b1605ae09d58086100 (patch)
treefd47ace4066b1617aaad74d987b2e5b9eb82298c /src
parentf2747dd729ab5d543059a99989445952c3fe98f4 (diff)
parent32d196f1b1859d6a0eaa906f65a82c42ad1b2c87 (diff)
Merge pull request #12822 from jtattermusch/csharp_cancellable_movenext
Add support for CancellationToken in IAsyncStreamReader.MoveNext()
Diffstat (limited to 'src')
-rw-r--r--src/csharp/Grpc.Core.Tests/CallCancellationTest.cs182
-rw-r--r--src/csharp/Grpc.Core.Tests/ClientServerTest.cs68
-rw-r--r--src/csharp/Grpc.Core/IAsyncStreamReader.cs7
-rw-r--r--src/csharp/Grpc.Core/Internal/ClientResponseStream.cs20
-rw-r--r--src/csharp/Grpc.Core/Internal/ServerRequestStream.cs11
-rw-r--r--src/csharp/tests.json1
6 files changed, 206 insertions, 83 deletions
diff --git a/src/csharp/Grpc.Core.Tests/CallCancellationTest.cs b/src/csharp/Grpc.Core.Tests/CallCancellationTest.cs
new file mode 100644
index 0000000000..e040f52380
--- /dev/null
+++ b/src/csharp/Grpc.Core.Tests/CallCancellationTest.cs
@@ -0,0 +1,182 @@
+#region Copyright notice and license
+
+// Copyright 2015 gRPC authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#endregion
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Grpc.Core;
+using Grpc.Core.Internal;
+using Grpc.Core.Profiling;
+using Grpc.Core.Utils;
+using NUnit.Framework;
+
+namespace Grpc.Core.Tests
+{
+ public class CallCancellationTest
+ {
+ const string Host = "127.0.0.1";
+
+ MockServiceHelper helper;
+ Server server;
+ Channel channel;
+
+ [SetUp]
+ public void Init()
+ {
+ helper = new MockServiceHelper(Host);
+ server = helper.GetServer();
+ server.Start();
+ channel = helper.GetChannel();
+ }
+
+ [TearDown]
+ public void Cleanup()
+ {
+ channel.ShutdownAsync().Wait();
+ server.ShutdownAsync().Wait();
+ }
+
+ [Test]
+ public async Task ClientStreamingCall_CancelAfterBegin()
+ {
+ var barrier = new TaskCompletionSource<object>();
+
+ helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
+ {
+ barrier.SetResult(null);
+ await requestStream.ToListAsync();
+ return "";
+ });
+
+ var cts = new CancellationTokenSource();
+ var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall(new CallOptions(cancellationToken: cts.Token)));
+
+ await barrier.Task; // make sure the handler has started.
+ cts.Cancel();
+
+ try
+ {
+ // cannot use Assert.ThrowsAsync because it uses Task.Wait and would deadlock.
+ await call.ResponseAsync;
+ Assert.Fail();
+ }
+ catch (RpcException ex)
+ {
+ Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
+ }
+ }
+
+ [Test]
+ public async Task ClientStreamingCall_ServerSideReadAfterCancelNotificationReturnsNull()
+ {
+ var handlerStartedBarrier = new TaskCompletionSource<object>();
+ var cancelNotificationReceivedBarrier = new TaskCompletionSource<object>();
+ var successTcs = new TaskCompletionSource<string>();
+
+ helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
+ {
+ handlerStartedBarrier.SetResult(null);
+
+ // wait for cancellation to be delivered.
+ context.CancellationToken.Register(() => cancelNotificationReceivedBarrier.SetResult(null));
+ await cancelNotificationReceivedBarrier.Task;
+
+ var moveNextResult = await requestStream.MoveNext();
+ successTcs.SetResult(!moveNextResult ? "SUCCESS" : "FAIL");
+ return "";
+ });
+
+ var cts = new CancellationTokenSource();
+ var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall(new CallOptions(cancellationToken: cts.Token)));
+
+ await handlerStartedBarrier.Task;
+ cts.Cancel();
+
+ try
+ {
+ await call.ResponseAsync;
+ Assert.Fail();
+ }
+ catch (RpcException ex)
+ {
+ Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
+ }
+ Assert.AreEqual("SUCCESS", await successTcs.Task);
+ }
+
+ [Test]
+ public async Task ClientStreamingCall_CancelServerSideRead()
+ {
+ helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
+ {
+ var cts = new CancellationTokenSource();
+ var moveNextTask = requestStream.MoveNext(cts.Token);
+ cts.Cancel();
+ await moveNextTask;
+ return "";
+ });
+
+ var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall());
+ try
+ {
+ // cannot use Assert.ThrowsAsync because it uses Task.Wait and would deadlock.
+ await call.ResponseAsync;
+ Assert.Fail();
+ }
+ catch (RpcException ex)
+ {
+ Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
+ }
+ }
+
+ [Test]
+ public async Task ServerStreamingCall_CancelClientSideRead()
+ {
+ helper.ServerStreamingHandler = new ServerStreamingServerMethod<string, string>(async (request, responseStream, context) =>
+ {
+ await responseStream.WriteAsync("abc");
+ while (!context.CancellationToken.IsCancellationRequested)
+ {
+ await Task.Delay(10);
+ }
+ });
+
+ var call = Calls.AsyncServerStreamingCall(helper.CreateServerStreamingCall(), "");
+ await call.ResponseStream.MoveNext();
+ Assert.AreEqual("abc", call.ResponseStream.Current);
+
+ var cts = new CancellationTokenSource();
+ var moveNextTask = call.ResponseStream.MoveNext(cts.Token);
+ cts.Cancel();
+
+ try
+ {
+ // cannot use Assert.ThrowsAsync because it uses Task.Wait and would deadlock.
+ await moveNextTask;
+ Assert.Fail();
+ }
+ catch (RpcException ex)
+ {
+ Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
+ }
+ }
+ }
+}
diff --git a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs
index 72d9035a6f..90dd365b07 100644
--- a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs
+++ b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs
@@ -273,74 +273,6 @@ namespace Grpc.Core.Tests
}
[Test]
- public async Task ClientStreamingCall_CancelAfterBegin()
- {
- var barrier = new TaskCompletionSource<object>();
-
- helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
- {
- barrier.SetResult(null);
- await requestStream.ToListAsync();
- return "";
- });
-
- var cts = new CancellationTokenSource();
- var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall(new CallOptions(cancellationToken: cts.Token)));
-
- await barrier.Task; // make sure the handler has started.
- cts.Cancel();
-
- try
- {
- // cannot use Assert.ThrowsAsync because it uses Task.Wait and would deadlock.
- await call.ResponseAsync;
- Assert.Fail();
- }
- catch (RpcException ex)
- {
- Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
- }
- }
-
- [Test]
- public async Task ClientStreamingCall_ServerSideReadAfterCancelNotificationReturnsNull()
- {
- var handlerStartedBarrier = new TaskCompletionSource<object>();
- var cancelNotificationReceivedBarrier = new TaskCompletionSource<object>();
- var successTcs = new TaskCompletionSource<string>();
-
- helper.ClientStreamingHandler = new ClientStreamingServerMethod<string, string>(async (requestStream, context) =>
- {
- handlerStartedBarrier.SetResult(null);
-
- // wait for cancellation to be delivered.
- context.CancellationToken.Register(() => cancelNotificationReceivedBarrier.SetResult(null));
- await cancelNotificationReceivedBarrier.Task;
-
- var moveNextResult = await requestStream.MoveNext();
- successTcs.SetResult(!moveNextResult ? "SUCCESS" : "FAIL");
- return "";
- });
-
- var cts = new CancellationTokenSource();
- var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall(new CallOptions(cancellationToken: cts.Token)));
-
- await handlerStartedBarrier.Task;
- cts.Cancel();
-
- try
- {
- await call.ResponseAsync;
- Assert.Fail();
- }
- catch (RpcException ex)
- {
- Assert.AreEqual(StatusCode.Cancelled, ex.Status.StatusCode);
- }
- Assert.AreEqual("SUCCESS", await successTcs.Task);
- }
-
- [Test]
public async Task AsyncUnaryCall_EchoMetadata()
{
helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
diff --git a/src/csharp/Grpc.Core/IAsyncStreamReader.cs b/src/csharp/Grpc.Core/IAsyncStreamReader.cs
index 42bfbb87e0..3751d549e3 100644
--- a/src/csharp/Grpc.Core/IAsyncStreamReader.cs
+++ b/src/csharp/Grpc.Core/IAsyncStreamReader.cs
@@ -41,6 +41,13 @@ namespace Grpc.Core
/// (<c>MoveNext</c> will return <c>false</c>) and the <c>CancellationToken</c>
/// associated with the call will be cancelled to signal the failure.
/// </para>
+ /// <para>
+ /// <c>MoveNext()</c> operations can be cancelled via a cancellation token. Cancelling
+ /// an individual read operation has the same effect as cancelling the entire call
+ /// (which will also result in the read operation returning prematurely), but the per-read cancellation
+ /// tokens passed to MoveNext() only result in cancelling the call if the read operation haven't finished
+ /// yet.
+ /// </para>
/// </summary>
/// <typeparam name="T">The message type.</typeparam>
public interface IAsyncStreamReader<T> : IAsyncEnumerator<T>
diff --git a/src/csharp/Grpc.Core/Internal/ClientResponseStream.cs b/src/csharp/Grpc.Core/Internal/ClientResponseStream.cs
index 851b6ca213..ab649ee766 100644
--- a/src/csharp/Grpc.Core/Internal/ClientResponseStream.cs
+++ b/src/csharp/Grpc.Core/Internal/ClientResponseStream.cs
@@ -49,19 +49,19 @@ namespace Grpc.Core.Internal
public async Task<bool> MoveNext(CancellationToken token)
{
- if (token != CancellationToken.None)
+ var cancellationTokenRegistration = token.CanBeCanceled ? token.Register(() => call.Cancel()) : (IDisposable) null;
+ using (cancellationTokenRegistration)
{
- throw new InvalidOperationException("Cancellation of individual reads is not supported.");
- }
- var result = await call.ReadMessageAsync().ConfigureAwait(false);
- this.current = result;
+ var result = await call.ReadMessageAsync().ConfigureAwait(false);
+ this.current = result;
- if (result == null)
- {
- await call.StreamingResponseCallFinishedTask.ConfigureAwait(false);
- return false;
+ if (result == null)
+ {
+ await call.StreamingResponseCallFinishedTask.ConfigureAwait(false);
+ return false;
+ }
+ return true;
}
- return true;
}
public void Dispose()
diff --git a/src/csharp/Grpc.Core/Internal/ServerRequestStream.cs b/src/csharp/Grpc.Core/Internal/ServerRequestStream.cs
index c65b960afb..058dddb7eb 100644
--- a/src/csharp/Grpc.Core/Internal/ServerRequestStream.cs
+++ b/src/csharp/Grpc.Core/Internal/ServerRequestStream.cs
@@ -49,13 +49,14 @@ namespace Grpc.Core.Internal
public async Task<bool> MoveNext(CancellationToken token)
{
- if (token != CancellationToken.None)
+
+ var cancellationTokenRegistration = token.CanBeCanceled ? token.Register(() => call.Cancel()) : (IDisposable) null;
+ using (cancellationTokenRegistration)
{
- throw new InvalidOperationException("Cancellation of individual reads is not supported.");
+ var result = await call.ReadMessageAsync().ConfigureAwait(false);
+ this.current = result;
+ return result != null;
}
- var result = await call.ReadMessageAsync().ConfigureAwait(false);
- this.current = result;
- return result != null;
}
public void Dispose()
diff --git a/src/csharp/tests.json b/src/csharp/tests.json
index 7841051052..65a0ed293f 100644
--- a/src/csharp/tests.json
+++ b/src/csharp/tests.json
@@ -10,6 +10,7 @@
"Grpc.Core.Tests.AppDomainUnloadTest",
"Grpc.Core.Tests.AuthContextTest",
"Grpc.Core.Tests.AuthPropertyTest",
+ "Grpc.Core.Tests.CallCancellationTest",
"Grpc.Core.Tests.CallCredentialsTest",
"Grpc.Core.Tests.CallOptionsTest",
"Grpc.Core.Tests.ChannelCredentialsTest",