diff --git a/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs b/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs index 2fe9910c3..595cde38d 100644 --- a/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs +++ b/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs @@ -31,6 +31,20 @@ public static TaskSchedulerAwaiter GetAwaiter(this TaskScheduler scheduler) return new TaskSchedulerAwaiter(scheduler); } + /// + /// Gets an awaiter that schedules continuations on the specified . + /// + /// The synchronization context used to execute continuations. + /// An awaitable. + /// + /// The awaiter that is returned will always result in yielding, even if already executing within the specified . + /// + public static SynchronizationContextAwaiter GetAwaiter(this SynchronizationContext synchronizationContext) + { + Requires.NotNull(synchronizationContext, nameof(synchronizationContext)); + return new SynchronizationContextAwaiter(synchronizationContext); + } + /// /// Gets an awaitable that schedules continuations on the specified scheduler. /// @@ -442,6 +456,71 @@ public void GetResult() } } + /// + /// An awaiter returned from . + /// + public readonly struct SynchronizationContextAwaiter : ICriticalNotifyCompletion + { + private static readonly SendOrPostCallback SyncContextDelegate = s => ((Action)s!)(); + + /// + /// The context for continuations. + /// + private readonly SynchronizationContext syncContext; + + /// + /// Initializes a new instance of the struct. + /// + /// The context for continuations. + public SynchronizationContextAwaiter(SynchronizationContext syncContext) + { + this.syncContext = syncContext; + } + + /// + /// Gets a value indicating whether no yield is necessary. + /// + /// Always returns . + public bool IsCompleted => false; + + /// + /// Schedules a continuation to execute using the specified . + /// + /// The delegate to invoke. + public void OnCompleted(Action continuation) => this.syncContext.Post(SyncContextDelegate, continuation); + + /// + /// Schedules a continuation to execute using the specified + /// without capturing the . + /// + /// The action. + public void UnsafeOnCompleted(Action continuation) + { +#if NETFRAMEWORK // Only bother suppressing flow on .NET Framework where the perf would improve from doing so. + if (ExecutionContext.IsFlowSuppressed()) + { + this.syncContext.Post(SyncContextDelegate, continuation); + } + else + { + using (ExecutionContext.SuppressFlow()) + { + this.syncContext.Post(SyncContextDelegate, continuation); + } + } +#else + this.syncContext.Post(SyncContextDelegate, continuation); +#endif + } + + /// + /// Does nothing. + /// + public void GetResult() + { + } + } + /// /// An awaitable that will always lead the calling async method to yield, /// then immediately resume, possibly on the original . diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index ffcf5d4cb..7426766cf 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -1,9 +1,16 @@ Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]! Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool Microsoft.VisualStudio.Threading.SemaphoreFaultedException Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void +static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan diff --git a/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt index ffcf5d4cb..7426766cf 100644 --- a/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt @@ -1,9 +1,16 @@ Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]! Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool Microsoft.VisualStudio.Threading.SemaphoreFaultedException Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void +static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index ffcf5d4cb..7426766cf 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -1,9 +1,16 @@ Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]! Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void +Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool Microsoft.VisualStudio.Threading.SemaphoreFaultedException Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void +static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan diff --git a/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs index c293bc189..660572e55 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs @@ -566,6 +566,53 @@ public async Task ConfigureAwaitForAggregateException_InnerCanceledAndFaulted() } } + [Fact] + public void GetAwaiter_SynchronizationContext_ValidatesArgs() + { + Assert.Throws(() => AwaitExtensions.GetAwaiter((SynchronizationContext)null!)); + } + + [Fact] + public async Task SyncContext_Awaiter() + { + TaskCompletionSource syncContextSource = new(); + SingleThreadedSynchronizationContext.Frame frame = new(); + Thread? otherThread = null; + Task otherThreadTask = Task.Run(delegate + { + SingleThreadedSynchronizationContext syncContext; + try + { + syncContext = new(); + otherThread = Thread.CurrentThread; + syncContextSource.SetResult(syncContext); + } + catch (Exception ex) + { + syncContextSource.SetException(ex); + throw; + } + + syncContext.PushFrame(frame); + }); + + try + { + SynchronizationContext context = await syncContextSource.Task; + Assert.NotSame(Thread.CurrentThread, otherThread); + await context; + Assert.Same(Thread.CurrentThread, otherThread); + await Task.Yield(); + Assert.Same(Thread.CurrentThread, otherThread); + await TaskScheduler.Default.SwitchTo(alwaysYield: true); + Assert.NotSame(Thread.CurrentThread, otherThread); + } + finally + { + frame.Continue = false; + } + } + [SkippableFact] public async Task AwaitRegKeyChange() {