Skip to content

Commit

Permalink
Fix consume ValueTask backed by IValueTaskSource (#2108)
Browse files Browse the repository at this point in the history
* Added AwaitHelper to properly wait for ValueTasks.

* Adjust `AwaitHelper` to allow multiple threads to use it concurrently.

* Changed AwaitHelper to static.

* Add test case to make sure ValueTasks work properly with a race condition between `IsCompleted` and `OnCompleted`.

Changed AwaitHelper to use `ManualResetEventSlim` instead of `Monitor.Wait`.

* Make `ValueTaskWaiter.Wait` generic.

* Compare types directly.
  • Loading branch information
timcassell committed Mar 19, 2024
1 parent 0d30991 commit 7306ee7
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 188 deletions.
14 changes: 5 additions & 9 deletions src/BenchmarkDotNet/Code/DeclarationsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
{
return $"() => {method.Name}().GetAwaiter().GetResult()";
return $"() => BenchmarkDotNet.Helpers.AwaitHelper.GetResult({method.Name}())";
}

return method.Name;
Expand Down Expand Up @@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
{
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";

protected override Type WorkloadMethodReturnType => typeof(void);
}
Expand All @@ -168,11 +166,9 @@ internal class GenericTaskDeclarationsProvider : NonVoidDeclarationsProvider

protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ return BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
}
}
108 changes: 108 additions & 0 deletions src/BenchmarkDotNet/Helpers/AwaitHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using System;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace BenchmarkDotNet.Helpers
{
public static class AwaitHelper
{
private class ValueTaskWaiter
{
// We use thread static field so that each thread uses its own individual callback and reset event.
[ThreadStatic]
private static ValueTaskWaiter ts_current;
internal static ValueTaskWaiter Current => ts_current ??= new ValueTaskWaiter();

// We cache the callback to prevent allocations for memory diagnoser.
private readonly Action awaiterCallback;
private readonly ManualResetEventSlim resetEvent;

private ValueTaskWaiter()
{
resetEvent = new ();
awaiterCallback = resetEvent.Set;
}

internal void Wait<TAwaiter>(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion
{
resetEvent.Reset();
awaiter.UnsafeOnCompleted(awaiterCallback);

// The fastest way to wait for completion is to spin a bit before waiting on the event. This is the same logic that Task.GetAwaiter().GetResult() uses.
var spinner = new SpinWait();
while (!resetEvent.IsSet)
{
if (spinner.NextSpinWillYield)
{
resetEvent.Wait();
return;
}
spinner.SpinOnce();
}
}
}

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public static void GetResult(Task task) => task.GetAwaiter().GetResult();

public static T GetResult<T>(Task<T> task) => task.GetAwaiter().GetResult();

// ValueTask can be backed by an IValueTaskSource that only supports asynchronous awaits,
// so we have to hook up a callback instead of calling .GetAwaiter().GetResult() like we do for Task.
// The alternative is to convert it to Task using .AsTask(), but that causes allocations which we must avoid for memory diagnoser.
public static void GetResult(ValueTask task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
awaiter.GetResult();
}

public static T GetResult<T>(ValueTask<T> task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
return awaiter.GetResult();
}

internal static MethodInfo GetGetResultMethod(Type taskType)
{
if (!taskType.IsGenericType)
{
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Static, null, new Type[1] { taskType }, null);
}

Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
: null;
if (compareType == null)
{
return null;
}
var resultType = taskType
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
.ReturnType;
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Static)
.First(m =>
{
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
Type paramType = m.GetParameters().First().ParameterType;
return paramType.IsGenericType && paramType.GetGenericTypeDefinition() == compareType;
})
.MakeGenericMethod(new[] { resultType });
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using BenchmarkDotNet.Engines;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Expand All @@ -16,28 +18,24 @@ public ConsumableTypeInfo(Type methodReturnType)

OriginMethodReturnType = methodReturnType;

// Please note this code does not support await over extension methods.
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
if (getAwaiterMethod == null)
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
|| (methodReturnType.GetTypeInfo().IsGenericType
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));

if (!IsAwaitable)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
var getResultMethod = getAwaiterMethod
WorkloadMethodReturnType = methodReturnType
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);

if (getResultMethod == null)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
WorkloadMethodReturnType = getResultMethod.ReturnType;
GetAwaiterMethod = getAwaiterMethod;
GetResultMethod = getResultMethod;
}
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
.ReturnType;
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
}

if (WorkloadMethodReturnType == null)
Expand Down Expand Up @@ -74,14 +72,13 @@ public ConsumableTypeInfo(Type methodReturnType)
public Type WorkloadMethodReturnType { get; }
public Type OverheadMethodReturnType { get; }

public MethodInfo? GetAwaiterMethod { get; }
public MethodInfo? GetResultMethod { get; }

public bool IsVoid { get; }
public bool IsByRef { get; }
public bool IsConsumable { get; }
public FieldInfo? WorkloadConsumableField { get; }

public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
public bool IsAwaitable { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ private void DefineFields()

Type argLocalsType;
Type argFieldType;
MethodInfo? opConversion = null;
MethodInfo opConversion = null;
if (parameterType.IsByRef)
{
argLocalsType = parameterType;
Expand Down Expand Up @@ -582,42 +582,28 @@ private MethodInfo EmitWorkloadImplementation(string methodName)
workloadInvokeMethod.ReturnParameter,
args);
args = methodBuilder.GetEmitParameters(args);
var callResultType = consumableInfo.OriginMethodReturnType;
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");

var ilBuilder = methodBuilder.GetILGenerator();

/*
.locals init (
[0] valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>
)
*/
var callResultLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
var awaiterLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);

/*
// return TaskSample(arg0). ... ;
IL_0000: ldarg.0
IL_0001: ldarg.1
IL_0002: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
*/
IL_0026: ldarg.0
IL_0027: ldloc.0
IL_0028: ldloc.1
IL_0029: ldloc.2
IL_002a: ldloc.3
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
*/
if (!Descriptor.WorkloadMethod.IsStatic)
ilBuilder.Emit(OpCodes.Ldarg_0);
ilBuilder.EmitLdargs(args);
ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod);

/*
// ... .GetAwaiter().GetResult();
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
IL_000c: stloc.0
IL_000d: ldloca.s 0
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
*/
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
*/

ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);

/*
IL_0014: ret
Expand Down Expand Up @@ -833,19 +819,6 @@ private MethodBuilder EmitForDisassemblyDiagnoser(string methodName)
var skipFirstArg = workloadMethod.IsStatic;
var argLocals = EmitDeclareArgLocals(ilBuilder, skipFirstArg);

LocalBuilder? callResultLocal = null;
LocalBuilder? awaiterLocal = null;
if (consumableInfo.IsAwaitable)
{
var callResultType = consumableInfo.OriginMethodReturnType;
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");
callResultLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
awaiterLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);
}

consumeEmitter.DeclareDisassemblyDiagnoserLocals(ilBuilder);

var notElevenLabel = ilBuilder.DefineLabel();
Expand All @@ -870,29 +843,27 @@ private MethodBuilder EmitForDisassemblyDiagnoser(string methodName)
EmitLoadArgFieldsToLocals(ilBuilder, argLocals, skipFirstArg);

/*
// return TaskSample(_argField) ... ;
IL_0011: ldarg.0
IL_0012: ldloc.0
IL_0013: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
IL_0018: ret
IL_0026: ldarg.0
IL_0027: ldloc.0
IL_0028: ldloc.1
IL_0029: ldloc.2
IL_002a: ldloc.3
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
*/

if (!workloadMethod.IsStatic)
{
ilBuilder.Emit(OpCodes.Ldarg_0);
}
ilBuilder.EmitLdLocals(argLocals);
ilBuilder.Emit(OpCodes.Call, workloadMethod);

if (consumableInfo.IsAwaitable)
{
/*
// ... .GetAwaiter().GetResult();
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
IL_000c: stloc.0
IL_000d: ldloca.s 0
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
*/
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
*/
ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public BenchmarkActionTask(object instance, MethodInfo method, int unrollFactor)
private void Overhead() { }

// must be kept in sync with TaskDeclarationsProvider.TargetMethodDelegate
private void ExecuteBlocking() => startTaskCallback.Invoke().GetAwaiter().GetResult();
private void ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());

[MethodImpl(CodeGenHelper.AggressiveOptimizationOption)]
private void WorkloadActionUnroll(long repeatCount)
Expand Down Expand Up @@ -165,7 +165,7 @@ public BenchmarkActionTask(object instance, MethodInfo method, int unrollFactor)
private T Overhead() => default;

// must be kept in sync with GenericTaskDeclarationsProvider.TargetMethodDelegate
private T ExecuteBlocking() => startTaskCallback().GetAwaiter().GetResult();
private T ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());

private void InvokeSingleHardcoded() => result = callback();

Expand Down Expand Up @@ -217,7 +217,7 @@ public BenchmarkActionValueTask(object instance, MethodInfo method, int unrollFa
private T Overhead() => default;

// must be kept in sync with GenericTaskDeclarationsProvider.TargetMethodDelegate
private T ExecuteBlocking() => startTaskCallback().GetAwaiter().GetResult();
private T ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());

private void InvokeSingleHardcoded() => result = callback();

Expand Down
18 changes: 3 additions & 15 deletions src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Extensions;
using BenchmarkDotNet.Helpers;
using BenchmarkDotNet.Running;

namespace BenchmarkDotNet.Validators
Expand Down Expand Up @@ -130,21 +131,8 @@ private void TryToGetTaskResult(object result)
return;
}

var returnType = result.GetType();
if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
var asTaskMethod = result.GetType().GetMethod("AsTask");
result = asTaskMethod.Invoke(result, null);
}

if (result is Task task)
{
task.GetAwaiter().GetResult();
}
else if (result is ValueTask valueTask)
{
valueTask.GetAwaiter().GetResult();
}
AwaitHelper.GetGetResultMethod(result.GetType())
?.Invoke(null, new[] { result });
}

private bool TryToSetParamsFields(object benchmarkTypeInstance, List<ValidationError> errors)
Expand Down

0 comments on commit 7306ee7

Please sign in to comment.