Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SynchronizationContext.GetAwaiter extension method #994

Merged
merged 1 commit into from Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/Microsoft.VisualStudio.Threading/AwaitExtensions.cs
Expand Up @@ -31,6 +31,20 @@ public static TaskSchedulerAwaiter GetAwaiter(this TaskScheduler scheduler)
return new TaskSchedulerAwaiter(scheduler);
}

/// <summary>
/// Gets an awaiter that schedules continuations on the specified <see cref="SynchronizationContext"/>.
/// </summary>
/// <param name="synchronizationContext">The synchronization context used to execute continuations.</param>
/// <returns>An awaitable.</returns>
/// <remarks>
/// The awaiter that is returned will <em>always</em> result in yielding, even if already executing within the specified <paramref name="synchronizationContext"/>.
/// </remarks>
public static SynchronizationContextAwaiter GetAwaiter(this SynchronizationContext synchronizationContext)
{
Requires.NotNull(synchronizationContext, nameof(synchronizationContext));
return new SynchronizationContextAwaiter(synchronizationContext);
}

/// <summary>
/// Gets an awaitable that schedules continuations on the specified scheduler.
/// </summary>
Expand Down Expand Up @@ -442,6 +456,71 @@ public void GetResult()
}
}

/// <summary>
/// An awaiter returned from <see cref="GetAwaiter(SynchronizationContext)"/>.
/// </summary>
public readonly struct SynchronizationContextAwaiter : ICriticalNotifyCompletion
{
private static readonly SendOrPostCallback SyncContextDelegate = s => ((Action)s!)();

/// <summary>
/// The context for continuations.
/// </summary>
private readonly SynchronizationContext syncContext;

/// <summary>
/// Initializes a new instance of the <see cref="SynchronizationContextAwaiter"/> struct.
/// </summary>
/// <param name="syncContext">The context for continuations.</param>
public SynchronizationContextAwaiter(SynchronizationContext syncContext)
{
this.syncContext = syncContext;
}

/// <summary>
/// Gets a value indicating whether no yield is necessary.
/// </summary>
/// <value>Always returns <see langword="false"/>.</value>
public bool IsCompleted => false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever want to short-cut this awaiter if the current SynchorizationContext matches what we want here?


/// <summary>
/// Schedules a continuation to execute using the specified <see cref="SynchronizationContext"/>.
/// </summary>
/// <param name="continuation">The delegate to invoke.</param>
public void OnCompleted(Action continuation) => this.syncContext.Post(SyncContextDelegate, continuation);

/// <summary>
/// Schedules a continuation to execute using the specified <see cref="SynchronizationContext"/>
/// without capturing the <see cref="ExecutionContext"/>.
/// </summary>
/// <param name="continuation">The action.</param>
public void UnsafeOnCompleted(Action continuation)
{
#if NETFRAMEWORK // Only bother suppressing flow on .NET Framework where the perf would improve from doing so.
if (ExecutionContext.IsFlowSuppressed())
{
this.syncContext.Post(SyncContextDelegate, continuation);
}
else
{
using (ExecutionContext.SuppressFlow())
{
this.syncContext.Post(SyncContextDelegate, continuation);
}
}
#else
this.syncContext.Post(SyncContextDelegate, continuation);
#endif
}

/// <summary>
/// Does nothing.
/// </summary>
public void GetResult()
{
}
}

/// <summary>
/// An awaitable that will always lead the calling async method to yield,
/// then immediately resume, possibly on the original <see cref="SynchronizationContext"/>.
Expand Down
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
@@ -1,9 +1,16 @@
Microsoft.VisualStudio.Threading.AsyncQueue<T>.ToArray() -> T[]!
Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.AsyncReaderWriterLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics = false) -> void
Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock<TMoniker, TResource>.AsyncReaderWriterResourceLock(Microsoft.VisualStudio.Threading.JoinableTaskContext? joinableTaskContext, bool captureDiagnostics) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.GetResult() -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.IsCompleted.get -> bool
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.OnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.SynchronizationContextAwaiter(System.Threading.SynchronizationContext! syncContext) -> void
Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void
Microsoft.VisualStudio.Threading.JoinableTaskContext.IsMainThreadMaybeBlocked() -> bool
Microsoft.VisualStudio.Threading.SemaphoreFaultedException
Microsoft.VisualStudio.Threading.SemaphoreFaultedException.SemaphoreFaultedException() -> void
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException
Microsoft.VisualStudio.Threading.IllegalSemaphoreUsageException.IllegalSemaphoreUsageException(string! message) -> void
static Microsoft.VisualStudio.Threading.AwaitExtensions.GetAwaiter(this System.Threading.SynchronizationContext! synchronizationContext) -> Microsoft.VisualStudio.Threading.AwaitExtensions.SynchronizationContextAwaiter
virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterLock.DeadlockCheckTimeout.get -> System.TimeSpan
Expand Up @@ -566,6 +566,53 @@ public async Task ConfigureAwaitForAggregateException_InnerCanceledAndFaulted()
}
}

[Fact]
public void GetAwaiter_SynchronizationContext_ValidatesArgs()
{
Assert.Throws<ArgumentNullException>(() => AwaitExtensions.GetAwaiter((SynchronizationContext)null!));
}

[Fact]
public async Task SyncContext_Awaiter()
{
TaskCompletionSource<SingleThreadedSynchronizationContext> syncContextSource = new();
SingleThreadedSynchronizationContext.Frame frame = new();
Thread? otherThread = null;
Task otherThreadTask = Task.Run(delegate
{
SingleThreadedSynchronizationContext syncContext;
try
{
syncContext = new();
otherThread = Thread.CurrentThread;
syncContextSource.SetResult(syncContext);
}
catch (Exception ex)
{
syncContextSource.SetException(ex);
throw;
}

syncContext.PushFrame(frame);
});

try
{
SynchronizationContext context = await syncContextSource.Task;
Assert.NotSame(Thread.CurrentThread, otherThread);
await context;
Assert.Same(Thread.CurrentThread, otherThread);
await Task.Yield();
Assert.Same(Thread.CurrentThread, otherThread);
await TaskScheduler.Default.SwitchTo(alwaysYield: true);
Assert.NotSame(Thread.CurrentThread, otherThread);
}
finally
{
frame.Continue = false;
}
}

[SkippableFact]
public async Task AwaitRegKeyChange()
{
Expand Down