diff --git a/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs new file mode 100644 index 0000000000..aa2dbffa23 --- /dev/null +++ b/Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Tests +{ + public class Timeout : AsyncEnumerableExTests + { + [Fact] + public async Task Timeout_Never() + { + var source = AsyncEnumerableEx.Never().Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Double_Never() + { + var source = AsyncEnumerableEx.Never() + .Timeout(TimeSpan.FromMilliseconds(300)) + .Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Delayed_Main() + { + var source = AsyncEnumerable.Range(1, 5) + .SelectAwait(async v => + { + await Task.Delay(300); + return v; + }) + .Timeout(TimeSpan.FromMilliseconds(100)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + } + + [Fact] + public async Task Timeout_Delayed_Main_Canceled() + { + var tcs = new TaskCompletionSource(); + + var source = AsyncEnumerable.Range(1, 5) + .SelectAwaitWithCancellation(async (v, ct) => + { + try + { + await Task.Delay(500, ct); + } + catch (TaskCanceledException) + { + tcs.SetResult(0); + } + return v; + }) + .Timeout(TimeSpan.FromMilliseconds(250)); + + var en = source.GetAsyncEnumerator(); + + try + { + await en.MoveNextAsync(); + + Assert.False(true, "MoveNextAsync should have thrown"); + } + catch (TimeoutException) + { + // expected + } + finally + { + await en.DisposeAsync(); + } + + Assert.Equal(0, await tcs.Task); + } + } +} diff --git a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs index f7f4f3640c..5c84398abf 100644 --- a/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs +++ b/Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs @@ -31,6 +31,8 @@ private sealed class TimeoutAsyncIterator : AsyncIterator private Task? _loserTask; + private CancellationTokenSource? _sourceCTS; + public TimeoutAsyncIterator(IAsyncEnumerable source, TimeSpan timeout) { _source = source; @@ -55,6 +57,11 @@ public override async ValueTask DisposeAsync() await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } + if (_sourceCTS != null) + { + _sourceCTS.Dispose(); + _sourceCTS = null; + } await base.DisposeAsync().ConfigureAwait(false); } @@ -64,7 +71,8 @@ protected override async ValueTask MoveNextCore() switch (_state) { case AsyncIteratorState.Allocated: - _enumerator = _source.GetAsyncEnumerator(_cancellationToken); + _sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken); + _enumerator = _source.GetAsyncEnumerator(_sourceCTS.Token); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; @@ -74,7 +82,7 @@ protected override async ValueTask MoveNextCore() if (!moveNext.IsCompleted) { - using var delayCts = new CancellationTokenSource(); + using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken); var delay = Task.Delay(_timeout, delayCts.Token); @@ -98,6 +106,8 @@ protected override async ValueTask MoveNextCore() _loserTask = next.ContinueWith((_, state) => ((IAsyncDisposable)state!).DisposeAsync().AsTask(), _enumerator); + _sourceCTS!.Cancel(); + throw new TimeoutException(); }