aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jan Tattermusch <jtattermusch@users.noreply.github.com>2018-09-07 18:21:35 -0700
committerGravatar GitHub <noreply@github.com>2018-09-07 18:21:35 -0700
commitd90d082ca228eb85d1de79623acf6fe5f44a4cce (patch)
tree760de2938e8202842bcf4374b3d1c714075ea1c3
parent9dfbb81d0ee8fe7f715b6b4fc9809ffc9ec16da4 (diff)
parent917af9a47f6141b897e1288a71e846f054941a39 (diff)
Merge pull request #16554 from jtattermusch/csharp_dont_leak_when_call_init_fails
C#: avoid leaking resources when starting a call fails
-rw-r--r--src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs75
-rw-r--r--src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs23
-rw-r--r--src/csharp/Grpc.Core/Channel.cs6
-rw-r--r--src/csharp/Grpc.Core/Internal/AsyncCall.cs207
-rw-r--r--src/csharp/Grpc.Core/Internal/AsyncCallBase.cs2
5 files changed, 253 insertions, 60 deletions
diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
index 9aab54d2d0..775849d89b 100644
--- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
+++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
@@ -107,6 +107,42 @@ namespace Grpc.Core.Internal.Tests
}
[Test]
+ public void AsyncUnary_RequestSerializationExceptionDoesntLeakResources()
+ {
+ string nullRequest = null; // will throw when serializing
+ Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCallAsync(nullRequest));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
+ public void AsyncUnary_StartCallFailureDoesntLeakResources()
+ {
+ fakeCall.MakeStartCallFail();
+ Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCallAsync("request1"));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
+ public void SyncUnary_RequestSerializationExceptionDoesntLeakResources()
+ {
+ string nullRequest = null; // will throw when serializing
+ Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCall(nullRequest));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
+ public void SyncUnary_StartCallFailureDoesntLeakResources()
+ {
+ fakeCall.MakeStartCallFail();
+ Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCall("request1"));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
public void ClientStreaming_StreamingReadNotAllowed()
{
asyncCall.ClientStreamingCallAsync();
@@ -328,6 +364,15 @@ namespace Grpc.Core.Internal.Tests
}
[Test]
+ public void ClientStreaming_StartCallFailureDoesntLeakResources()
+ {
+ fakeCall.MakeStartCallFail();
+ Assert.Throws(typeof(InvalidOperationException), () => asyncCall.ClientStreamingCallAsync());
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
public void ServerStreaming_StreamingSendNotAllowed()
{
asyncCall.StartServerStreamingCall("request1");
@@ -402,6 +447,27 @@ namespace Grpc.Core.Internal.Tests
}
[Test]
+ public void ServerStreaming_RequestSerializationExceptionDoesntLeakResources()
+ {
+ string nullRequest = null; // will throw when serializing
+ Assert.Throws(typeof(ArgumentNullException), () => asyncCall.StartServerStreamingCall(nullRequest));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+
+ var responseStream = new ClientResponseStream<string, string>(asyncCall);
+ var readTask = responseStream.MoveNext();
+ }
+
+ [Test]
+ public void ServerStreaming_StartCallFailureDoesntLeakResources()
+ {
+ fakeCall.MakeStartCallFail();
+ Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartServerStreamingCall("request1"));
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
+ [Test]
public void DuplexStreaming_NoRequestNoResponse_Success()
{
asyncCall.StartDuplexStreamingCall();
@@ -558,6 +624,15 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled);
}
+ [Test]
+ public void DuplexStreaming_StartCallFailureDoesntLeakResources()
+ {
+ fakeCall.MakeStartCallFail();
+ Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartDuplexStreamingCall());
+ Assert.AreEqual(0, channel.GetCallReferenceCount());
+ Assert.IsTrue(fakeCall.IsDisposed);
+ }
+
ClientSideStatus CreateClientSideStatus(StatusCode statusCode)
{
return new ClientSideStatus(new Status(statusCode, ""), new Metadata());
diff --git a/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs b/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
index 581ac3384b..ef67918dab 100644
--- a/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
+++ b/src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
@@ -31,6 +31,7 @@ namespace Grpc.Core.Internal.Tests
/// </summary>
internal class FakeNativeCall : INativeCall
{
+ private bool shouldStartCallFail;
public IUnaryResponseClientCallback UnaryResponseClientCallback
{
get;
@@ -102,26 +103,31 @@ namespace Grpc.Core.Internal.Tests
public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
+ StartCallMaybeFail();
UnaryResponseClientCallback = callback;
}
public void StartUnary(BatchContextSafeHandle ctx, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
+ StartCallMaybeFail();
throw new NotImplementedException();
}
public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
+ StartCallMaybeFail();
UnaryResponseClientCallback = callback;
}
public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
+ StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback;
}
public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
+ StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback;
}
@@ -165,5 +171,22 @@ namespace Grpc.Core.Internal.Tests
{
IsDisposed = true;
}
+
+ /// <summary>
+ /// Emulate CallSafeHandle.CheckOk() failure for all future attempts
+ /// to start a call.
+ /// </summary>
+ public void MakeStartCallFail()
+ {
+ shouldStartCallFail = true;
+ }
+
+ private void StartCallMaybeFail()
+ {
+ if (shouldStartCallFail)
+ {
+ throw new InvalidOperationException("Start call has failed.");
+ }
+ }
}
}
diff --git a/src/csharp/Grpc.Core/Channel.cs b/src/csharp/Grpc.Core/Channel.cs
index 4c89ed7393..7ce929dfa3 100644
--- a/src/csharp/Grpc.Core/Channel.cs
+++ b/src/csharp/Grpc.Core/Channel.cs
@@ -297,6 +297,12 @@ namespace Grpc.Core
activeCallCounter.Decrement();
}
+ // for testing only
+ internal long GetCallReferenceCount()
+ {
+ return activeCallCounter.Count;
+ }
+
private ChannelState GetConnectivityState(bool tryToConnect)
{
try
diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs
index 66902f3caa..4cdf0ee6a7 100644
--- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs
+++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs
@@ -17,6 +17,7 @@
#endregion
using System;
+using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Logging;
using Grpc.Core.Profiling;
@@ -34,6 +35,8 @@ namespace Grpc.Core.Internal
readonly CallInvocationDetails<TRequest, TResponse> details;
readonly INativeCall injectedNativeCall; // for testing
+ bool registeredWithChannel;
+
// Dispose of to de-register cancellation token registration
IDisposable cancellationTokenRegistration;
@@ -77,43 +80,59 @@ namespace Grpc.Core.Internal
using (profiler.NewScope("AsyncCall.UnaryCall"))
using (CompletionQueueSafeHandle cq = CompletionQueueSafeHandle.CreateSync())
{
- byte[] payload = UnsafeSerialize(msg);
+ bool callStartedOk = false;
+ try
+ {
+ unaryResponseTcs = new TaskCompletionSource<TResponse>();
- unaryResponseTcs = new TaskCompletionSource<TResponse>();
+ lock (myLock)
+ {
+ GrpcPreconditions.CheckState(!started);
+ started = true;
+ Initialize(cq);
- lock (myLock)
- {
- GrpcPreconditions.CheckState(!started);
- started = true;
- Initialize(cq);
+ halfcloseRequested = true;
+ readingDone = true;
+ }
- halfcloseRequested = true;
- readingDone = true;
- }
+ byte[] payload = UnsafeSerialize(msg);
- using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
- {
- var ctx = details.Channel.Environment.BatchContextPool.Lease();
- try
+ using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
- call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
- var ev = cq.Pluck(ctx.Handle);
- bool success = (ev.success != 0);
+ var ctx = details.Channel.Environment.BatchContextPool.Lease();
try
{
- using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
+ call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
+ callStartedOk = true;
+
+ var ev = cq.Pluck(ctx.Handle);
+ bool success = (ev.success != 0);
+ try
+ {
+ using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
+ {
+ HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
+ }
+ }
+ catch (Exception e)
{
- HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
+ Logger.Error(e, "Exception occurred while invoking completion delegate.");
}
}
- catch (Exception e)
+ finally
{
- Logger.Error(e, "Exception occurred while invoking completion delegate.");
+ ctx.Recycle();
}
}
- finally
+ }
+ finally
+ {
+ if (!callStartedOk)
{
- ctx.Recycle();
+ lock (myLock)
+ {
+ OnFailedToStartCallLocked();
+ }
}
}
@@ -130,22 +149,35 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
- GrpcPreconditions.CheckState(!started);
- started = true;
+ bool callStartedOk = false;
+ try
+ {
+ GrpcPreconditions.CheckState(!started);
+ started = true;
- Initialize(details.Channel.CompletionQueue);
+ Initialize(details.Channel.CompletionQueue);
- halfcloseRequested = true;
- readingDone = true;
+ halfcloseRequested = true;
+ readingDone = true;
+
+ byte[] payload = UnsafeSerialize(msg);
- byte[] payload = UnsafeSerialize(msg);
+ unaryResponseTcs = new TaskCompletionSource<TResponse>();
+ using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ {
+ call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
+ callStartedOk = true;
+ }
- unaryResponseTcs = new TaskCompletionSource<TResponse>();
- using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ return unaryResponseTcs.Task;
+ }
+ finally
{
- call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
+ if (!callStartedOk)
+ {
+ OnFailedToStartCallLocked();
+ }
}
- return unaryResponseTcs.Task;
}
}
@@ -157,20 +189,32 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
- GrpcPreconditions.CheckState(!started);
- started = true;
+ bool callStartedOk = false;
+ try
+ {
+ GrpcPreconditions.CheckState(!started);
+ started = true;
- Initialize(details.Channel.CompletionQueue);
+ Initialize(details.Channel.CompletionQueue);
- readingDone = true;
+ readingDone = true;
+
+ unaryResponseTcs = new TaskCompletionSource<TResponse>();
+ using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ {
+ call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
+ callStartedOk = true;
+ }
- unaryResponseTcs = new TaskCompletionSource<TResponse>();
- using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ return unaryResponseTcs.Task;
+ }
+ finally
{
- call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
+ if (!callStartedOk)
+ {
+ OnFailedToStartCallLocked();
+ }
}
-
- return unaryResponseTcs.Task;
}
}
@@ -181,21 +225,33 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
- GrpcPreconditions.CheckState(!started);
- started = true;
+ bool callStartedOk = false;
+ try
+ {
+ GrpcPreconditions.CheckState(!started);
+ started = true;
- Initialize(details.Channel.CompletionQueue);
+ Initialize(details.Channel.CompletionQueue);
- halfcloseRequested = true;
+ halfcloseRequested = true;
- byte[] payload = UnsafeSerialize(msg);
+ byte[] payload = UnsafeSerialize(msg);
- streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
- using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
+ using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ {
+ call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
+ callStartedOk = true;
+ }
+ call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
+ }
+ finally
{
- call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
+ if (!callStartedOk)
+ {
+ OnFailedToStartCallLocked();
+ }
}
- call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
}
@@ -207,17 +263,29 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
- GrpcPreconditions.CheckState(!started);
- started = true;
+ bool callStartedOk = false;
+ try
+ {
+ GrpcPreconditions.CheckState(!started);
+ started = true;
- Initialize(details.Channel.CompletionQueue);
+ Initialize(details.Channel.CompletionQueue);
- streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
- using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
+ using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
+ {
+ call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
+ callStartedOk = true;
+ }
+ call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
+ }
+ finally
{
- call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
+ if (!callStartedOk)
+ {
+ OnFailedToStartCallLocked();
+ }
}
- call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
}
@@ -327,7 +395,11 @@ namespace Grpc.Core.Internal
protected override void OnAfterReleaseResourcesLocked()
{
- details.Channel.RemoveCallReference(this);
+ if (registeredWithChannel)
+ {
+ details.Channel.RemoveCallReference(this);
+ registeredWithChannel = false;
+ }
}
protected override void OnAfterReleaseResourcesUnlocked()
@@ -394,10 +466,27 @@ namespace Grpc.Core.Internal
var call = CreateNativeCall(cq);
details.Channel.AddCallReference(this);
+ registeredWithChannel = true;
InitializeInternal(call);
+
RegisterCancellationCallback();
}
+ private void OnFailedToStartCallLocked()
+ {
+ ReleaseResources();
+
+ // We need to execute the hook that disposes the cancellation token
+ // registration, but it cannot be done from under a lock.
+ // To make things simple, we just schedule the unregistering
+ // on a threadpool.
+ // - Once the native call is disposed, the Cancel() calls are ignored anyway
+ // - We don't care about the overhead as OnFailedToStartCallLocked() only happens
+ // when something goes very bad when initializing a call and that should
+ // never happen when gRPC is used correctly.
+ ThreadPool.QueueUserWorkItem((state) => OnAfterReleaseResourcesUnlocked());
+ }
+
private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq)
{
if (injectedNativeCall != null)
diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
index 5a53049e4b..a93dc34620 100644
--- a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
+++ b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
@@ -189,7 +189,7 @@ namespace Grpc.Core.Internal
/// </summary>
protected abstract Exception GetRpcExceptionClientOnly();
- private void ReleaseResources()
+ protected void ReleaseResources()
{
if (call != null)
{