Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More cases to optimize linq #1157

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions ChangeLog.md
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Add SECURITY.md ([#1147](https://github.com/josefpihrt/roslynator/pull/1147))
- Add custom FixAllProvider for [RCS1014](https://github.com/JosefPihrt/Roslynator/blob/main/docs/analyzers/RCS1014.md) ([#1070](https://github.com/JosefPihrt/Roslynator/pull/1070)).
- Support for more linq optimizations ([#1157](https://github.com/josefpihrt/roslynator/pull/1157))

### Fixed

Expand Down
Expand Up @@ -332,8 +332,17 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
InvocationExpressionSyntax invocation = invocationInfo.InvocationExpression;
InvocationExpressionSyntax invocation2 = invocationInfo2.InvocationExpression;

SimpleNameSyntax name = (invocationInfo2.NameText, invocationInfo.NameText) switch
{
("OrderBy", "FirstOrDefault") => (SimpleNameSyntax)ParseName("MinBy"),
("OrderByDescending", "FirstOrDefault") => (SimpleNameSyntax)ParseName("MaxBy"),
("OrderBy", "First") => (SimpleNameSyntax)ParseName("MinBy"),
("OrderByDescending", "First") => (SimpleNameSyntax)ParseName("MaxBy"),
_ => invocationInfo.Name
};

InvocationExpressionSyntax newNode = invocation2.WithExpression(
invocationInfo2.MemberAccessExpression.WithName(invocationInfo.Name.WithTriviaFrom(invocationInfo2.Name)));
invocationInfo2.MemberAccessExpression.WithName(name.WithTriviaFrom(invocationInfo2.Name)));

IEnumerable<SyntaxTrivia> trivia = invocation.DescendantTrivia(TextSpan.FromBounds(invocation2.Span.End, invocation.Span.End));

Expand Down
Expand Up @@ -97,9 +97,11 @@ private static ExpressionSyntax GetNewNode(PrefixUnaryExpressionSyntax logicalNo

SingleParameterLambdaExpressionInfo lambdaInfo = SyntaxInfo.SingleParameterLambdaExpressionInfo(lambdaExpression);

var logicalNot2 = (PrefixUnaryExpressionSyntax)SimplifyLogicalNegationAnalyzer.GetReturnExpression(lambdaInfo.Body).WalkDownParentheses();
ExpressionSyntax logicalNot2 = SimplifyLogicalNegationAnalyzer.GetReturnExpression(lambdaInfo.Body).WalkDownParentheses();

InvocationExpressionSyntax newNode = invocationExpression.ReplaceNode(logicalNot2, logicalNot2.Operand.WithTriviaFrom(logicalNot2));
ExpressionSyntax invertedExperssion = SyntaxLogicalInverter.GetInstance(document).LogicallyInvert(logicalNot2);

InvocationExpressionSyntax newNode = invocationExpression.ReplaceNode(logicalNot2, invertedExperssion.WithTriviaFrom(logicalNot2));

return SyntaxRefactorings.ChangeInvokedMethodName(newNode, (memberAccessExpression.Name.Identifier.ValueText == "All") ? "Any" : "All");
}
Expand Down
2 changes: 2 additions & 0 deletions src/Analyzers/CSharp/Analysis/InvocationExpressionAnalyzer.cs
Expand Up @@ -129,6 +129,7 @@ private static void AnalyzeInvocationExpression(SyntaxNodeAnalysisContext contex

OptimizeLinqMethodCallAnalysis.AnalyzeWhere(context, invocationInfo);
OptimizeLinqMethodCallAnalysis.AnalyzeFirst(context, invocationInfo);
OptimizeLinqMethodCallAnalysis.AnalyzerOrderByAndFirst(context, invocationInfo, shouldThrowIfEmpty: true);
}

break;
Expand Down Expand Up @@ -184,6 +185,7 @@ private static void AnalyzeInvocationExpression(SyntaxNodeAnalysisContext contex
{
OptimizeLinqMethodCallAnalysis.AnalyzeWhere(context, invocationInfo);
OptimizeLinqMethodCallAnalysis.AnalyzeFirstOrDefault(context, invocationInfo);
OptimizeLinqMethodCallAnalysis.AnalyzerOrderByAndFirst(context, invocationInfo, shouldThrowIfEmpty: false);
}

break;
Expand Down
84 changes: 84 additions & 0 deletions src/Analyzers/CSharp/Analysis/OptimizeLinqMethodCallAnalysis.cs
Expand Up @@ -199,6 +199,90 @@ public static void AnalyzeWhere(SyntaxNodeAnalysisContext context, in SimpleMemb
Report(context, invocation, span, checkDirectives: true, properties: properties);
}


// for reference types
// items.OrderBy(selector).FirstOrDefault() >>> items.MaxBy(selector)
// items.OrderByDescending(selector).FirstOrDefault() >>> items.MaxBy(selector)
// for value types:
// items.OrderBy(selector).First() >>> items.MaxBy(selector)
// items.OrderByDescending(selector).First() >>> items.MaxBy(selector)
public static void AnalyzerOrderByAndFirst(SyntaxNodeAnalysisContext context, in SimpleMemberInvocationExpressionInfo invocationInfo, bool shouldThrowIfEmpty)
{
// MinBy / MaxBy are only supported for net6.0 onwards
INamedTypeSymbol enumerableSymbol = context.Compilation.GetTypeByMetadataName("System.Linq.Enumerable");

if (enumerableSymbol.FindMember<IMethodSymbol>("MinBy") is null)
return;

SimpleMemberInvocationExpressionInfo previousInvocationInfo = SyntaxInfo.SimpleMemberInvocationExpressionInfo(invocationInfo.Expression);

if (!previousInvocationInfo.Success)
return;

if (previousInvocationInfo.Arguments.Count != 1)
return;

if (previousInvocationInfo.NameText != "OrderBy" && previousInvocationInfo.NameText != "OrderByDescending")
return;

InvocationExpressionSyntax invocation = invocationInfo.InvocationExpression;

SemanticModel semanticModel = context.SemanticModel;
CancellationToken cancellationToken = context.CancellationToken;

IMethodSymbol methodSymbol = semanticModel.GetExtensionMethodInfo(invocation, cancellationToken).Symbol;

if (methodSymbol is null)
return;

if (!SymbolUtility.IsLinqExtensionOfIEnumerableOfTWithoutParameters(methodSymbol, invocationInfo.NameText))
return;

IMethodSymbol methodSymbol2 = semanticModel.GetExtensionMethodInfo(previousInvocationInfo.InvocationExpression, cancellationToken).Symbol;

if (methodSymbol2 is null)
return;


switch (previousInvocationInfo.NameText)
{
case "OrderBy":
{
if (!SymbolUtility.IsLinqOrderBy(methodSymbol2, allowImmutableArrayExtension: true))
return;

break;
}
case "OrderByDescending":
{
if (!SymbolUtility.IsLinqOrderByDescending(methodSymbol2, allowImmutableArrayExtension: true))
return;

break;
}
default:
{
throw new InvalidOperationException();
}
}

// First throws if no values found. MaxBy/MinBy match this behaviour if TSource is a not reference type.
var lambda = previousInvocationInfo.InvocationExpression.ArgumentList.Arguments[0].Expression;
var delegateType = semanticModel.GetTypeInfo(lambda).ConvertedType;
if (delegateType is not INamedTypeSymbol { TypeKind: TypeKind.Delegate } namedDelegateType)
return;

var tSource = namedDelegateType.TypeArguments.First();

if (tSource.IsReferenceType == shouldThrowIfEmpty)
return;

TextSpan span = TextSpan.FromBounds(previousInvocationInfo.Name.SpanStart, invocation.Span.End);

Report(context, invocation, span, checkDirectives: true, properties: Properties.SimplifyLinqMethodChain);
}


public static void AnalyzeFirstOrDefault(SyntaxNodeAnalysisContext context, in SimpleMemberInvocationExpressionInfo invocationInfo)
{
InvocationExpressionSyntax invocation = invocationInfo.InvocationExpression;
Expand Down
Expand Up @@ -173,9 +173,7 @@ public static void Analyze(SyntaxNodeAnalysisContext context, in SimpleMemberInv
if (!lambdaInfo.Success)
return;

ExpressionSyntax expression = GetReturnExpression(lambdaInfo.Body)?.WalkDownParentheses();

if (expression?.IsKind(SyntaxKind.LogicalNotExpression) != true)
if (GetReturnExpression(lambdaInfo.Body) is null)
return;

IMethodSymbol methodSymbol = context.SemanticModel.GetReducedExtensionMethodInfo(invocationInfo.InvocationExpression, context.CancellationToken).Symbol;
Expand Down
14 changes: 14 additions & 0 deletions src/Core/SymbolUtility.cs
Expand Up @@ -305,6 +305,20 @@ internal static bool IsPropertyOfNullableOfT(ISymbol symbol, string name)
return IsLinqExtensionOfIEnumerableOfTWithPredicate(methodSymbol, "Where", parameterCount: 2, allowImmutableArrayExtension: allowImmutableArrayExtension);
}

internal static bool IsLinqOrderBy(
IMethodSymbol methodSymbol,
bool allowImmutableArrayExtension = false)
{
return IsLinqExtensionOfIEnumerableOfT(methodSymbol, "OrderBy", parameterCount: 2, allowImmutableArrayExtension: allowImmutableArrayExtension);
}

internal static bool IsLinqOrderByDescending(
IMethodSymbol methodSymbol,
bool allowImmutableArrayExtension = false)
{
return IsLinqExtensionOfIEnumerableOfT(methodSymbol, "OrderByDescending", parameterCount: 2, allowImmutableArrayExtension: allowImmutableArrayExtension);
}

internal static bool IsLinqWhereWithIndex(IMethodSymbol methodSymbol)
{
if (!IsLinqExtensionOfIEnumerableOfT(methodSymbol, "Where", parameterCount: 2, allowImmutableArrayExtension: false))
Expand Down
68 changes: 68 additions & 0 deletions src/Tests/Analyzers.Tests/RCS1068SimplifyLogicalNegationTests2.cs
Expand Up @@ -120,6 +120,40 @@ void M()
");
}

[Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.SimplifyLogicalNegation)]
public async Task Test_NotAny4()
{
await VerifyDiagnosticAndFixAsync(@"
using System.Linq;
using System.Collections.Generic;

class C
{
void M()
{
bool f1 = false;
var items = new List<int>();

f1 = [|!items.Any(i => i % 2 == 0)|];
}
}
", @"
using System.Linq;
using System.Collections.Generic;

class C
{
void M()
{
bool f1 = false;
var items = new List<int>();

f1 = items.All(i => i % 2 != 0);
}
}
");
}

[Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.SimplifyLogicalNegation)]
public async Task Test_NotAll()
{
Expand Down Expand Up @@ -225,6 +259,40 @@ void M()
f1 = items.Any<string>(s => s.Equals(s));
}
}
");
}

[Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.SimplifyLogicalNegation)]
public async Task Test_NotAll4()
{
await VerifyDiagnosticAndFixAsync(@"
using System.Linq;
using System.Collections.Generic;

class C
{
void M()
{
bool f1 = false;
var items = new List<int>();

f1 = [|!items.All(i => i % 2 == 0)|];
}
}
", @"
using System.Linq;
using System.Collections.Generic;

class C
{
void M()
{
bool f1 = false;
var items = new List<int>();

f1 = items.Any(i => i % 2 != 0);
}
}
");
}
}
88 changes: 88 additions & 0 deletions src/Tests/Analyzers.Tests/RCS1077OptimizeLinqMethodCallTests.cs
Expand Up @@ -282,6 +282,94 @@ void M()
", source, expected);
}

[Theory, Trait(Traits.Analyzer, DiagnosticIdentifiers.OptimizeLinqMethodCall)]
[InlineData("OrderBy(f => f.Length).FirstOrDefault()", "MinBy(f => f.Length)")]
[InlineData("OrderByDescending(f => f.Length).FirstOrDefault()", "MaxBy(f => f.Length)")]
public async Task Test_CombineOrderByFirstOrDefault(string source, string expected)
{
await VerifyDiagnosticAndFixAsync(@"
using System.Collections.Generic;
using System.Linq;

namespace N
{
class C
{
string M()
{
var items = new List<string>();

return items.[||];
}
}
}", source, expected);
}

[Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.OptimizeLinqMethodCall)]
public async Task Test_CombineOrderByFirstOrDefault_NoDiagnosticIfTsourceIsValueType()
{
await VerifyNoDiagnosticAsync(@"
using System.Collections.Generic;
using System.Linq;

namespace N
{
class C
{
void M()
{
var items = new List<int>();

var y = items.OrderBy(x=>x).FirstOrDefault();
}
}
}");
}

[Theory, Trait(Traits.Analyzer, DiagnosticIdentifiers.OptimizeLinqMethodCall)]
[InlineData("OrderBy(f => f).First()", "MinBy(f => f)")]
[InlineData("OrderByDescending(f => f).First()", "MaxBy(f => f)")]
public async Task Test_CombineOrderByFirst(string source, string expected)
{
await VerifyDiagnosticAndFixAsync(@"
using System.Collections.Generic;
using System.Linq;

namespace N
{
class C
{
int M()
{
var items = new List<int>();

return items.[||];
}
}
}", source, expected);
}

[Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.OptimizeLinqMethodCall)]
public async Task Test_CombineOrderByFirst_NoDiagnosticIfTsourceIsReferenceType()
{
await VerifyNoDiagnosticAsync(@"
using System.Collections.Generic;
using System.Linq;

namespace N
{
class C
{
void M()
{
var items = new List<string>();

var y = items.OrderBy(x=>x.Length).First();
}
}
}");
}

[Theory, Trait(Traits.Analyzer, DiagnosticIdentifiers.OptimizeLinqMethodCall)]
[InlineData(@"Where(f => f.StartsWith(""a"")).Any(f => f.StartsWith(""b""))", @"Any(f => f.StartsWith(""a"") && f.StartsWith(""b""))")]
[InlineData(@"Where((f) => f.StartsWith(""a"")).Any(f => f.StartsWith(""b""))", @"Any((f) => f.StartsWith(""a"") && f.StartsWith(""b""))")]
Expand Down