diff --git a/Ix.NET/Source/Directory.build.props b/Ix.NET/Source/Directory.build.props index a604effcfe..5059f8b1d4 100644 --- a/Ix.NET/Source/Directory.build.props +++ b/Ix.NET/Source/Directory.build.props @@ -23,7 +23,7 @@ - + diff --git a/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs new file mode 100644 index 0000000000..6f4336d078 --- /dev/null +++ b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs @@ -0,0 +1,342 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Tests +{ + public class Amb : AsyncEnumerableExTests + { + [Fact] + public void Amb_Null() + { + Assert.Throws(() => AsyncEnumerableEx.Amb(default, Return42)); + Assert.Throws(() => AsyncEnumerableEx.Amb(Return42, default)); + } + + [Fact] + public async Task Amb_First_Wins() + { + var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerableEx.Never()); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_First_Wins_Alt() + { + var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerable.Range(1, 5).SelectAwait(async v => + { + await Task.Delay(500); + return v; + })); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Second_Wins() + { + var source = AsyncEnumerableEx.Never().Amb(AsyncEnumerable.Range(1, 5)); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Second_Wins_Alt() + { + var source = AsyncEnumerable.Range(1, 5).SelectAwait(async v => + { + await Task.Delay(500); + return v; + }).Amb(AsyncEnumerable.Range(6, 5)); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i + 5, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_First_Wins() + { + var source = AsyncEnumerableEx.Amb( + AsyncEnumerable.Range(1, 5), + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never() + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_Last_Wins() + { + var source = AsyncEnumerableEx.Amb( + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never(), + AsyncEnumerable.Range(1, 5) + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_Enum_First_Wins() + { + var source = AsyncEnumerableEx.Amb(new[] { + AsyncEnumerable.Range(1, 5), + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never() + }.AsEnumerable() + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_Enum_Last_Wins() + { + var source = AsyncEnumerableEx.Amb(new[] { + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never(), + AsyncEnumerable.Range(1, 5) + }.AsEnumerable() + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + for (var i = 1; i <= 5; i++) + { + Assert.True(await xs.MoveNextAsync()); + Assert.Equal(i, xs.Current); + } + + Assert.False(await xs.MoveNextAsync()); + } + finally + { + await xs.DisposeAsync(); + } + } + + + [Fact] + public async Task Amb_First_GetAsyncEnumerator_Crashes() + { + var source = new FailingGetAsyncEnumerator().Amb(AsyncEnumerableEx.Never()); + + var xs = source.GetAsyncEnumerator(); + + try + { + await xs.MoveNextAsync(); + + Assert.False(true, "Should not have gotten here"); + } + catch (InvalidOperationException) + { + // we expect this + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Second_GetAsyncEnumerator_Crashes() + { + var source = AsyncEnumerableEx.Never().Amb(new FailingGetAsyncEnumerator()); + + var xs = source.GetAsyncEnumerator(); + + try + { + await xs.MoveNextAsync(); + + Assert.False(true, "Should not have gotten here"); + } + catch (InvalidOperationException) + { + // we expect this + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_First_GetAsyncEnumerator_Crashes() + { + var source = AsyncEnumerableEx.Amb( + new FailingGetAsyncEnumerator(), + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never() + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + await xs.MoveNextAsync(); + + Assert.False(true, "Should not have gotten here"); + } + catch (InvalidOperationException) + { + // we expect this + } + finally + { + await xs.DisposeAsync(); + } + } + + [Fact] + public async Task Amb_Many_Last_GetAsyncEnumerator_Crashes() + { + var source = AsyncEnumerableEx.Amb( + AsyncEnumerableEx.Never(), + AsyncEnumerableEx.Never(), + new FailingGetAsyncEnumerator() + ); + + var xs = source.GetAsyncEnumerator(); + + try + { + await xs.MoveNextAsync(); + + Assert.False(true, "Should not have gotten here"); + } + catch (InvalidOperationException) + { + // we expect this + } + finally + { + await xs.DisposeAsync(); + } + } + + private class FailingGetAsyncEnumerator : IAsyncEnumerable + { + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + throw new InvalidOperationException(); + } + } + } +} diff --git a/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs new file mode 100644 index 0000000000..aa2dbffa23 --- /dev/null +++ b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Tests +{ + public class Timeout : AsyncEnumerableExTests + { + [Fact] + public async Task Timeout_Never() + { + var source = AsyncEnumerableEx.Never().Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Double_Never() + { + var source = AsyncEnumerableEx.Never() + .Timeout(TimeSpan.FromMilliseconds(300)) + .Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Delayed_Main() + { + var source = AsyncEnumerable.Range(1, 5) + .SelectAwait(async v => + { + await Task.Delay(300); + return v; + }) + .Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Delayed_Main_Canceled() + { + var tcs = new TaskCompletionSource(); + + var source = AsyncEnumerable.Range(1, 5) + .SelectAwaitWithCancellation(async (v, ct) => + { + try + { + await Task.Delay(500, ct); + } + catch (TaskCanceledException) + { + tcs.SetResult(0); + } + return v; + }) + .Timeout(TimeSpan.FromMilliseconds(250)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + + Assert.Equal(0, await tcs.Task); + } + } +} diff --git a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs index 849a41f8f5..6274d8a8ad 100644 --- a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs +++ b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs @@ -27,6 +27,14 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) Task? firstMoveNext = null; Task? secondMoveNext = null; + // + // We need separate tokens for each source so that the non-winner can get disposed and unblocked + // i.e., see Never() + // + + var firstCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var secondCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + try { // @@ -36,7 +44,7 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) // adding a WhenAny combinator that does exactly that. We can even avoid calling AsTask. // - firstEnumerator = first.GetAsyncEnumerator(cancellationToken); + firstEnumerator = first.GetAsyncEnumerator(firstCancelToken.Token); firstMoveNext = firstEnumerator.MoveNextAsync().AsTask(); // @@ -44,11 +52,14 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) // overload which performs GetAsyncEnumerator/MoveNextAsync in pairs, rather than phased. // - secondEnumerator = second.GetAsyncEnumerator(cancellationToken); + secondEnumerator = second.GetAsyncEnumerator(secondCancelToken.Token); secondMoveNext = secondEnumerator.MoveNextAsync().AsTask(); } catch { + secondCancelToken.Cancel(); + firstCancelToken.Cancel(); + // NB: AwaitMoveNextAsyncAndDispose checks for null for both arguments, reducing the need for many null // checks over here. @@ -58,6 +69,7 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator) }; + await Task.WhenAll(cleanup).ConfigureAwait(false); throw; @@ -83,11 +95,13 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) if (moveNextWinner == firstMoveNext) { winner = firstEnumerator; + secondCancelToken.Cancel(); disposeLoser = AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator); } else { winner = secondEnumerator; + firstCancelToken.Cancel(); disposeLoser = AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator); } @@ -143,12 +157,17 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) var enumerators = new IAsyncEnumerator[n]; var moveNexts = new Task[n]; + var individualTokenSources = new CancellationTokenSource[n]; + for (var i = 0; i < n; i++) + { + individualTokenSources[i] = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + } try { for (var i = 0; i < n; i++) { - var enumerator = sources[i].GetAsyncEnumerator(cancellationToken); + var enumerator = sources[i].GetAsyncEnumerator(individualTokenSources[i].Token); enumerators[i] = enumerator; moveNexts[i] = enumerator.MoveNextAsync().AsTask(); @@ -158,12 +177,15 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) { var cleanup = new Task[n]; - for (var i = 0; i < n; i++) + for (var i = n - 1; i >= 0; i--) { + individualTokenSources[i].Cancel(); + cleanup[i] = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]); } await Task.WhenAll(cleanup).ConfigureAwait(false); + throw; } @@ -185,10 +207,11 @@ async IAsyncEnumerator Core(CancellationToken cancellationToken) var loserCleanupTasks = new List(n - 1); - for (var i = 0; i < n; i++) + for (var i = n - 1; i >= 0; i--) { if (i != winnerIndex) { + individualTokenSources[i].Cancel(); var loserCleanupTask = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]); loserCleanupTasks.Add(loserCleanupTask); } @@ -236,7 +259,14 @@ await using (enumerator.ConfigureAwait(false)) { if (moveNextAsync != null) { - await moveNextAsync.ConfigureAwait(false); + try + { + await moveNextAsync.ConfigureAwait(false); + } + catch (TaskCanceledException) + { + // ignored because of cancelling the non-winners + } } } } diff --git a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Never.cs b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Never.cs index 66efb62596..969e2177d9 100644 --- a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Never.cs +++ b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Never.cs @@ -49,7 +49,7 @@ public ValueTask MoveNextAsync() _once = true; var task = new TaskCompletionSource(); - _registration = _token.Register(state => ((TaskCompletionSource)state!).SetCanceled(), task); + _registration = _token.Register(state => ((TaskCompletionSource)state).TrySetCanceled(_token), task); return new ValueTask(task.Task); } } diff --git a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs index f7f4f3640c..5c84398abf 100644 --- a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs +++ b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs @@ -31,6 +31,8 @@ private sealed class TimeoutAsyncIterator : AsyncIterator private Task? _loserTask; + private CancellationTokenSource? _sourceCTS; + public TimeoutAsyncIterator(IAsyncEnumerable source, TimeSpan timeout) { _source = source; @@ -55,6 +57,11 @@ public override async ValueTask DisposeAsync() await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } + if (_sourceCTS != null) + { + _sourceCTS.Dispose(); + _sourceCTS = null; + } await base.DisposeAsync().ConfigureAwait(false); } @@ -64,7 +71,8 @@ protected override async ValueTask MoveNextCore() switch (_state) { case AsyncIteratorState.Allocated: - _enumerator = _source.GetAsyncEnumerator(_cancellationToken); + _sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken); + _enumerator = _source.GetAsyncEnumerator(_sourceCTS.Token); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; @@ -74,7 +82,7 @@ protected override async ValueTask MoveNextCore() if (!moveNext.IsCompleted) { - using var delayCts = new CancellationTokenSource(); + using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken); var delay = Task.Delay(_timeout, delayCts.Token); @@ -98,6 +106,8 @@ protected override async ValueTask MoveNextCore() _loserTask = next.ContinueWith((_, state) => ((IAsyncDisposable)state!).DisposeAsync().AsTask(), _enumerator); + _sourceCTS!.Cancel(); + throw new TimeoutException(); } diff --git a/Rx.NET/Source/Directory.build.props b/Rx.NET/Source/Directory.build.props index 101cc3a79b..f7177a820e 100644 --- a/Rx.NET/Source/Directory.build.props +++ b/Rx.NET/Source/Directory.build.props @@ -25,7 +25,7 @@ - + diff --git a/Rx.NET/Source/src/System.Reactive/Concurrency/EventLoopScheduler.cs b/Rx.NET/Source/src/System.Reactive/Concurrency/EventLoopScheduler.cs index ed68a66d14..0c0e62e217 100644 --- a/Rx.NET/Source/src/System.Reactive/Concurrency/EventLoopScheduler.cs +++ b/Rx.NET/Source/src/System.Reactive/Concurrency/EventLoopScheduler.cs @@ -153,7 +153,7 @@ public override IDisposable Schedule(TState state, TimeSpan dueTime, Fun { if (_disposed) { - throw new ObjectDisposedException(""); + throw new ObjectDisposedException(nameof(EventLoopScheduler)); } if (dueTime <= TimeSpan.Zero) @@ -351,7 +351,15 @@ private void Run() { if (!item.IsCanceled) { - item.Invoke(); + try + { + item.Invoke(); + } + catch (ObjectDisposedException ex) when (nameof(EventLoopScheduler).Equals(ex.ObjectName)) + { + // Since we are not inside the lock at this point + // the scheduler can be disposed before the item had a chance to run + } } } } diff --git a/Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Tests.System.Reactive.ApiApprovals.csproj b/Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Tests.System.Reactive.ApiApprovals.csproj index b1d9859757..ec91a0a353 100644 --- a/Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Tests.System.Reactive.ApiApprovals.csproj +++ b/Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Tests.System.Reactive.ApiApprovals.csproj @@ -31,7 +31,7 @@ - + diff --git a/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Concurrency/EventLoopSchedulerTest.cs b/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Concurrency/EventLoopSchedulerTest.cs index 24288732ea..44a0a26159 100644 --- a/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Concurrency/EventLoopSchedulerTest.cs +++ b/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Concurrency/EventLoopSchedulerTest.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.Reactive.Concurrency; using System.Reactive.Disposables; +using System.Reactive.Linq; using System.Threading; using Microsoft.Reactive.Testing; using Xunit; @@ -41,6 +42,19 @@ public void EventLoop_Now() Assert.True(res.Seconds < 1); } + [Fact] + public void EventLoop_DisposeWithInFlightActions() + { + using (var scheduler = new EventLoopScheduler()) + using (var subscription = Observable + .Range(1, 10) + .ObserveOn(scheduler) + .Subscribe(_ => Thread.Sleep(50))) + { + Thread.Sleep(50); + } + } + [Fact] public void EventLoop_ScheduleAction() {