diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index 606f20deba..056b0fcebf 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -17,6 +17,8 @@ using System.Transactions; using Microsoft.Data.Common; +[assembly: InternalsVisibleTo("FunctionalTests")] + namespace Microsoft.Data.SqlClient { internal static class AsyncHelper @@ -204,6 +206,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout } if (!task.IsCompleted) { + task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception if (onTimeout != null) { onTimeout(); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs index 9d64e1a3d9..0cc0276d6f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -17,6 +17,8 @@ using System.Threading.Tasks; using SysTx = System.Transactions; +[assembly: InternalsVisibleTo("FunctionalTests")] + namespace Microsoft.Data.SqlClient { using Microsoft.Data.Common; @@ -189,6 +191,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout } if (!task.IsCompleted) { + task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception if (onTimeout != null) { onTimeout(); diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/BaseProviderAsyncTest/BaseProviderAsyncTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/BaseProviderAsyncTest/BaseProviderAsyncTest.cs index 1c340f771b..0d13ad5f77 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/BaseProviderAsyncTest/BaseProviderAsyncTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/BaseProviderAsyncTest/BaseProviderAsyncTest.cs @@ -13,6 +13,11 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public static class BaseProviderAsyncTest { + private static void AssertTaskFaults(Task t) + { + Assert.ThrowsAny(() => t.Wait(TimeSpan.FromMilliseconds(1))); + } + [Fact] public static void TestDbConnection() { @@ -37,8 +42,8 @@ public static void TestDbConnection() { Fail = true }; - connectionFail.OpenAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - connectionFail.OpenAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); + AssertTaskFaults(connectionFail.OpenAsync()); + AssertTaskFaults(connectionFail.OpenAsync(source.Token)); // Verify base implementation does not call Open when passed an already cancelled cancellation token source.Cancel(); @@ -90,14 +95,14 @@ public static void TestDbCommand() { Fail = true }; - commandFail.ExecuteNonQueryAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteNonQueryAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteReaderAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteReaderAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteScalarAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - commandFail.ExecuteScalarAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); + AssertTaskFaults(commandFail.ExecuteNonQueryAsync()); + AssertTaskFaults(commandFail.ExecuteNonQueryAsync(source.Token)); + AssertTaskFaults(commandFail.ExecuteReaderAsync()); + AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess)); + AssertTaskFaults(commandFail.ExecuteReaderAsync(source.Token)); + AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token)); + AssertTaskFaults(commandFail.ExecuteScalarAsync()); + AssertTaskFaults(commandFail.ExecuteScalarAsync(source.Token)); // Verify base implementation does not call Open when passed an already cancelled cancellation token source.Cancel(); @@ -116,17 +121,17 @@ public static void TestDbCommand() source = new CancellationTokenSource(); Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); }); Task result = command.ExecuteNonQueryAsync(source.Token); - Assert.True(result.IsFaulted, "Task result should be faulted"); + Assert.True(result.Exception != null, "Task result should be faulted"); source = new CancellationTokenSource(); Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); }); result = command.ExecuteReaderAsync(source.Token); - Assert.True(result.IsFaulted, "Task result should be faulted"); + Assert.True(result.Exception != null, "Task result should be faulted"); source = new CancellationTokenSource(); Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); }); result = command.ExecuteScalarAsync(source.Token); - Assert.True(result.IsFaulted, "Task result should be faulted"); + Assert.True(result.Exception != null, "Task result should be faulted"); } [Fact] @@ -155,9 +160,9 @@ public static void TestDbDataReader() GetFieldValueAsync(reader, 2, DBNull.Value); GetFieldValueAsync(reader, 2, DBNull.Value); - reader.GetFieldValueAsync(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - reader.GetFieldValueAsync(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - reader.GetFieldValueAsync(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); + AssertTaskFaults(reader.GetFieldValueAsync(2)); + AssertTaskFaults(reader.GetFieldValueAsync(2)); + AssertTaskFaults(reader.GetFieldValueAsync(2)); AssertEqualsWithDescription("GetValue", reader.LastCommand, "Last command was not as expected"); result = reader.ReadAsync(); @@ -174,12 +179,12 @@ public static void TestDbDataReader() Assert.False(result.Result, "Should NOT have received a Result from NextResultAsync"); MockDataReader readerFail = new MockDataReader { Results = query.GetEnumerator(), Fail = true }; - readerFail.ReadAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - readerFail.ReadAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - readerFail.NextResultAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - readerFail.NextResultAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - readerFail.GetFieldValueAsync(0).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); - readerFail.GetFieldValueAsync(0, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait(); + AssertTaskFaults(readerFail.ReadAsync()); + AssertTaskFaults(readerFail.ReadAsync(source.Token)); + AssertTaskFaults(readerFail.NextResultAsync()); + AssertTaskFaults(readerFail.NextResultAsync(source.Token)); + AssertTaskFaults(readerFail.GetFieldValueAsync(0)); + AssertTaskFaults(readerFail.GetFieldValueAsync(0, source.Token)); source.Cancel(); reader.LastCommand = "Nothing"; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index fc7ac315b4..e93c6c867f 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -42,6 +42,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs new file mode 100644 index 0000000000..30152bd2c3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Data.SqlClient.Tests +{ + public class SqlHelperTest + { + private void TimeOutATask() + { + TaskCompletionSource tcs = new TaskCompletionSource(); + AsyncHelper.WaitForCompletion(tcs.Task, 1); //Will time out as task uncompleted + tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error + } + + private Exception UnwrapException(Exception e) + { + return e?.InnerException != null ? UnwrapException(e.InnerException) : e; + } + + [Fact] + public void WaitForCompletion_DoesNotCreateUnobservedException() + { + var unobservedExceptionHappenedEvent = new AutoResetEvent(false); + Exception unhandledException = null; + void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a) + { unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); } + + TaskScheduler.UnobservedTaskException += handleUnobservedException; + + try + { + TimeOutATask(); //Create the task in another function so the task has no reference remaining + GC.Collect(); //Force collection of unobserved task + GC.WaitForPendingFinalizers(); + + bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1); + if (unobservedExceptionHappend) //Save doing string interpolation in the happy case + { + var e = UnwrapException(unhandledException); + Assert.False(true, $"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\""); + } + } + finally + { + TaskScheduler.UnobservedTaskException -= handleUnobservedException; + } + } + } +}