diff --git a/Rx.NET/Source/src/System.Reactive/Linq/Observable/Zip.cs b/Rx.NET/Source/src/System.Reactive/Linq/Observable/Zip.cs index e8258aa6b6..ab11ada92e 100644 --- a/Rx.NET/Source/src/System.Reactive/Linq/Observable/Zip.cs +++ b/Rx.NET/Source/src/System.Reactive/Linq/Observable/Zip.cs @@ -274,6 +274,8 @@ public _(Func resultSelector, IObserver obser _resultSelector = resultSelector; } + int _enumerationInProgress; + private IEnumerator _rightEnumerator; private static readonly IEnumerator DisposedEnumerator = MakeDisposedEnumerator(); @@ -315,37 +317,60 @@ protected override void Dispose(bool disposing) { if (disposing) { - Interlocked.Exchange(ref _rightEnumerator, DisposedEnumerator)?.Dispose(); + if (Interlocked.Increment(ref _enumerationInProgress) == 1) + { + Interlocked.Exchange(ref _rightEnumerator, DisposedEnumerator)?.Dispose(); + } } base.Dispose(disposing); } public override void OnNext(TFirst value) { - bool hasNext; - try + var currentEnumerator = Volatile.Read(ref _rightEnumerator); + if (currentEnumerator == DisposedEnumerator) { - hasNext = _rightEnumerator.MoveNext(); + return; } - catch (Exception ex) + if (Interlocked.Increment(ref _enumerationInProgress) != 1) { - ForwardOnError(ex); return; } - - if (hasNext) + bool hasNext; + TSecond right = default; + var wasDisposed = false; + try { - TSecond right; try { - right = _rightEnumerator.Current; + hasNext = currentEnumerator.MoveNext(); + if (hasNext) + { + right = currentEnumerator.Current; + } } - catch (Exception ex) + finally { - ForwardOnError(ex); - return; + if (Interlocked.Decrement(ref _enumerationInProgress) != 0) + { + Interlocked.Exchange(ref _rightEnumerator, DisposedEnumerator)?.Dispose(); + wasDisposed = true; + } } + } + catch (Exception ex) + { + ForwardOnError(ex); + return; + } + if (wasDisposed) + { + return; + } + + if (hasNext) + { TResult result; try { diff --git a/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/ZipTest.cs b/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/ZipTest.cs index e8091e3b29..beb804a51c 100644 --- a/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/ZipTest.cs +++ b/Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/ZipTest.cs @@ -3,10 +3,13 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Reactive; +using System.Reactive.Disposables; using System.Reactive.Linq; +using System.Reactive.Subjects; using System.Threading; using Microsoft.Reactive.Testing; using ReactiveTests.Dummies; @@ -4054,6 +4057,113 @@ public void ZipWithEnumerable_SelectorThrows() ); } + [Fact] + public void ZipWithEnumerable_NoAsyncDisposeOnMoveNext() + { + var source = new Subject(); + + var disposable = new SingleAssignmentDisposable(); + + var other = new MoveNextDisposeDetectEnumerable(disposable, true); + + disposable.Disposable = source.Zip(other, (a, b) => a + b).Subscribe(); + + source.OnNext(1); + + Assert.True(other.IsDisposed); + Assert.False(other.DisposedWhileMoveNext); + Assert.False(other.DisposedWhileCurrent); + } + + [Fact] + public void ZipWithEnumerable_NoAsyncDisposeOnCurrent() + { + var source = new Subject(); + + var disposable = new SingleAssignmentDisposable(); + + var other = new MoveNextDisposeDetectEnumerable(disposable, false); + + disposable.Disposable = source.Zip(other, (a, b) => a + b).Subscribe(); + + source.OnNext(1); + + Assert.True(other.IsDisposed); + Assert.False(other.DisposedWhileMoveNext); + Assert.False(other.DisposedWhileCurrent); + } + + private class MoveNextDisposeDetectEnumerable : IEnumerable, IEnumerator + { + readonly IDisposable _disposable; + + readonly bool _disposeOnMoveNext; + + private bool _moveNextRunning; + + private bool _currentRunning; + + internal bool DisposedWhileMoveNext; + + internal bool DisposedWhileCurrent; + + internal bool IsDisposed; + + internal MoveNextDisposeDetectEnumerable(IDisposable disposable, bool disposeOnMoveNext) + { + _disposable = disposable; + _disposeOnMoveNext = disposeOnMoveNext; + } + public int Current + { + get + { + _currentRunning = true; + if (!_disposeOnMoveNext) + { + _disposable.Dispose(); + } + _currentRunning = false; + return 0; + } + } + + object IEnumerator.Current => Current; + + public void Dispose() + { + DisposedWhileMoveNext = _moveNextRunning; + DisposedWhileCurrent = _currentRunning; + IsDisposed = true; + } + + public IEnumerator GetEnumerator() + { + return this; + } + + public bool MoveNext() + { + _moveNextRunning = true; + if (_disposeOnMoveNext) + { + _disposable.Dispose(); + } + _moveNextRunning = false; + return true; + } + + public void Reset() + { + throw new NotSupportedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this; + } + } + private IEnumerable EnumerableNever(ManualResetEvent evt) { evt.WaitOne();