diff options
-rw-r--r-- | src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs | 57 | ||||
-rw-r--r-- | src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs | 22 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Channel.cs | 6 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Internal/AsyncCall.cs | 210 | ||||
-rw-r--r-- | src/csharp/Grpc.Core/Internal/AsyncCallBase.cs | 2 |
5 files changed, 238 insertions, 59 deletions
diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs index 9aab54d2d0..961b75c0b5 100644 --- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs +++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs @@ -107,6 +107,24 @@ namespace Grpc.Core.Internal.Tests } [Test] + public void AsyncUnary_RequestSerializationExceptionDoesntLeakResources() + { + string nullRequest = null; // will throw when serializing + Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCallAsync(nullRequest)); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + } + + [Test] + public void AsyncUnary_StartCallFailureDoesntLeakResources() + { + fakeCall.MakeStartCallFail(); + Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCallAsync("request1")); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + } + + [Test] public void ClientStreaming_StreamingReadNotAllowed() { asyncCall.ClientStreamingCallAsync(); @@ -328,6 +346,15 @@ namespace Grpc.Core.Internal.Tests } [Test] + public void ClientStreaming_StartCallFailureDoesntLeakResources() + { + fakeCall.MakeStartCallFail(); + Assert.Throws(typeof(InvalidOperationException), () => asyncCall.ClientStreamingCallAsync()); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + } + + [Test] public void ServerStreaming_StreamingSendNotAllowed() { asyncCall.StartServerStreamingCall("request1"); @@ -402,6 +429,27 @@ namespace Grpc.Core.Internal.Tests } [Test] + public void ServerStreaming_RequestSerializationExceptionDoesntLeakResources() + { + string nullRequest = null; // will throw when serializing + Assert.Throws(typeof(ArgumentNullException), () => asyncCall.StartServerStreamingCall(nullRequest)); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + + var responseStream = new ClientResponseStream<string, string>(asyncCall); + var readTask = responseStream.MoveNext(); + } + + [Test] + public void ServerStreaming_StartCallFailureDoesntLeakResources() + { + fakeCall.MakeStartCallFail(); + Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartServerStreamingCall("request1")); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + } + + [Test] public void DuplexStreaming_NoRequestNoResponse_Success() { asyncCall.StartDuplexStreamingCall(); @@ -558,6 +606,15 @@ namespace Grpc.Core.Internal.Tests AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled); } + [Test] + public void DuplexStreaming_StartCallFailureDoesntLeakResources() + { + fakeCall.MakeStartCallFail(); + Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartDuplexStreamingCall()); + Assert.AreEqual(0, channel.GetCallReferenceCount()); + Assert.IsTrue(fakeCall.IsDisposed); + } + ClientSideStatus CreateClientSideStatus(StatusCode statusCode) { return new ClientSideStatus(new Status(statusCode, ""), new Metadata()); diff --git a/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs b/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs index 581ac3384b..42536dd0c1 100644 --- a/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs +++ b/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs @@ -31,6 +31,7 @@ namespace Grpc.Core.Internal.Tests /// </summary> internal class FakeNativeCall : INativeCall { + private bool shouldStartCallFail; public IUnaryResponseClientCallback UnaryResponseClientCallback { get; @@ -102,6 +103,7 @@ namespace Grpc.Core.Internal.Tests public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) { + StartCallMaybeFail(); UnaryResponseClientCallback = callback; } @@ -112,16 +114,19 @@ namespace Grpc.Core.Internal.Tests public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) { + StartCallMaybeFail(); UnaryResponseClientCallback = callback; } public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) { + StartCallMaybeFail(); ReceivedStatusOnClientCallback = callback; } public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) { + StartCallMaybeFail(); ReceivedStatusOnClientCallback = callback; } @@ -165,5 +170,22 @@ namespace Grpc.Core.Internal.Tests { IsDisposed = true; } + + /// <summary> + /// Emulate CallSafeHandle.CheckOk() failure for all future attempts + /// to start a call. + /// </summary> + public void MakeStartCallFail() + { + shouldStartCallFail = true; + } + + private void StartCallMaybeFail() + { + if (shouldStartCallFail) + { + throw new InvalidOperationException("Start call has failed."); + } + } } } diff --git a/src/csharp/Grpc.Core/Channel.cs b/src/csharp/Grpc.Core/Channel.cs index 7912b06b71..bc236da560 100644 --- a/src/csharp/Grpc.Core/Channel.cs +++ b/src/csharp/Grpc.Core/Channel.cs @@ -297,6 +297,12 @@ namespace Grpc.Core activeCallCounter.Decrement(); } + // for testing only + internal long GetCallReferenceCount() + { + return activeCallCounter.Count; + } + private ChannelState GetConnectivityState(bool tryToConnect) { try diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs index 66902f3caa..76743226b3 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs @@ -17,6 +17,7 @@ #endregion using System; +using System.Threading; using System.Threading.Tasks; using Grpc.Core.Logging; using Grpc.Core.Profiling; @@ -34,6 +35,8 @@ namespace Grpc.Core.Internal readonly CallInvocationDetails<TRequest, TResponse> details; readonly INativeCall injectedNativeCall; // for testing + bool registeredWithChannel; + // Dispose of to de-register cancellation token registration IDisposable cancellationTokenRegistration; @@ -79,42 +82,59 @@ namespace Grpc.Core.Internal { byte[] payload = UnsafeSerialize(msg); - unaryResponseTcs = new TaskCompletionSource<TResponse>(); - - lock (myLock) + bool callStartedOk = false; + try { - GrpcPreconditions.CheckState(!started); - started = true; - Initialize(cq); + unaryResponseTcs = new TaskCompletionSource<TResponse>(); - halfcloseRequested = true; - readingDone = true; - } + lock (myLock) + { + GrpcPreconditions.CheckState(!started); + started = true; + Initialize(cq); - using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) - { - var ctx = details.Channel.Environment.BatchContextPool.Lease(); - try + halfcloseRequested = true; + readingDone = true; + } + + using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) { - call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); - var ev = cq.Pluck(ctx.Handle); - bool success = (ev.success != 0); + var ctx = details.Channel.Environment.BatchContextPool.Lease(); try { - using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch")) + call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); + callStartedOk = true; + + var ev = cq.Pluck(ctx.Handle); + bool success = (ev.success != 0); + try + { + using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch")) + { + HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata()); + } + } + catch (Exception e) { - HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata()); + Logger.Error(e, "Exception occurred while invoking completion delegate."); } } - catch (Exception e) + finally { - Logger.Error(e, "Exception occurred while invoking completion delegate."); + ctx.Recycle(); } } - finally + } + catch (Exception) + { + if (!callStartedOk) { - ctx.Recycle(); + lock (myLock) + { + OnFailedToStartCallLocked(); + } } + throw; } // Once the blocking call returns, the result should be available synchronously. @@ -130,22 +150,36 @@ namespace Grpc.Core.Internal { lock (myLock) { - GrpcPreconditions.CheckState(!started); - started = true; + bool callStartedOk = false; + try + { + GrpcPreconditions.CheckState(!started); + started = true; - Initialize(details.Channel.CompletionQueue); + Initialize(details.Channel.CompletionQueue); - halfcloseRequested = true; - readingDone = true; + halfcloseRequested = true; + readingDone = true; - byte[] payload = UnsafeSerialize(msg); + byte[] payload = UnsafeSerialize(msg); - unaryResponseTcs = new TaskCompletionSource<TResponse>(); - using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + unaryResponseTcs = new TaskCompletionSource<TResponse>(); + using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + { + call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); + callStartedOk = true; + } + + return unaryResponseTcs.Task; + } + catch (Exception) { - call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); + if (!callStartedOk) + { + OnFailedToStartCallLocked(); + } + throw; } - return unaryResponseTcs.Task; } } @@ -157,20 +191,33 @@ namespace Grpc.Core.Internal { lock (myLock) { - GrpcPreconditions.CheckState(!started); - started = true; + bool callStartedOk = false; + try + { + GrpcPreconditions.CheckState(!started); + started = true; - Initialize(details.Channel.CompletionQueue); + Initialize(details.Channel.CompletionQueue); - readingDone = true; + readingDone = true; - unaryResponseTcs = new TaskCompletionSource<TResponse>(); - using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + unaryResponseTcs = new TaskCompletionSource<TResponse>(); + using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + { + call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags); + callStartedOk = true; + } + + return unaryResponseTcs.Task; + } + catch (Exception) { - call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags); + if (!callStartedOk) + { + OnFailedToStartCallLocked(); + } + throw; } - - return unaryResponseTcs.Task; } } @@ -181,21 +228,34 @@ namespace Grpc.Core.Internal { lock (myLock) { - GrpcPreconditions.CheckState(!started); - started = true; + bool callStartedOk = false; + try + { + GrpcPreconditions.CheckState(!started); + started = true; - Initialize(details.Channel.CompletionQueue); + Initialize(details.Channel.CompletionQueue); - halfcloseRequested = true; + halfcloseRequested = true; - byte[] payload = UnsafeSerialize(msg); + byte[] payload = UnsafeSerialize(msg); - streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); - using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); + using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + { + call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); + callStartedOk = true; + } + call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); + } + catch (Exception) { - call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); + if (!callStartedOk) + { + OnFailedToStartCallLocked(); + } + throw; } - call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); } } @@ -207,17 +267,30 @@ namespace Grpc.Core.Internal { lock (myLock) { - GrpcPreconditions.CheckState(!started); - started = true; + bool callStartedOk = false; + try + { + GrpcPreconditions.CheckState(!started); + started = true; - Initialize(details.Channel.CompletionQueue); + Initialize(details.Channel.CompletionQueue); - streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); - using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); + using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) + { + call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags); + callStartedOk = true; + } + call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); + } + catch (Exception) { - call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags); + if (!callStartedOk) + { + OnFailedToStartCallLocked(); + } + throw; } - call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); } } @@ -327,7 +400,11 @@ namespace Grpc.Core.Internal protected override void OnAfterReleaseResourcesLocked() { - details.Channel.RemoveCallReference(this); + if (registeredWithChannel) + { + details.Channel.RemoveCallReference(this); + registeredWithChannel = false; + } } protected override void OnAfterReleaseResourcesUnlocked() @@ -394,10 +471,27 @@ namespace Grpc.Core.Internal var call = CreateNativeCall(cq); details.Channel.AddCallReference(this); + registeredWithChannel = true; InitializeInternal(call); + RegisterCancellationCallback(); } + private void OnFailedToStartCallLocked() + { + ReleaseResources(); + + // We need to execute the hook that disposes the cancellation token + // registration, but it cannot be done from under a lock. + // To make things simple, we just schedule the unregistering + // on a threadpool. + // - Once the native call is disposed, the Cancel() calls are ignored anyway + // - We don't care about the overhead as OnFailedToStartCallLocked() only happens + // when something goes very bad when initializing a call and that should + // never happen when gRPC is used correctly. + ThreadPool.QueueUserWorkItem((state) => OnAfterReleaseResourcesUnlocked()); + } + private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq) { if (injectedNativeCall != null) diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs index 5a53049e4b..a93dc34620 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs @@ -189,7 +189,7 @@ namespace Grpc.Core.Internal /// </summary> protected abstract Exception GetRpcExceptionClientOnly(); - private void ReleaseResources() + protected void ReleaseResources() { if (call != null) { |