Skip to content

Commit

Permalink
Fixes microsoft#693. VSTHRD002 now checks if the task has completed v…
Browse files Browse the repository at this point in the history
…ia Task.WhenAll prior to the problematic member access (e.g. Result/Wait), and that the task variable has not been used in-between in a way that would invalidate this check.
  • Loading branch information
bluetarpmedia committed Jul 12, 2021
1 parent 127e17d commit c063c87
Show file tree
Hide file tree
Showing 3 changed files with 463 additions and 0 deletions.
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.VisualStudio.Threading.Analyzers
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -100,11 +101,192 @@ internal static bool ShouldIgnoreContext(SyntaxNodeAnalysisContext context)
typeReceiver.Name == item.Method.ContainingType.Name &&
typeReceiver.BelongsToNamespace(item.Method.ContainingType.Namespace))
{
if (HasTaskCompleted(context, memberAccessSyntax))
{
return;
}

Location? location = memberAccessSyntax.Name.GetLocation();
context.ReportDiagnostic(Diagnostic.Create(descriptor, location));
}
}
}
}

private static SyntaxNode? GetEnclosingBlock(SyntaxNode node)
{
while (node != null)
{
if (node.IsKind(SyntaxKind.Block))
{
return node;
}

node = node.Parent;
}

return null;
}

private static bool IsVariablePassedToInvocation(InvocationExpressionSyntax invocationExpr, string variableName, bool byRef)
{
ArgumentListSyntax? argList = invocationExpr.ChildNodes().OfType<ArgumentListSyntax>().FirstOrDefault();
if (argList == null)
{
return false;
}

foreach (ArgumentSyntax arg in argList.ChildNodes().OfType<ArgumentSyntax>())
{
// `byRef` includes `out` parameters because they are the same as `ref` except don't require initialization first.
if (byRef && (arg.RefKindKeyword.Kind() != SyntaxKind.RefKeyword && arg.RefKindKeyword.Kind() != SyntaxKind.OutKeyword))
{
continue;
}

IdentifierNameSyntax identiferName = arg.ChildNodes().OfType<IdentifierNameSyntax>().FirstOrDefault();
if (identiferName == null)
{
return false;
}

if (identiferName.Identifier.ValueText == variableName)
{
return true;
}
}

return false;
}

private static bool IsTaskCompletedWithWhenAll(SyntaxNodeAnalysisContext context, InvocationExpressionSyntax invocationExpr, string taskVariableName)
{
// We only care about awaited invocations, because an un-awaited Task.WhenAll will be an error.
if (invocationExpr.Parent is not AwaitExpressionSyntax)
{
return false;
}

IEnumerable<MemberAccessExpressionSyntax>? memberAccessList = invocationExpr.ChildNodes().OfType<MemberAccessExpressionSyntax>();
if (memberAccessList.Count() != 1)
{
return false;
}

MemberAccessExpressionSyntax? memberAccess = memberAccessList.First();

// Does the invocation have the expected `Task.WhenAll` syntax? This is cheaper to verify before looking up its semantic type.
var correctSyntax =
((IdentifierNameSyntax)memberAccess.Expression).Identifier.ValueText == Types.Task.TypeName &&
((IdentifierNameSyntax)memberAccess.Name).Identifier.ValueText == Types.Task.WhenAll;

if (!correctSyntax)
{
return false;
}

// Is this `Task.WhenAll` invocation from the System.Threading.Tasks.Task type?
ITypeSymbol? classType = context.SemanticModel.GetTypeInfo(memberAccess.Expression).Type;
var correctType = classType.Name == Types.Task.TypeName && classType.BelongsToNamespace(Types.Task.Namespace);
if (!correctType)
{
return false;
}

// Is the task variable passed as an argument to `Task.WhenAll`?
return IsVariablePassedToInvocation(invocationExpr, taskVariableName, byRef: false);
}

private static bool HasTaskCompleted(SyntaxNodeAnalysisContext context, MemberAccessExpressionSyntax memberAccessSyntax)
{
SyntaxNode? enclosingBlock = GetEnclosingBlock(memberAccessSyntax);
if (enclosingBlock == null)
{
return false;
}

// Get the task variable name from the problematic member access expression so that we can later try
// and determine if it has been used in a `Task.WhenAll` invocation.
// Examples:
// task1.Result;
// task2.GetAwaiter().GetResult();
string? taskVariableName = null;
ExpressionSyntax parentExpr = memberAccessSyntax.Expression;
while (parentExpr != null)
{
if (parentExpr is IdentifierNameSyntax identifierExpr)
{
taskVariableName = identifierExpr.Identifier.ValueText;
break;
}
else if (parentExpr is MemberAccessExpressionSyntax memberAccessExpr)
{
parentExpr = memberAccessExpr.Expression;
}
else if (parentExpr is InvocationExpressionSyntax invocExpr)
{
parentExpr = invocExpr.Expression;
}
else
{
break;
}
}

if (taskVariableName == null)
{
return false;
}

// Find all `Task.WhenAll` invocations that precede the problematic member access, which are also in the same enclosing block.
IEnumerable<InvocationExpressionSyntax>? taskWhenAllInvocationList =
from invoc in enclosingBlock.DescendantNodes().OfType<InvocationExpressionSyntax>()
where memberAccessSyntax.SpanStart > invoc.Span.End &&
IsTaskCompletedWithWhenAll(context, invoc, taskVariableName)
select invoc;

if (!taskWhenAllInvocationList.Any())
{
return false;
}

// If a `Task.WhenAll` invocation precedes the problematic member access, and the task variable has not been
// invalidated in between, then we consider the task to be completed.
// Example:
// await Task.WhenAll(task1, task2, task3);
// task1 = Task.Run(...); // Invalidates `task1`
// DoSomething(ref task2); // Invalidates `task2`
// task1.Result; // Warn
// task2.Result; // Warn
// task3.Result; // No warning, task3 has not been invalidated in between WhenAll and this problematic member access
foreach (InvocationExpressionSyntax? taskWhenAllInvocation in taskWhenAllInvocationList)
{
// Has the task variable been assigned to a new task?
IEnumerable<AssignmentExpressionSyntax>? assignmentList =
from assign in enclosingBlock.DescendantNodes().OfType<AssignmentExpressionSyntax>()
where assign.SpanStart > taskWhenAllInvocation.Span.End &&
assign.SpanStart < memberAccessSyntax.SpanStart &&
((IdentifierNameSyntax)assign.Left).Identifier.ValueText == taskVariableName
select assign;

if (assignmentList.Any())
{
return false;
}

// Has the task variable been passed by ref to a method?
// If so, we must assume the worst case that the method has assigned it to a new task.
IEnumerable<InvocationExpressionSyntax>? invocationList =
from invoc in enclosingBlock.DescendantNodes().OfType<InvocationExpressionSyntax>()
where invoc.SpanStart > taskWhenAllInvocation.Span.End &&
invoc.SpanStart < memberAccessSyntax.SpanStart &&
IsVariablePassedToInvocation(invoc, taskVariableName, byRef: true)
select invoc;

return !invocationList.Any();
}

return false;
}
}
}
2 changes: 2 additions & 0 deletions src/Microsoft.VisualStudio.Threading.Analyzers/Types.cs
Expand Up @@ -183,6 +183,8 @@ internal static class Task

internal const string CompletedTask = nameof(System.Threading.Tasks.Task.CompletedTask);

internal const string WhenAll = "WhenAll";

internal static readonly IReadOnlyList<string> Namespace = Namespaces.SystemThreadingTasks;
}

Expand Down

0 comments on commit c063c87

Please sign in to comment.