diff --git a/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs b/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs
index 2fe9910c3..595cde38d 100644
--- a/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs
+++ b/src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs
@@ -31,6 +31,20 @@ public static TaskSchedulerAwaiter GetAwaiter(this TaskScheduler scheduler)
return new TaskSchedulerAwaiter(scheduler);
}
+ ///
+ /// Gets an awaiter that schedules continuations on the specified .
+ ///
+ /// The synchronization context used to execute continuations.
+ /// An awaitable.
+ ///
+ /// The awaiter that is returned will always result in yielding, even if already executing within the specified .
+ ///
+ public static SynchronizationContextAwaiter GetAwaiter(this SynchronizationContext synchronizationContext)
+ {
+ Requires.NotNull(synchronizationContext, nameof(synchronizationContext));
+ return new SynchronizationContextAwaiter(synchronizationContext);
+ }
+
///
/// Gets an awaitable that schedules continuations on the specified scheduler.
///
@@ -442,6 +456,71 @@ public void GetResult()
}
}
+ ///
+ /// An awaiter returned from .
+ ///
+ public readonly struct SynchronizationContextAwaiter : ICriticalNotifyCompletion
+ {
+ private static readonly SendOrPostCallback SyncContextDelegate = s => ((Action)s!)();
+
+ ///
+ /// The context for continuations.
+ ///
+ private readonly SynchronizationContext syncContext;
+
+ ///
+ /// Initializes a new instance of the struct.
+ ///
+ /// The context for continuations.
+ public SynchronizationContextAwaiter(SynchronizationContext syncContext)
+ {
+ this.syncContext = syncContext;
+ }
+
+ ///
+ /// Gets a value indicating whether no yield is necessary.
+ ///
+ /// Always returns .
+ public bool IsCompleted => false;
+
+ ///
+ /// Schedules a continuation to execute using the specified .
+ ///
+ /// The delegate to invoke.
+ public void OnCompleted(Action continuation) => this.syncContext.Post(SyncContextDelegate, continuation);
+
+ ///
+ /// Schedules a continuation to execute using the specified
+ /// without capturing the .
+ ///
+ /// The action.
+ public void UnsafeOnCompleted(Action continuation)
+ {
+#if NETFRAMEWORK // Only bother suppressing flow on .NET Framework where the perf would improve from doing so.
+ if (ExecutionContext.IsFlowSuppressed())
+ {
+ this.syncContext.Post(SyncContextDelegate, continuation);
+ }
+ else
+ {
+ using (ExecutionContext.SuppressFlow())
+ {
+ this.syncContext.Post(SyncContextDelegate, continuation);
+ }
+ }
+#else
+ this.syncContext.Post(SyncContextDelegate, continuation);
+#endif
+ }
+
+ ///
+ /// Does nothing.
+ ///
+ public void GetResult()
+ {
+ }
+ }
+
///
/// An awaitable that will always lead the calling async method to yield,
/// then immediately resume, possibly on the original .
diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt
index ffcf5d4cb..7426766cf 100644
--- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt
+++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
+static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
diff --git a/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt
index ffcf5d4cb..7426766cf 100644
--- a/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt
+++ b/src/Microsoft.VisualStudio.Threading/netcoreapp3.1/PublicAPI.Unshipped.txt
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
+static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt
index ffcf5d4cb..7426766cf 100644
--- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt
+++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
+Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
+static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
diff --git a/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs
index c293bc189..660572e55 100644
--- a/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs
+++ b/test/Microsoft.VisualStudio.Threading.Tests/AwaitExtensionsTests.cs
@@ -566,6 +566,53 @@ public async Task ConfigureAwaitForAggregateException_InnerCanceledAndFaulted()
}
}
+ [Fact]
+ public void GetAwaiter_SynchronizationContext_ValidatesArgs()
+ {
+ Assert.Throws(() => AwaitExtensions.GetAwaiter((SynchronizationContext)null!));
+ }
+
+ [Fact]
+ public async Task SyncContext_Awaiter()
+ {
+ TaskCompletionSource syncContextSource = new();
+ SingleThreadedSynchronizationContext.Frame frame = new();
+ Thread? otherThread = null;
+ Task otherThreadTask = Task.Run(delegate
+ {
+ SingleThreadedSynchronizationContext syncContext;
+ try
+ {
+ syncContext = new();
+ otherThread = Thread.CurrentThread;
+ syncContextSource.SetResult(syncContext);
+ }
+ catch (Exception ex)
+ {
+ syncContextSource.SetException(ex);
+ throw;
+ }
+
+ syncContext.PushFrame(frame);
+ });
+
+ try
+ {
+ SynchronizationContext context = await syncContextSource.Task;
+ Assert.NotSame(Thread.CurrentThread, otherThread);
+ await context;
+ Assert.Same(Thread.CurrentThread, otherThread);
+ await Task.Yield();
+ Assert.Same(Thread.CurrentThread, otherThread);
+ await TaskScheduler.Default.SwitchTo(alwaysYield: true);
+ Assert.NotSame(Thread.CurrentThread, otherThread);
+ }
+ finally
+ {
+ frame.Continue = false;
+ }
+ }
+
[SkippableFact]
public async Task AwaitRegKeyChange()
{