Skip to content

Commit

Permalink
Add test for Equal -> EqualityComparer
Browse files Browse the repository at this point in the history
  • Loading branch information
ryzngard committed Sep 26, 2019
1 parent 3842562 commit 3bc68a1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 33 deletions.
Expand Up @@ -19,7 +19,7 @@ public class CompareSymbolsCorrectlyAnalyzer : DiagnosticAnalyzer

private static readonly string s_symbolTypeFullName = typeof(ISymbol).FullName;
private const string s_symbolEqualsName = nameof(ISymbol.Equals);
public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.Shared.Utilities.SymbolEquivalenceComparer";
public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.SymbolEqualityComparer";

public static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
DiagnosticIds.CompareSymbolsCorrectlyRuleId,
Expand Down Expand Up @@ -47,12 +47,11 @@ public override void Initialize(AnalysisContext context)
return;
}
// Check that the s_symbolEqualityComparerName exists and can be used, otherwise the Roslyn version
// Check that the EqualityComparer exists and can be used, otherwise the Roslyn version
// being used it too low to need the change for method references
var symbolEqualityComparerType = context.Compilation.GetTypeByMetadataName(SymbolEqualityComparerName);
var operatorsToHandle = symbolEqualityComparerType is null ?
new[] { OperationKind.BinaryOperator } :
new[] { OperationKind.BinaryOperator, OperationKind.MethodReference };
var operatorsToHandle = UseSymbolEqualityComparer(context.Compilation) ?
new[] { OperationKind.BinaryOperator, OperationKind.Invocation } :
new[] { OperationKind.BinaryOperator };
context.RegisterOperationAction(context => HandleOperation(in context, symbolType), operatorsToHandle);
});
Expand All @@ -62,11 +61,11 @@ private void HandleOperation(in OperationAnalysisContext context, INamedTypeSymb
{
if (context.Operation is IBinaryOperation)
{
HandleBinaryOperator(context, symbolType);
HandleBinaryOperator(in context, symbolType);
}
if (context.Operation is IMethodReferenceOperation)
else if (context.Operation is IInvocationOperation)
{
HandleMethodReferenceOperation(context, symbolType);
HandleInvocationOperation(in context, symbolType);
}
}

Expand Down Expand Up @@ -113,19 +112,24 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
context.ReportDiagnostic(binary.Syntax.GetLocation().CreateDiagnostic(Rule));
}

private static void HandleMethodReferenceOperation(OperationAnalysisContext context, INamedTypeSymbol symbolType)
private static void HandleInvocationOperation(in OperationAnalysisContext context, INamedTypeSymbol symbolType)
{
var methodReference = (IMethodReferenceOperation)context.Operation;
var invocationOperation = (IInvocationOperation)context.Operation;
var method = invocationOperation.TargetMethod;
if (method.Name != s_symbolEqualsName)
{
return;
}

if (methodReference.Instance != null && !IsSymbolType(methodReference.Instance, symbolType))
if (invocationOperation.Instance != null && !IsSymbolType(invocationOperation.Instance, symbolType))
{
return;
}

var parameters = methodReference.Method.Parameters;
if (methodReference.Method.Name == s_symbolEqualsName && parameters.All(p => IsSymbolType(p.Type, symbolType)))
var parameters = invocationOperation.Arguments;
if (parameters.All(p => IsSymbolType(p.Value, symbolType)))
{
context.ReportDiagnostic(methodReference.Syntax.GetLocation().CreateDiagnostic(Rule));
context.ReportDiagnostic(invocationOperation.Syntax.GetLocation().CreateDiagnostic(Rule));
}
}

Expand Down Expand Up @@ -197,5 +201,8 @@ private static bool IsExplicitCastToObject(IOperation operation)

return conversion.Type?.SpecialType == SpecialType.System_Object;
}

public static bool UseSymbolEqualityComparer(Compilation compilation)
=> compilation.GetTypeByMetadataName(SymbolEqualityComparerName) is object;
}
}
Expand Up @@ -2,6 +2,7 @@

using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
Expand All @@ -16,8 +17,6 @@ namespace Microsoft.CodeAnalysis.Analyzers.MetaAnalyzers.Fixers
[Shared]
public class CompareSymbolsCorrectlyFix : CodeFixProvider
{
private const string s_equalityComparerIdentifier = "Microsoft.CodeAnalysis.SymbolEqualityComparer";

public override ImmutableArray<string> FixableDiagnosticIds => ImmutableArray.Create(CompareSymbolsCorrectlyAnalyzer.Rule.Id);

public override FixAllProvider GetFixAllProvider()
Expand Down Expand Up @@ -48,23 +47,35 @@ private async Task<Document> ConvertToEqualsAsync(Document document, TextSpan so
return rawOperation switch
{
IBinaryOperation binaryOperation => await ConvertToEqualsAsync(document, semanticModel, binaryOperation, cancellationToken).ConfigureAwait(false),
IMethodReferenceOperation methodReferenceOperation => await EnsureEqualsCorrectAsync(document, semanticModel, methodReferenceOperation, cancellationToken).ConfigureAwait(false),
IInvocationOperation invocationOperation => await EnsureEqualsCorrectAsync(document, semanticModel, invocationOperation, cancellationToken).ConfigureAwait(false),
_ => document
};
}

private static async Task<Document> EnsureEqualsCorrectAsync(Document document, SemanticModel semanticModel, IMethodReferenceOperation methodReference, CancellationToken cancellationToken)
private static async Task<Document> EnsureEqualsCorrectAsync(Document document, SemanticModel semanticModel, IInvocationOperation invocationOperation, CancellationToken cancellationToken)
{
if (!UseEqualityComparer(semanticModel.Compilation))
if (!CompareSymbolsCorrectlyAnalyzer.UseSymbolEqualityComparer(semanticModel.Compilation))
{
return document;
}

var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
var generator = editor.Generator;

var replacement = generator.AddParameters(methodReference.Syntax, new[] { generator.IdentifierName(s_equalityComparerIdentifier) });
editor.ReplaceNode(methodReference.Syntax, replacement.WithTriviaFrom(methodReference.Syntax));
if (invocationOperation.Instance is null)
{
var replacement = generator.InvocationExpression(
GetEqualityComparerDefaultEquals(generator),
invocationOperation.Arguments.Select(argument => argument.Syntax).ToImmutableArray());

editor.ReplaceNode(invocationOperation.Syntax, replacement.WithTriviaFrom(invocationOperation.Syntax));
}
else
{
var replacement = generator.AddParameters(invocationOperation.Syntax, new[] { GetEqualityComparerDefault(generator) });
editor.ReplaceNode(invocationOperation.Syntax, replacement.WithTriviaFrom(invocationOperation.Syntax));
}

return editor.GetChangedDocument();
}

Expand All @@ -75,15 +86,11 @@ private static async Task<Document> ConvertToEqualsAsync(Document document, Sema
var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
var generator = editor.Generator;

var replacement = UseEqualityComparer(semanticModel.Compilation) switch
var replacement = CompareSymbolsCorrectlyAnalyzer.UseSymbolEqualityComparer(semanticModel.Compilation) switch
{
true =>
true =>
generator.InvocationExpression(
generator.MemberAccessExpression(
generator.MemberAccessExpression(
generator.DottedName(s_equalityComparerIdentifier),
"Default"),
nameof(object.Equals)),
GetEqualityComparerDefaultEquals(generator),
binaryOperation.LeftOperand.Syntax.WithoutLeadingTrivia(),
binaryOperation.RightOperand.Syntax.WithoutTrailingTrivia()),

Expand All @@ -105,9 +112,12 @@ private static async Task<Document> ConvertToEqualsAsync(Document document, Sema
return editor.GetChangedDocument();
}

private static bool UseEqualityComparer(Compilation compilation)
{
return compilation.GetTypeByMetadataName(s_equalityComparerIdentifier) is object;
}
private static SyntaxNode GetEqualityComparerDefaultEquals(SyntaxGenerator generator)
=> generator.MemberAccessExpression(
GetEqualityComparerDefault(generator),
nameof(object.Equals));

private static SyntaxNode GetEqualityComparerDefault(SyntaxGenerator generator)
=> generator.MemberAccessExpression(generator.DottedName(CompareSymbolsCorrectlyAnalyzer.SymbolEqualityComparerName), "Default");
}
}
Expand Up @@ -508,5 +508,35 @@ End Class

await VerifyVB.VerifyAnalyzerAsync(source);
}


[Theory]
[InlineData(nameof(ISymbol))]
[InlineData(nameof(INamedTypeSymbol))]
public async Task CompareSymbolImplementationWithInterface_EqualsComparison_CSharp(string symbolType)
{
var source = $@"
using Microsoft.CodeAnalysis;
class TestClass {{
bool Method(ISymbol x, {symbolType} y) {{
return [|Equals(x, y)|];
}}
}}
";
var fixedSource = $@"
using Microsoft.CodeAnalysis;
class TestClass {{
bool Method(ISymbol x, {symbolType} y) {{
return SymbolEqualityComparer.Default.Equals(x, y);
}}
}}
";

await new VerifyCS.Test
{
TestState = { Sources = { source, SymbolEqualityComparerStubCSharp } },
FixedState = { Sources = { fixedSource, SymbolEqualityComparerStubCSharp } },
}.RunAsync();
}
}
}

0 comments on commit 3bc68a1

Please sign in to comment.