Skip to content

Commit

Permalink
Avoid direct descriptor comparison, so the fix can be shown
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorKrolic committed Apr 17, 2024
1 parent b07c100 commit aba7cf1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the MIT license. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
Expand All @@ -26,7 +27,9 @@ public class CompareSymbolsCorrectlyAnalyzer : DiagnosticAnalyzer
private static readonly string s_symbolTypeFullName = typeof(ISymbol).FullName;
private const string s_symbolEqualsName = nameof(ISymbol.Equals);
private const string s_HashCodeCombineName = "Combine";

public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.SymbolEqualityComparer";
public const string RulePropertyName = "Rule";

public static readonly DiagnosticDescriptor EqualityRule = new(
DiagnosticIds.CompareSymbolsCorrectlyRuleId,
Expand Down Expand Up @@ -142,7 +145,7 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
return;
}

context.ReportDiagnostic(binary.Syntax.GetLocation().CreateDiagnostic(EqualityRule));
context.ReportDiagnostic(binary.Syntax.GetLocation().CreateDiagnostic(EqualityRule, MakeProperties(nameof(EqualityRule))));
}

private static void HandleInvocationOperation(
Expand All @@ -163,7 +166,7 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
// without the correct arguments
if (IsSymbolType(invocationOperation.Instance, symbolType))
{
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(GetHashCodeRule));
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(GetHashCodeRule, MakeProperties(nameof(GetHashCode))));
}

break;
Expand All @@ -174,7 +177,7 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
var parameters = invocationOperation.Arguments;
if (parameters.All(p => IsSymbolType(p.Value, symbolType)))
{
context.ReportDiagnostic(invocationOperation.Syntax.GetLocation().CreateDiagnostic(EqualityRule));
context.ReportDiagnostic(invocationOperation.Syntax.GetLocation().CreateDiagnostic(EqualityRule, MakeProperties(nameof(EqualityRule))));
}
}

Expand All @@ -187,7 +190,7 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
systemHashCodeType.Equals(method.ContainingType, SymbolEqualityComparer.Default) &&
invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, symbolType)))
{
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(GetHashCodeRule));
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(GetHashCodeRule, MakeProperties(nameof(GetHashCodeRule))));
}

break;
Expand All @@ -199,7 +202,7 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
IsBehavingOnSymbolType(method, symbolType) &&
!invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparer)))
{
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(CollectionRule));
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(CollectionRule, MakeProperties(nameof(CollectionRule))));
}

break;
Expand Down Expand Up @@ -249,7 +252,7 @@ static bool IsBehavingOnSymbolType(IMethodSymbol? method, INamedTypeSymbol symbo
IsSymbolType(createdType.TypeArguments[0], symbolType) &&
!objectCreation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparerType)))
{
context.ReportDiagnostic(objectCreation.CreateDiagnostic(CollectionRule));
context.ReportDiagnostic(objectCreation.CreateDiagnostic(CollectionRule, MakeProperties(nameof(CollectionRule))));
}
}

Expand Down Expand Up @@ -354,6 +357,9 @@ void AddOrUpdate(string methodName, INamedTypeSymbol typeSymbol)
}
}

private static ImmutableDictionary<string, string?> MakeProperties(string rule)
=> ImmutableDictionary.CreateRange([new KeyValuePair<string, string?>(RulePropertyName, rule)]);

public static bool UseSymbolEqualityComparer(Compilation compilation)
=> compilation.GetOrCreateTypeByMetadataName(SymbolEqualityComparerName) is object;
}
Expand Down
Expand Up @@ -32,23 +32,27 @@ public sealed override Task RegisterCodeFixesAsync(CodeFixContext context)
{
foreach (var diagnostic in context.Diagnostics)
{
if (diagnostic.Descriptor == CompareSymbolsCorrectlyAnalyzer.EqualityRule)
if (diagnostic.Properties.TryGetValue(CompareSymbolsCorrectlyAnalyzer.RulePropertyName, out var rule))
{
context.RegisterCodeFix(
CodeAction.Create(
CodeAnalysisDiagnosticsResources.CompareSymbolsCorrectlyCodeFix,
cancellationToken => ConvertToEqualsAsync(context.Document, diagnostic.Location.SourceSpan, cancellationToken),
equivalenceKey: nameof(CompareSymbolsCorrectlyFix)),
diagnostic);
}
else if (diagnostic.Descriptor == CompareSymbolsCorrectlyAnalyzer.CollectionRule)
{
context.RegisterCodeFix(
CodeAction.Create(
CodeAnalysisDiagnosticsResources.CompareSymbolsCorrectlyCodeFix,
cancellationToken => CallOverloadWithEqualityComparerAsync(context.Document, diagnostic.Location.SourceSpan, cancellationToken),
equivalenceKey: nameof(CompareSymbolsCorrectlyFix)),
diagnostic);
switch (rule)
{
case nameof(CompareSymbolsCorrectlyAnalyzer.EqualityRule):
context.RegisterCodeFix(
CodeAction.Create(
CodeAnalysisDiagnosticsResources.CompareSymbolsCorrectlyCodeFix,
cancellationToken => ConvertToEqualsAsync(context.Document, diagnostic.Location.SourceSpan, cancellationToken),
equivalenceKey: nameof(CompareSymbolsCorrectlyFix)),
diagnostic);
break;
case nameof(CompareSymbolsCorrectlyAnalyzer.CollectionRule):
context.RegisterCodeFix(
CodeAction.Create(
CodeAnalysisDiagnosticsResources.CompareSymbolsCorrectlyCodeFix,
cancellationToken => CallOverloadWithEqualityComparerAsync(context.Document, diagnostic.Location.SourceSpan, cancellationToken),
equivalenceKey: nameof(CompareSymbolsCorrectlyFix)),
diagnostic);
break;
}
}
}

Expand Down

0 comments on commit aba7cf1

Please sign in to comment.