Skip to content

Commit

Permalink
Merge pull request #994 from AArnott/SyncContextAwaiter
Browse files Browse the repository at this point in the history
Add `SynchronizationContext.GetAwaiter` extension method
  • Loading branch information
AArnott committed Mar 8, 2022
2 parents 1fab666 + f48369e commit 683b45c
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 0 deletions.
79 changes: 79 additions & 0 deletions src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs
Expand Up @@ -31,6 +31,20 @@ public static TaskSchedulerAwaiter GetAwaiter(this TaskScheduler scheduler)
return new TaskSchedulerAwaiter(scheduler);
}

/// <summary>
/// Gets an awaiter that schedules continuations on the specified <see cref="SynchronizationContext"/>.
/// </summary>
/// <param name="synchronizationContext">The synchronization context used to execute continuations.</param>
/// <returns>An awaitable.</returns>
/// <remarks>
/// The awaiter that is returned will <em>always</em> result in yielding, even if already executing within the specified <paramref name="synchronizationContext"/>.
/// </remarks>
public static SynchronizationContextAwaiter GetAwaiter(this SynchronizationContext synchronizationContext)
{
Requires.NotNull(synchronizationContext, nameof(synchronizationContext));
return new SynchronizationContextAwaiter(synchronizationContext);
}

/// <summary>
/// Gets an awaitable that schedules continuations on the specified scheduler.
/// </summary>
Expand Down Expand Up @@ -442,6 +456,71 @@ public void GetResult()
}
}

/// <summary>
/// An awaiter returned from <see cref="GetAwaiter(SynchronizationContext)"/>.
/// </summary>
public readonly struct SynchronizationContextAwaiter : ICriticalNotifyCompletion
{
private static readonly SendOrPostCallback SyncContextDelegate = s => ((Action)s!)();

/// <summary>
/// The context for continuations.
/// </summary>
private readonly SynchronizationContext syncContext;

/// <summary>
/// Initializes a new instance of the <see cref="SynchronizationContextAwaiter"/> struct.
/// </summary>
/// <param name="syncContext">The context for continuations.</param>
public SynchronizationContextAwaiter(SynchronizationContext syncContext)
{
this.syncContext = syncContext;
}

/// <summary>
/// Gets a value indicating whether no yield is necessary.
/// </summary>
/// <value>Always returns <see langword="false"/>.</value>
public bool IsCompleted => false;

/// <summary>
/// Schedules a continuation to execute using the specified <see cref="SynchronizationContext"/>.
/// </summary>
/// <param name="continuation">The delegate to invoke.</param>
public void OnCompleted(Action continuation) => this.syncContext.Post(SyncContextDelegate, continuation);

/// <summary>
/// Schedules a continuation to execute using the specified <see cref="SynchronizationContext"/>
/// without capturing the <see cref="ExecutionContext"/>.
/// </summary>
/// <param name="continuation">The action.</param>
public void UnsafeOnCompleted(Action continuation)
{
#if NETFRAMEWORK // Only bother suppressing flow on .NET Framework where the perf would improve from doing so.
if (ExecutionContext.IsFlowSuppressed())
{
this.syncContext.Post(SyncContextDelegate, continuation);
}
else
{
using (ExecutionContext.SuppressFlow())
{
this.syncContext.Post(SyncContextDelegate, continuation);
}
}
#else
this.syncContext.Post(SyncContextDelegate, continuation);
#endif
}

/// <summary>
/// Does nothing.
/// </summary>
public void GetResult()
{
}
}

/// <summary>
/// An awaitable that will always lead the calling async method to yield,
/// then immediately resume, possibly on the original <see cref="SynchronizationContext"/>.
Expand Down
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
Expand Up @@ -566,6 +566,53 @@ public async Task ConfigureAwaitForAggregateException_InnerCanceledAndFaulted()
}
}

[Fact]
public void GetAwaiter_SynchronizationContext_ValidatesArgs()
{
Assert.Throws<ArgumentNullException>(() => AwaitExtensions.GetAwaiter((SynchronizationContext)null!));
}

[Fact]
public async Task SyncContext_Awaiter()
{
TaskCompletionSource<SingleThreadedSynchronizationContext> syncContextSource = new();
SingleThreadedSynchronizationContext.Frame frame = new();
Thread? otherThread = null;
Task otherThreadTask = Task.Run(delegate
{
SingleThreadedSynchronizationContext syncContext;
try
{
syncContext = new();
otherThread = Thread.CurrentThread;
syncContextSource.SetResult(syncContext);
}
catch (Exception ex)
{
syncContextSource.SetException(ex);
throw;
}

syncContext.PushFrame(frame);
});

try
{
SynchronizationContext context = await syncContextSource.Task;
Assert.NotSame(Thread.CurrentThread, otherThread);
await context;
Assert.Same(Thread.CurrentThread, otherThread);
await Task.Yield();
Assert.Same(Thread.CurrentThread, otherThread);
await TaskScheduler.Default.SwitchTo(alwaysYield: true);
Assert.NotSame(Thread.CurrentThread, otherThread);
}
finally
{
frame.Continue = false;
}
}

[SkippableFact]
public async Task AwaitRegKeyChange()
{
Expand Down

0 comments on commit 683b45c

Please sign in to comment.