diff --git a/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/CompareSymbolsCorrectlyAnalyzer.cs b/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/CompareSymbolsCorrectlyAnalyzer.cs index 6e10c4e197..6ecc1217d5 100644 --- a/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/CompareSymbolsCorrectlyAnalyzer.cs +++ b/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/CompareSymbolsCorrectlyAnalyzer.cs @@ -19,7 +19,7 @@ public class CompareSymbolsCorrectlyAnalyzer : DiagnosticAnalyzer private static readonly string s_symbolTypeFullName = typeof(ISymbol).FullName; private const string s_symbolEqualsName = nameof(ISymbol.Equals); - public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.Shared.Utilities.SymbolEquivalenceComparer"; + public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.SymbolEqualityComparer"; public static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor( DiagnosticIds.CompareSymbolsCorrectlyRuleId, @@ -47,12 +47,11 @@ public override void Initialize(AnalysisContext context) return; } - // Check that the s_symbolEqualityComparerName exists and can be used, otherwise the Roslyn version + // Check that the EqualityComparer exists and can be used, otherwise the Roslyn version // being used it too low to need the change for method references - var symbolEqualityComparerType = context.Compilation.GetTypeByMetadataName(SymbolEqualityComparerName); - var operatorsToHandle = symbolEqualityComparerType is null ? - new[] { OperationKind.BinaryOperator } : - new[] { OperationKind.BinaryOperator, OperationKind.MethodReference }; + var operatorsToHandle = UseSymbolEqualityComparer(context.Compilation) ? + new[] { OperationKind.BinaryOperator, OperationKind.Invocation } : + new[] { OperationKind.BinaryOperator }; context.RegisterOperationAction(context => HandleOperation(in context, symbolType), operatorsToHandle); }); @@ -62,11 +61,11 @@ private void HandleOperation(in OperationAnalysisContext context, INamedTypeSymb { if (context.Operation is IBinaryOperation) { - HandleBinaryOperator(context, symbolType); + HandleBinaryOperator(in context, symbolType); } - if (context.Operation is IMethodReferenceOperation) + else if (context.Operation is IInvocationOperation) { - HandleMethodReferenceOperation(context, symbolType); + HandleInvocationOperation(in context, symbolType); } } @@ -113,19 +112,24 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN context.ReportDiagnostic(binary.Syntax.GetLocation().CreateDiagnostic(Rule)); } - private static void HandleMethodReferenceOperation(OperationAnalysisContext context, INamedTypeSymbol symbolType) + private static void HandleInvocationOperation(in OperationAnalysisContext context, INamedTypeSymbol symbolType) { - var methodReference = (IMethodReferenceOperation)context.Operation; + var invocationOperation = (IInvocationOperation)context.Operation; + var method = invocationOperation.TargetMethod; + if (method.Name != s_symbolEqualsName) + { + return; + } - if (methodReference.Instance != null && !IsSymbolType(methodReference.Instance, symbolType)) + if (invocationOperation.Instance != null && !IsSymbolType(invocationOperation.Instance, symbolType)) { return; } - var parameters = methodReference.Method.Parameters; - if (methodReference.Method.Name == s_symbolEqualsName && parameters.All(p => IsSymbolType(p.Type, symbolType))) + var parameters = invocationOperation.Arguments; + if (parameters.All(p => IsSymbolType(p.Value, symbolType))) { - context.ReportDiagnostic(methodReference.Syntax.GetLocation().CreateDiagnostic(Rule)); + context.ReportDiagnostic(invocationOperation.Syntax.GetLocation().CreateDiagnostic(Rule)); } } @@ -197,5 +201,8 @@ private static bool IsExplicitCastToObject(IOperation operation) return conversion.Type?.SpecialType == SpecialType.System_Object; } + + public static bool UseSymbolEqualityComparer(Compilation compilation) + => compilation.GetTypeByMetadataName(SymbolEqualityComparerName) is object; } } diff --git a/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/Fixers/CompareSymbolsCorrectlyFix.cs b/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/Fixers/CompareSymbolsCorrectlyFix.cs index c5e872a9ae..986faaba52 100644 --- a/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/Fixers/CompareSymbolsCorrectlyFix.cs +++ b/src/Microsoft.CodeAnalysis.Analyzers/Core/MetaAnalyzers/Fixers/CompareSymbolsCorrectlyFix.cs @@ -2,6 +2,7 @@ using System.Collections.Immutable; using System.Composition; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis.CodeActions; @@ -16,8 +17,6 @@ namespace Microsoft.CodeAnalysis.Analyzers.MetaAnalyzers.Fixers [Shared] public class CompareSymbolsCorrectlyFix : CodeFixProvider { - private const string s_equalityComparerIdentifier = "Microsoft.CodeAnalysis.SymbolEqualityComparer"; - public override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(CompareSymbolsCorrectlyAnalyzer.Rule.Id); public override FixAllProvider GetFixAllProvider() @@ -48,14 +47,14 @@ private async Task ConvertToEqualsAsync(Document document, TextSpan so return rawOperation switch { IBinaryOperation binaryOperation => await ConvertToEqualsAsync(document, semanticModel, binaryOperation, cancellationToken).ConfigureAwait(false), - IMethodReferenceOperation methodReferenceOperation => await EnsureEqualsCorrectAsync(document, semanticModel, methodReferenceOperation, cancellationToken).ConfigureAwait(false), + IInvocationOperation invocationOperation => await EnsureEqualsCorrectAsync(document, semanticModel, invocationOperation, cancellationToken).ConfigureAwait(false), _ => document }; } - private static async Task EnsureEqualsCorrectAsync(Document document, SemanticModel semanticModel, IMethodReferenceOperation methodReference, CancellationToken cancellationToken) + private static async Task EnsureEqualsCorrectAsync(Document document, SemanticModel semanticModel, IInvocationOperation invocationOperation, CancellationToken cancellationToken) { - if (!UseEqualityComparer(semanticModel.Compilation)) + if (!CompareSymbolsCorrectlyAnalyzer.UseSymbolEqualityComparer(semanticModel.Compilation)) { return document; } @@ -63,8 +62,20 @@ private static async Task EnsureEqualsCorrectAsync(Document document, var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); var generator = editor.Generator; - var replacement = generator.AddParameters(methodReference.Syntax, new[] { generator.IdentifierName(s_equalityComparerIdentifier) }); - editor.ReplaceNode(methodReference.Syntax, replacement.WithTriviaFrom(methodReference.Syntax)); + if (invocationOperation.Instance is null) + { + var replacement = generator.InvocationExpression( + GetEqualityComparerDefaultEquals(generator), + invocationOperation.Arguments.Select(argument => argument.Syntax).ToImmutableArray()); + + editor.ReplaceNode(invocationOperation.Syntax, replacement.WithTriviaFrom(invocationOperation.Syntax)); + } + else + { + var replacement = generator.AddParameters(invocationOperation.Syntax, new[] { GetEqualityComparerDefault(generator) }); + editor.ReplaceNode(invocationOperation.Syntax, replacement.WithTriviaFrom(invocationOperation.Syntax)); + } + return editor.GetChangedDocument(); } @@ -75,15 +86,11 @@ private static async Task ConvertToEqualsAsync(Document document, Sema var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); var generator = editor.Generator; - var replacement = UseEqualityComparer(semanticModel.Compilation) switch + var replacement = CompareSymbolsCorrectlyAnalyzer.UseSymbolEqualityComparer(semanticModel.Compilation) switch { - true => + true => generator.InvocationExpression( - generator.MemberAccessExpression( - generator.MemberAccessExpression( - generator.DottedName(s_equalityComparerIdentifier), - "Default"), - nameof(object.Equals)), + GetEqualityComparerDefaultEquals(generator), binaryOperation.LeftOperand.Syntax.WithoutLeadingTrivia(), binaryOperation.RightOperand.Syntax.WithoutTrailingTrivia()), @@ -105,9 +112,12 @@ private static async Task ConvertToEqualsAsync(Document document, Sema return editor.GetChangedDocument(); } - private static bool UseEqualityComparer(Compilation compilation) - { - return compilation.GetTypeByMetadataName(s_equalityComparerIdentifier) is object; - } + private static SyntaxNode GetEqualityComparerDefaultEquals(SyntaxGenerator generator) + => generator.MemberAccessExpression( + GetEqualityComparerDefault(generator), + nameof(object.Equals)); + + private static SyntaxNode GetEqualityComparerDefault(SyntaxGenerator generator) + => generator.MemberAccessExpression(generator.DottedName(CompareSymbolsCorrectlyAnalyzer.SymbolEqualityComparerName), "Default"); } } diff --git a/src/Microsoft.CodeAnalysis.Analyzers/UnitTests/MetaAnalyzers/CompareSymbolsCorrectlyTests.cs b/src/Microsoft.CodeAnalysis.Analyzers/UnitTests/MetaAnalyzers/CompareSymbolsCorrectlyTests.cs index d137900982..cd236a79a4 100644 --- a/src/Microsoft.CodeAnalysis.Analyzers/UnitTests/MetaAnalyzers/CompareSymbolsCorrectlyTests.cs +++ b/src/Microsoft.CodeAnalysis.Analyzers/UnitTests/MetaAnalyzers/CompareSymbolsCorrectlyTests.cs @@ -508,5 +508,35 @@ End Class await VerifyVB.VerifyAnalyzerAsync(source); } + + + [Theory] + [InlineData(nameof(ISymbol))] + [InlineData(nameof(INamedTypeSymbol))] + public async Task CompareSymbolImplementationWithInterface_EqualsComparison_CSharp(string symbolType) + { + var source = $@" +using Microsoft.CodeAnalysis; +class TestClass {{ + bool Method(ISymbol x, {symbolType} y) {{ + return [|Equals(x, y)|]; + }} +}} +"; + var fixedSource = $@" +using Microsoft.CodeAnalysis; +class TestClass {{ + bool Method(ISymbol x, {symbolType} y) {{ + return SymbolEqualityComparer.Default.Equals(x, y); + }} +}} +"; + + await new VerifyCS.Test + { + TestState = { Sources = { source, SymbolEqualityComparerStubCSharp } }, + FixedState = { Sources = { fixedSource, SymbolEqualityComparerStubCSharp } }, + }.RunAsync(); + } } }