Skip to content

Commit

Permalink
Fix | AsyncHelper.WaitForCompletion leaks unobserved exceptions (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
jm771 committed Oct 22, 2020
1 parent 4107f24 commit a947ded
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 22 deletions.
Expand Up @@ -17,6 +17,8 @@
using System.Transactions;
using Microsoft.Data.Common;

[assembly: InternalsVisibleTo("FunctionalTests")]

namespace Microsoft.Data.SqlClient
{
internal static class AsyncHelper
Expand Down Expand Up @@ -204,6 +206,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout
}
if (!task.IsCompleted)
{
task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception
if (onTimeout != null)
{
onTimeout();
Expand Down
Expand Up @@ -17,6 +17,8 @@
using System.Threading.Tasks;
using SysTx = System.Transactions;

[assembly: InternalsVisibleTo("FunctionalTests")]

namespace Microsoft.Data.SqlClient
{
using Microsoft.Data.Common;
Expand Down Expand Up @@ -189,6 +191,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout
}
if (!task.IsCompleted)
{
task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception
if (onTimeout != null)
{
onTimeout();
Expand Down
Expand Up @@ -13,6 +13,11 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class BaseProviderAsyncTest
{
private static void AssertTaskFaults(Task t)
{
Assert.ThrowsAny<Exception>(() => t.Wait(TimeSpan.FromMilliseconds(1)));
}

[Fact]
public static void TestDbConnection()
{
Expand All @@ -37,8 +42,8 @@ public static void TestDbConnection()
{
Fail = true
};
connectionFail.OpenAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
connectionFail.OpenAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
AssertTaskFaults(connectionFail.OpenAsync());
AssertTaskFaults(connectionFail.OpenAsync(source.Token));

// Verify base implementation does not call Open when passed an already cancelled cancellation token
source.Cancel();
Expand Down Expand Up @@ -90,14 +95,14 @@ public static void TestDbCommand()
{
Fail = true
};
commandFail.ExecuteNonQueryAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteNonQueryAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteReaderAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteReaderAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteScalarAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
commandFail.ExecuteScalarAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
AssertTaskFaults(commandFail.ExecuteNonQueryAsync());
AssertTaskFaults(commandFail.ExecuteNonQueryAsync(source.Token));
AssertTaskFaults(commandFail.ExecuteReaderAsync());
AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess));
AssertTaskFaults(commandFail.ExecuteReaderAsync(source.Token));
AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token));
AssertTaskFaults(commandFail.ExecuteScalarAsync());
AssertTaskFaults(commandFail.ExecuteScalarAsync(source.Token));

// Verify base implementation does not call Open when passed an already cancelled cancellation token
source.Cancel();
Expand All @@ -116,17 +121,17 @@ public static void TestDbCommand()
source = new CancellationTokenSource();
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
Task result = command.ExecuteNonQueryAsync(source.Token);
Assert.True(result.IsFaulted, "Task result should be faulted");
Assert.True(result.Exception != null, "Task result should be faulted");

source = new CancellationTokenSource();
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
result = command.ExecuteReaderAsync(source.Token);
Assert.True(result.IsFaulted, "Task result should be faulted");
Assert.True(result.Exception != null, "Task result should be faulted");

source = new CancellationTokenSource();
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
result = command.ExecuteScalarAsync(source.Token);
Assert.True(result.IsFaulted, "Task result should be faulted");
Assert.True(result.Exception != null, "Task result should be faulted");
}

[Fact]
Expand Down Expand Up @@ -155,9 +160,9 @@ public static void TestDbDataReader()

GetFieldValueAsync<object>(reader, 2, DBNull.Value);
GetFieldValueAsync<DBNull>(reader, 2, DBNull.Value);
reader.GetFieldValueAsync<int?>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
reader.GetFieldValueAsync<string>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
reader.GetFieldValueAsync<bool>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
AssertTaskFaults(reader.GetFieldValueAsync<int?>(2));
AssertTaskFaults(reader.GetFieldValueAsync<string>(2));
AssertTaskFaults(reader.GetFieldValueAsync<bool>(2));
AssertEqualsWithDescription("GetValue", reader.LastCommand, "Last command was not as expected");

result = reader.ReadAsync();
Expand All @@ -174,12 +179,12 @@ public static void TestDbDataReader()
Assert.False(result.Result, "Should NOT have received a Result from NextResultAsync");

MockDataReader readerFail = new MockDataReader { Results = query.GetEnumerator(), Fail = true };
readerFail.ReadAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
readerFail.ReadAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
readerFail.NextResultAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
readerFail.NextResultAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
readerFail.GetFieldValueAsync<object>(0).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
readerFail.GetFieldValueAsync<object>(0, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
AssertTaskFaults(readerFail.ReadAsync());
AssertTaskFaults(readerFail.ReadAsync(source.Token));
AssertTaskFaults(readerFail.NextResultAsync());
AssertTaskFaults(readerFail.NextResultAsync(source.Token));
AssertTaskFaults(readerFail.GetFieldValueAsync<object>(0));
AssertTaskFaults(readerFail.GetFieldValueAsync<object>(0, source.Token));

source.Cancel();
reader.LastCommand = "Nothing";
Expand Down
Expand Up @@ -42,6 +42,7 @@
<Compile Include="SqlCredentialTest.cs" />
<Compile Include="SqlDataRecordTest.cs" />
<Compile Include="SqlExceptionTest.cs" />
<Compile Include="SqlHelperTest.cs" />
<Compile Include="SqlParameterTest.cs" />
<Compile Include="SqlClientFactoryTest.cs" />
<Compile Include="SqlErrorCollectionTest.cs" />
Expand Down
@@ -0,0 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
{
public class SqlHelperTest
{
private void TimeOutATask()
{
TaskCompletionSource<bool> tcs = new TaskCompletionSource<bool>();
AsyncHelper.WaitForCompletion(tcs.Task, 1); //Will time out as task uncompleted
tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error
}

private Exception UnwrapException(Exception e)
{
return e?.InnerException != null ? UnwrapException(e.InnerException) : e;
}

[Fact]
public void WaitForCompletion_DoesNotCreateUnobservedException()
{
var unobservedExceptionHappenedEvent = new AutoResetEvent(false);
Exception unhandledException = null;
void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a)
{ unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); }

TaskScheduler.UnobservedTaskException += handleUnobservedException;

try
{
TimeOutATask(); //Create the task in another function so the task has no reference remaining
GC.Collect(); //Force collection of unobserved task
GC.WaitForPendingFinalizers();

bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1);
if (unobservedExceptionHappend) //Save doing string interpolation in the happy case
{
var e = UnwrapException(unhandledException);
Assert.False(true, $"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\"");
}
}
finally
{
TaskScheduler.UnobservedTaskException -= handleUnobservedException;
}
}
}
}

0 comments on commit a947ded

Please sign in to comment.