From abc8dafb51f53f113297de32d9e1f073e4d8727e Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 18 Mar 2024 13:38:45 +0100 Subject: [PATCH] Precompiled query inner loop source generator Closes #32727 --- .../AnalyzerReleases.Unshipped.md | 5 + src/EFCore.Analyzers/EFCore.Analyzers.csproj | 4 + .../FakeAnalyzerConfigOptionsProvider.cs | 31 + .../Helpers/KnownTypeSymbols.cs | 37 ++ .../Helpers/NullableAttributes.cs | 17 + .../LinqQuerySourceGenerator.cs | 528 ++++++++++++++++++ .../LinqQuerySourceGenerator.props | 5 + .../Properties/AnalyzerStrings.Designer.cs | 18 + .../Properties/AnalyzerStrings.resx | 6 + .../EFCore.Relational.csproj | 1 + src/EFCore/EFCore.csproj | 1 + src/EFCore/Properties/CoreStrings.Designer.cs | 7 + src/EFCore/Properties/CoreStrings.resx | 3 + .../Internal/PrecompiledQuerySafeMarker.cs | 17 + .../Query/QueryTranslationPreprocessor.cs | 13 + .../EFCore.Analyzers.Tests.csproj | 4 +- .../LinqQuerySourceGeneratorTests.cs | 381 +++++++++++++ 17 files changed, 1077 insertions(+), 1 deletion(-) create mode 100644 src/EFCore.Analyzers/Helpers/FakeAnalyzerConfigOptionsProvider.cs create mode 100644 src/EFCore.Analyzers/Helpers/KnownTypeSymbols.cs create mode 100644 src/EFCore.Analyzers/Helpers/NullableAttributes.cs create mode 100644 src/EFCore.Analyzers/LinqQuerySourceGenerator.cs create mode 100644 src/EFCore.Analyzers/LinqQuerySourceGenerator.props create mode 100644 src/EFCore/Query/Internal/PrecompiledQuerySafeMarker.cs create mode 100644 test/EFCore.Analyzers.Tests/LinqQuerySourceGeneratorTests.cs diff --git a/src/EFCore.Analyzers/AnalyzerReleases.Unshipped.md b/src/EFCore.Analyzers/AnalyzerReleases.Unshipped.md index e69de29bb2d..15a6c1f31fd 100644 --- a/src/EFCore.Analyzers/AnalyzerReleases.Unshipped.md +++ b/src/EFCore.Analyzers/AnalyzerReleases.Unshipped.md @@ -0,0 +1,5 @@ +### New Rules + +Rule ID | Category | Severity | Notes +--------|----------|----------|-------------------------------- +EF1003 | USage | Warning | PrecompiledQuerySourceGenerator diff --git a/src/EFCore.Analyzers/EFCore.Analyzers.csproj b/src/EFCore.Analyzers/EFCore.Analyzers.csproj index beacecfce80..bd44f917535 100644 --- a/src/EFCore.Analyzers/EFCore.Analyzers.csproj +++ b/src/EFCore.Analyzers/EFCore.Analyzers.csproj @@ -34,6 +34,10 @@ + + + + diff --git a/src/EFCore.Analyzers/Helpers/FakeAnalyzerConfigOptionsProvider.cs b/src/EFCore.Analyzers/Helpers/FakeAnalyzerConfigOptionsProvider.cs new file mode 100644 index 00000000000..61c006ac1e4 --- /dev/null +++ b/src/EFCore.Analyzers/Helpers/FakeAnalyzerConfigOptionsProvider.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Microsoft.EntityFrameworkCore; + +public sealed class FakeAnalyzerConfigOptionsProvider(params (string, string)[] globalOptions) : AnalyzerConfigOptionsProvider +{ + public override AnalyzerConfigOptions GlobalOptions { get; } = new ConfigOptions(globalOptions); + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) + => GlobalOptions; + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) + => GlobalOptions; + + private sealed class ConfigOptions : AnalyzerConfigOptions + { + private readonly Dictionary _globalOptions; + + public ConfigOptions((string, string)[] globalOptions) + => _globalOptions = globalOptions.ToDictionary(t => t.Item1, t => t.Item2); + + public override bool TryGetValue(string key, [NotNullWhen(true)] out string? value) + => _globalOptions.TryGetValue(key, out value); + } +} + diff --git a/src/EFCore.Analyzers/Helpers/KnownTypeSymbols.cs b/src/EFCore.Analyzers/Helpers/KnownTypeSymbols.cs new file mode 100644 index 00000000000..542646e4c42 --- /dev/null +++ b/src/EFCore.Analyzers/Helpers/KnownTypeSymbols.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; + +// ReSharper disable InconsistentNaming + +namespace Microsoft.EntityFrameworkCore; + +internal sealed class KnownTypeSymbols(Compilation compilation) +{ + public INamedTypeSymbol? IEnumerableOfTType => GetOrResolveType(typeof(IEnumerable<>), ref _IEnumerableOfTType); + private Option _IEnumerableOfTType; + + private INamedTypeSymbol? GetOrResolveType(Type type, ref Option field) + => GetOrResolveType(type.FullName!, ref field); + + private INamedTypeSymbol? GetOrResolveType(string fullyQualifiedName, ref Option field) + { + if (field.HasValue) + { + return field.Value; + } + + // TODO: What to do if the type is not found + var type = compilation.GetTypeByMetadataName(fullyQualifiedName) + ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedName); + field = new(type); + return type; + } + + private readonly struct Option(T value) + { + public readonly bool HasValue = true; + public readonly T Value = value; + } +} diff --git a/src/EFCore.Analyzers/Helpers/NullableAttributes.cs b/src/EFCore.Analyzers/Helpers/NullableAttributes.cs new file mode 100644 index 00000000000..ff1cf24d168 --- /dev/null +++ b/src/EFCore.Analyzers/Helpers/NullableAttributes.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Diagnostics.CodeAnalysis; + +[AttributeUsage(AttributeTargets.Parameter, Inherited = false)] +internal sealed class NotNullWhenAttribute : Attribute +{ + /// Initializes the attribute with the specified return value condition. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + /// Gets the return value condition. + public bool ReturnValue { get; } +} diff --git a/src/EFCore.Analyzers/LinqQuerySourceGenerator.cs b/src/EFCore.Analyzers/LinqQuerySourceGenerator.cs new file mode 100644 index 00000000000..e43b15a99ac --- /dev/null +++ b/src/EFCore.Analyzers/LinqQuerySourceGenerator.cs @@ -0,0 +1,528 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +// TODO: Temporary during development, remove these +#pragma warning disable CS0219 // Variable is assigned but its value is never used +#pragma warning disable CS0162 // Unreachable code detected +#pragma warning disable CS8321 // Local function is declared but never used + +namespace Microsoft.EntityFrameworkCore; + +/// +/// A source generator that identifies queryable LINQ queries, and checks if they're compatible with query precompilation (i.e. they're +/// a static chain of method invocation over an EF DbContext). For compatible queries, the terminating operator gets an interceptor +/// which injects an additional "safe marker" node into the query expression tree (when this marker is absent, runtime compilation will +/// fail). For incompatible queries, a warning diagnostic is reported. +/// +[Generator(LanguageNames.CSharp)] +public class LinqQuerySourceGenerator : IIncrementalGenerator +{ + public const string Id = "EF1003"; + public const string DisableRuntimeCompilationMsbuildProperty = "build_property.EFNukeDynamic"; + + private static readonly DiagnosticDescriptor DynamicQueryDiagnosticDescriptor + // HACK: Work around dotnet/roslyn-analyzers#5890 by not using target-typed new + = new DiagnosticDescriptor( + Id, + title: AnalyzerStrings.DynamicQueryTitle, + messageFormat: AnalyzerStrings.DynamicQueryMessageFormat, + category: "Usage", + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + // TODO: Reference this from the publish-time source generator + private const string InterceptorsNamespace = "Microsoft.EntityFrameworkCore.GeneratedInterceptors"; + + /// + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // TODO: Also allow per-source-file metadata which enables/disables the analysis? + // https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#consume-msbuild-properties-and-metadata + + // TODO: ideally, the terminatingOperators pipeline below wouldn't even run if the source generator isn't enabled (this is an opt-in + // TODO: source generator, at least for now). Does that happen if e.g. this pipeline is on the left side of the Combine() + // TODO: specifically? + var isGeneratorEnabled = + context.AnalyzerConfigOptionsProvider.Select( + (provider, _) => + provider.GlobalOptions.TryGetValue(DisableRuntimeCompilationMsbuildProperty, out var configurationSwitch) + && configurationSwitch == "true"); + + // context.AnalyzerConfigOptionsProvider.Select( + // (provider, _) => + // provider.GlobalOptions.TryGetValue("build_property.EFNukeDynamic", out var nukeDynamicSwitch) + // && nukeDynamicSwitch.Equals("true", StringComparison.OrdinalIgnoreCase)) + // .SelectMany((e, _) => new[] { e }) + // .Where(e => e) + // .Combine(terminatingOperators) + // .Select((t, _) => t.Right); + + var terminatingOperators = context.SyntaxProvider + .CreateSyntaxProvider(IsPossibleTerminatingOperator, ProcessLinqQuery) + .Combine(isGeneratorEnabled) + // Filter out empty operators (i.e. operator candidates that turned out to be LINQ-to-Objects), and also filter out everything + // if the source generator is disabled. + .Where(t => t is { Left.IsEmpty: false, Right: true }) + .Select((t, _) => t.Left) + .WithTrackingName("EF" + nameof(LinqQuerySourceGenerator)); + + // TODO: Currently all interceptors from the entire project go into the same file, which may be a lot to redo every time something + // TODO: changes; maybe generate an interceptor file per source file, to limit the regeneration scope to a single file? + // TODO: Is that possible (how)? + // TODO: Possibly look at GroupWith() in ASP.NET: https://github.com/dotnet/aspnetcore/blob/main/src/Shared/RoslynUtils/IncrementalValuesProviderExtensions.cs#L11 + + context.RegisterSourceOutput( + terminatingOperators.Where(o => o.Diagnostic is not null), + (context, terminatingOperator) => context.ReportDiagnostic(terminatingOperator.Diagnostic!)); + + context.RegisterSourceOutput( + terminatingOperators.Where(o => o.Diagnostic is null).Collect(), + GenerateCode); + } + + private static bool IsPossibleTerminatingOperator(SyntaxNode syntaxNode, CancellationToken cancellationToken) + => syntaxNode switch + { + InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.Text: var identifier } } + } + // TODO: Is the perf here something to worry about? If I had FrozenSet maybe I'd use it (not available because of TFM) + // TODO: There seem to be good optimizations around switch over strings, so this may be fine + // TODO: (https://github.com/dotnet/roslyn/pull/66081). + => identifier switch + { + // On Enumerable + "AsEnumerable" or "ToArray" or "ToDictionary" or "ToHashSet" or "ToLookup" or "ToList" + // On EntityFrameworkQueryableExtensions + or "AsAsyncEnumerable" or "ToArrayAsync" or "ToDictionaryAsync" or "ToHashSetAsync" or "ToListAsync" + // or "ToLookupAsync" + + // when syntaxNode.SyntaxTree.FilePath.Contains("Program.cs") // TODO: Hack for now, since the source gen runs on project references too (i.e. EF source code)?? + => true, + + _ => false + }, + + // TODO: Handle foreach + ForEachStatementSyntax => false, + + _ => false + }; + + private TerminatorOperatorInfo ProcessLinqQuery(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + InvocationExpressionSyntax terminatingOperator; + MemberAccessExpressionSyntax memberAccess; + string? interceptorDeclaration; + + // Our input is a candidate query terminating node (e.g. ToList()); above we've just verified its name, we now verify that it's + // actually the correct query operator (e.g. Enumerable.ToList() and not some other ToList()) by checking its symbol + switch (context.Node) + { + case InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.Text: var identifier } } memberAccess2 + } invocation: + { + // TODO: Is there an advantage in doing this via Operations (as the ASP.NET generator does)? + interceptorDeclaration = identifier switch + { + // These sync terminating operators exist exist over IEnumerable only, so verify the actual argument is an IQueryable + // (otherwise this is just LINQ to Objects) + // On Enumerable: + "AsEnumerable" when IsEnumerableOperatorOverQueryable() => """ +public static global::System.Collections.Generic.IEnumerable AsEnumerable_Safe( + this global::System.Collections.Generic.IEnumerable source) +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose( + (global::System.Linq.IQueryable)source); + return global::System.Linq.Enumerable.AsEnumerable(safeWrapped); +} +""", + "ToArray" when IsEnumerableOperatorOverQueryable() => """ +public static TSource[] ToArray_Safe(this global::System.Collections.Generic.IEnumerable source) +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose( + (global::System.Linq.IQueryable)source); + return global::System.Linq.Enumerable.ToArray(safeWrapped); +} +""", + "ToDictionary" when IsEnumerableOperatorOverQueryable() => """ +public static global::System.Collections.Generic.Dictionary ToDictionary_Safe( + this global::System.Collections.Generic.IEnumerable source, + global::System.Func keySelector, + global::System.Func elementSelector) + where TKey : notnull +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose( + (global::System.Linq.IQueryable)source); + return global::System.Linq.Enumerable.ToDictionary(safeWrapped, keySelector, elementSelector); +} +""", + "ToHashSet" when IsEnumerableOperatorOverQueryable() => """ +public static global::System.Collections.Generic.HashSet ToHashSet_Safe( + this global::System.Collections.Generic.IEnumerable source) +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose( + (global::System.Linq.IQueryable)source); + return global::System.Linq.Enumerable.ToHashSet(safeWrapped); +} +""", + "ToLookup" when IsEnumerableOperatorOverQueryable() => throw new NotImplementedException(), + + "ToList" when IsEnumerableOperatorOverQueryable() => """ +public static global::System.Collections.Generic.List ToList_Safe( + this global::System.Collections.Generic.IEnumerable source) +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(( + global::System.Linq.IQueryable)source); + return global::System.Linq.Enumerable.ToList(safeWrapped); +} +""", + + // On EntityFrameworkQueryableExtensions + "AsAsyncEnumerable" when IsOnEfQueryableExtensions() => """ +public static global::System.Collections.Generic.IAsyncEnumerable AsAsyncEnumerableAsync_Safe( + this global::System.Linq.IQueryable source) +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(source); + return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.AsAsyncEnumerable(safeWrapped); +} +""", + "ToArrayAsync" when IsOnEfQueryableExtensions() => """ +public static global::System.Threading.Tasks.Task ToArrayAsync_Safe( + this global::System.Linq.IQueryable source, + global::System.Threading.CancellationToken cancellationToken = default) +{ +var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(source); +return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.ToArrayAsync(safeWrapped); +} +""", + "ToDictionaryAsync" when IsOnEfQueryableExtensions() => """ +public static global::System.Threading.Tasks.Task> ToDictionaryAsync_Safe( + this global::System.Linq.IQueryable source, + global::System.Func keySelector, + global::System.Func elementSelector, + global::System.Threading.CancellationToken cancellationToken = default) + where TKey : notnull +{ + var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(source); + return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.ToDictionaryAsync( + safeWrapped, keySelector, elementSelector); +} +""", + "ToHashSetAsync" when IsOnEfQueryableExtensions() => """ +public static global::System.Threading.Tasks.Task> ToHashSetAsync_Safe( + this global::System.Linq.IQueryable source, + global::System.Threading.CancellationToken cancellationToken = default) +{ +var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(source); +return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.ToHashSetAsync(safeWrapped); +} +""", + "ToListAsync" when IsOnEfQueryableExtensions() => """ +public static global::System.Threading.Tasks.Task> ToListAsync_Safe( + this global::System.Linq.IQueryable source, + global::System.Threading.CancellationToken cancellationToken = default) +{ +var safeWrapped = global::Microsoft.EntityFrameworkCore.Query.Internal.PrecompiledQuerySafeMarker.Compose(source); +return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.ToListAsync(safeWrapped); +} +""", + // "ToLookupAsync" + + _ => default + }; + + // Check that we're actually dealing with a queryable LINQ query, with a well-known operator. + if (interceptorDeclaration is null) + { + return default; + } + + terminatingOperator = invocation; + memberAccess = memberAccess2; + + break; + + // TODO: we currently check symbols by their name; should we switch to loading symbols from the Compilation (via + // CompilationProvider) - like the System.Text.Json gen does - or MetadataReferenceProvider? Do symbols from dependencies + // (e.g. System.Linq.Enumerable) change across compilations? + bool IsEnumerableOperatorOverQueryable() + => context.SemanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol is IMethodSymbol + { + // TODO: Can check that the parameter is IEnumerable - but it's kind of useless, since we filter on specific method names anyway. May as well keep it like this. + ContainingType: + { Name: "Enumerable", ContainingNamespace: { Name: "Linq", ContainingNamespace.Name: "System" } } + } + && context.SemanticModel.GetSymbolInfo(memberAccess2.Expression, cancellationToken).Symbol switch + { + // Terminating operator over a method that returns an IQueryable, e.g. context.Blogs.Where(...).ToList() + // TODO: As an optimization, exclude methods defined on the Enumerable type before doing the more expensive + // TODO: IQueryable interface check? + IMethodSymbol { ReturnType: var returnType } when returnType.AllInterfaces.Any( + i => i.OriginalDefinition is + { + Name: "IQueryable", + ContainingNamespace: + { + Name: "Linq", + ContainingNamespace.Name: "System" + }, + ContainingAssembly.Name: "System.Linq.Expressions" + }) + => true, + + // Terminating operator directly over DbSet property, e.g. context.Blogs.ToList() + IPropertySymbol + { + Type: + { + Name: "DbSet", + ContainingNamespace: + { + Name: "EntityFrameworkCore", + ContainingNamespace.Name: "Microsoft" + }, + ContainingAssembly.Name: "Microsoft.EntityFrameworkCore" + } + } => true, + + // TODO: do we support DbSet fields as opposed to properties?? + + _ => false + }; + + bool IsOnEfQueryableExtensions() + => context.SemanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol is IMethodSymbol + { + ContainingType: + { + Name: "EntityFrameworkQueryableExtensions", + ContainingNamespace: + { + Name: "EntityFrameworkCore", + ContainingNamespace.Name: "Microsoft" + }, + ContainingAssembly.Name: "Microsoft.EntityFrameworkCore" + } + }; + } + + // TODO: Handle foreach + case ForEachStatementSyntax: + return default; + + default: + return default; + } + + // At this point we know we're dealing with a queryable LINQ query. Check whether it's static - i.e. a string of unbroken method + // invocations, rooted on an EF Core DbContext. If not, report a warning. + if (!IsRootOnDbContext(terminatingOperator)) + { + // TODO: Consider reporting the diagnostic on the whole fragment rather than only the terminating operator? + return new TerminatorOperatorInfo(Diagnostic.Create(DynamicQueryDiagnosticDescriptor, terminatingOperator.GetLocation())); + } + + // We now have a confirmed static EF query that needs to be intercepted. + var syntaxTree = terminatingOperator.SyntaxTree; + var startPosition = syntaxTree.GetLineSpan(memberAccess.Name.Span, cancellationToken).StartLinePosition; + + return new TerminatorOperatorInfo(syntaxTree.FilePath, startPosition.Line + 1, startPosition.Character + 1, interceptorDeclaration); + + bool IsRootOnDbContext(ExpressionSyntax expression) + { + // TODO: Carefully think about exactly what kind of verification we want to do here: static/non-static, actually get the + // TODO: method symbols and confirm it's an IQueryable flowing all the way through, etc. + + // Work backwards through the LINQ operator chain until we reach something that isn't a method invocation + while (expression is InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax { Expression: var innerExpression } + }) + { + expression = innerExpression; + } + + // We've reached a non-invocation. + + // First, check if this is a property access for a DbSet + if (expression is MemberAccessExpressionSyntax { Expression: var innerExpression2 } + && IsDbContext(innerExpression2)) + { + // TODO: Check symbol for DbSet? + return true; + } + + // If we had context.Set(), the Set() method was skipped like any other method, and we're on the context. + return IsDbContext(expression); + + bool IsDbContext(ExpressionSyntax expression) + { + switch (ModelExtensions.GetSymbolInfo(context.SemanticModel, expression, cancellationToken).Symbol) + { + case ILocalSymbol localSymbol: + return IsDbContextType(localSymbol.Type); + + case IPropertySymbol: + case IFieldSymbol: + case IMethodSymbol: + return false; // TODO + + case null: + return false; + default: + return false; // TODO: ? + } + + bool IsDbContextType(ITypeSymbol typeSymbol) + { + while (true) + { + if (typeSymbol is // TODO: Add assembly check + { + Name: "DbContext", + ContainingNamespace: + { + Name: "EntityFrameworkCore", + ContainingNamespace.Name: "Microsoft" + } + }) + { + return true; + } + + if (typeSymbol.BaseType is null) + { + return false; + } + + typeSymbol = typeSymbol.BaseType; + } + } + } + } + } + + private static void GenerateCode(SourceProductionContext context, ImmutableArray terminatingOperators) + { + if (terminatingOperators.IsDefaultOrEmpty) + { + return; + } + + // TODO: Add [GeneratedCode] to the publish-time generated code as well + var code = new StringBuilder() + .AppendLine( + $$""" +// +using System; +using System.Runtime.CompilerServices; + +namespace {{InterceptorsNamespace}} +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("{{typeof(LinqQuerySourceGenerator).Assembly.FullName}}", "{{typeof(LinqQuerySourceGenerator).Assembly.GetName().Version}}")] + file static class EntityFrameworkCoreInterceptors + { +"""); + + // TODO: Perf (GroupBy)? + foreach (var interceptionGroup in terminatingOperators.GroupBy(o => o.InterceptorDeclaration!)) + { + foreach (var terminatingOperator in interceptionGroup) + { + code.AppendLine( + $"""[InterceptsLocation("{terminatingOperator.FilePath}", {terminatingOperator.Line}, {terminatingOperator.Character})]"""); + } + + // TODO: Properly indent this for pretty generated code :) + code.AppendLine(interceptionGroup.Key); + } + + code.AppendLine( + """ + } +} + +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) { } + } +} +"""); + + context.AddSource("EFInterceptors.g.cs", code.ToString()); + } + + private readonly struct TerminatorOperatorInfo + : IEquatable + { + public TerminatorOperatorInfo(string filePath, int line, int character, string interceptorDeclaration) + { + FilePath = filePath; + Line = line; + Character = character; + InterceptorDeclaration = interceptorDeclaration; + } + + public TerminatorOperatorInfo(Diagnostic? diagnostic) + { + Diagnostic = diagnostic; + } + + public readonly string? InterceptorDeclaration; + public readonly Diagnostic? Diagnostic; + + // TODO: Do these need to participate in the equality/hashcode check? By definition they change if the originating syntax node + // changes etc. + public readonly string? FilePath; + public readonly int Line; + public readonly int Character; + + public bool IsEmpty + => InterceptorDeclaration is null && Diagnostic is null; + + public override bool Equals(object? obj) + => obj is TerminatorOperatorInfo other && Equals(other); + + public bool Equals(TerminatorOperatorInfo other) + => InterceptorDeclaration is null + ? Diagnostic!.Equals(other.Diagnostic) + // Interceptor declarations are always interned strings, so we skip calculating the hash codes and use reference comparison + // instead + : ReferenceEquals(InterceptorDeclaration, other.InterceptorDeclaration); + // && Line == other.Line + // && Character == other.Character + // && FilePath == other.FilePath; + + public override int GetHashCode() + { + unchecked + { + if (InterceptorDeclaration is null) + { + return Diagnostic!.GetHashCode(); + } + + // Interceptor declarations are always interned strings, so we skip calculating the hash codes and use the default reference + // hash code logic instead + var hashCode = RuntimeHelpers.GetHashCode(InterceptorDeclaration); + // hashCode = (hashCode * 397) ^ FilePath!.GetHashCode(); + // hashCode = (hashCode * 397) ^ Line; + // hashCode = (hashCode * 397) ^ Character; + return hashCode; + } + } + } +} diff --git a/src/EFCore.Analyzers/LinqQuerySourceGenerator.props b/src/EFCore.Analyzers/LinqQuerySourceGenerator.props new file mode 100644 index 00000000000..b6ef22e74e4 --- /dev/null +++ b/src/EFCore.Analyzers/LinqQuerySourceGenerator.props @@ -0,0 +1,5 @@ + + + + + diff --git a/src/EFCore.Analyzers/Properties/AnalyzerStrings.Designer.cs b/src/EFCore.Analyzers/Properties/AnalyzerStrings.Designer.cs index bcd1664c2ab..662d480d361 100644 --- a/src/EFCore.Analyzers/Properties/AnalyzerStrings.Designer.cs +++ b/src/EFCore.Analyzers/Properties/AnalyzerStrings.Designer.cs @@ -113,5 +113,23 @@ public class AnalyzerStrings { return ResourceManager.GetString("UninitializedDbSetWarningSuppressionJustification", resourceCulture); } } + + /// + /// Unsupported dynamic EF Core query. + /// + public static string DynamicQueryTitle { + get { + return ResourceManager.GetString("DynamicQueryTitle", resourceCulture); + } + } + + /// + /// This call to '{0}' represents a dynamic queryable LINQ query which cannot be precompiled by EF. + /// + public static string DynamicQueryMessageFormat { + get { + return ResourceManager.GetString("DynamicQueryMessageFormat", resourceCulture); + } + } } } diff --git a/src/EFCore.Analyzers/Properties/AnalyzerStrings.resx b/src/EFCore.Analyzers/Properties/AnalyzerStrings.resx index 2282189be54..edcca4fbe48 100644 --- a/src/EFCore.Analyzers/Properties/AnalyzerStrings.resx +++ b/src/EFCore.Analyzers/Properties/AnalyzerStrings.resx @@ -36,4 +36,10 @@ Method '{0}' inserts interpolated strings directly into the SQL, without any protection against SQL injection. Consider using '{1}' instead, which protects against SQL injection, or make sure that the value is sanitized and suppress the warning. + + Unsupported dynamic EF Core query. + + + This call to '{0}' represents a dynamic queryable LINQ query which cannot be precompiled by EF. + \ No newline at end of file diff --git a/src/EFCore.Relational/EFCore.Relational.csproj b/src/EFCore.Relational/EFCore.Relational.csproj index 514b7110e89..29254957763 100644 --- a/src/EFCore.Relational/EFCore.Relational.csproj +++ b/src/EFCore.Relational/EFCore.Relational.csproj @@ -9,6 +9,7 @@ true true $(NoWarn);EF1003 + $(NoWarn);CS1591 diff --git a/src/EFCore/EFCore.csproj b/src/EFCore/EFCore.csproj index 5ae18c8a667..c9a6be96da5 100644 --- a/src/EFCore/EFCore.csproj +++ b/src/EFCore/EFCore.csproj @@ -13,6 +13,7 @@ Microsoft.EntityFrameworkCore.DbSet Microsoft.EntityFrameworkCore true true + $(NoWarn);CS1591 diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 12a35ead3e1..67001486245 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2713,6 +2713,13 @@ public static string RuntimeModelMissingData public static string RuntimeParameterMissingParameter => GetString("RuntimeParameterMissingParameter"); + + /// + /// This LINQ query was not precompiled, likely because it is dynamic, and runtime query compilation has been disabled. + /// + public static string RuntimeQueryCompilationDisabled + => GetString("RuntimeQueryCompilationDisabled"); + /// /// The same parameter instance with name '{parameterName}' was used in multiple lambdas in the query tree. Each lambda must have its own parameter instances. /// diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index 1ae11dc6547..2b914fc14ff 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1480,6 +1480,9 @@ While registering a runtime parameter, the lambda expression must have only one parameter which must be same as 'QueryCompilationContext.QueryContextParameter' expression. + + This LINQ query was not precompiled, likely because it is dynamic, and runtime query compilation has been disabled. + The same parameter instance with name '{parameterName}' was used in multiple lambdas in the query tree. Each lambda must have its own parameter instances. diff --git a/src/EFCore/Query/Internal/PrecompiledQuerySafeMarker.cs b/src/EFCore/Query/Internal/PrecompiledQuerySafeMarker.cs new file mode 100644 index 00000000000..a060294b35f --- /dev/null +++ b/src/EFCore/Query/Internal/PrecompiledQuerySafeMarker.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +public class PrecompiledQuerySafeMarker : Expression +{ + internal static readonly MethodInfo ComposeMethodInfo + = typeof(PrecompiledQuerySafeMarker).GetTypeInfo().GetDeclaredMethod(nameof(Compose))!; + + public static IQueryable Compose(IQueryable source) + => source.Provider.CreateQuery( + Call( + instance: null, + method: new Func, IQueryable>(Compose).Method, + arguments: [source.Expression])); +} diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index a9860384aa0..47502a9ed0b 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -50,6 +50,7 @@ public class QueryTranslationPreprocessor /// A query expression after transformations. public virtual Expression Process(Expression query) { + query = CheckPrecompiledQuerySafeExpression(query); query = new InvocationExpressionRemovingExpressionVisitor().Visit(query); query = NormalizeQueryableMethod(query); query = new CallForwardingExpressionVisitor().Visit(query); @@ -67,6 +68,18 @@ public virtual Expression Process(Expression query) return query; } + private Expression CheckPrecompiledQuerySafeExpression(Expression query) + { + if (query is MethodCallExpression { Method.IsGenericMethod: true } methodCall + && methodCall.Method.GetGenericMethodDefinition() == PrecompiledQuerySafeMarker.ComposeMethodInfo) + { + return methodCall.Arguments[0]; + } + + // TODO: Check feature switch for whether we should allow only safe queries + throw new InvalidOperationException(CoreStrings.RuntimeQueryCompilationDisabled); + } + /// /// Normalizes queryable methods in the query. /// diff --git a/test/EFCore.Analyzers.Tests/EFCore.Analyzers.Tests.csproj b/test/EFCore.Analyzers.Tests/EFCore.Analyzers.Tests.csproj index d6d3557f289..be8d1afada1 100644 --- a/test/EFCore.Analyzers.Tests/EFCore.Analyzers.Tests.csproj +++ b/test/EFCore.Analyzers.Tests/EFCore.Analyzers.Tests.csproj @@ -38,8 +38,10 @@ + - + + diff --git a/test/EFCore.Analyzers.Tests/LinqQuerySourceGeneratorTests.cs b/test/EFCore.Analyzers.Tests/LinqQuerySourceGeneratorTests.cs new file mode 100644 index 00000000000..a43452459ec --- /dev/null +++ b/test/EFCore.Analyzers.Tests/LinqQuerySourceGeneratorTests.cs @@ -0,0 +1,381 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Runtime.Loader; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Extensions.DependencyModel; + +// ReSharper disable InconsistentNaming + +namespace Microsoft.EntityFrameworkCore; + +public class LinqQuerySourceGeneratorTests +{ + [Fact] + public Task Query_on_multiple_operators_and_DbSet_property() + => Test(""" +var blogs = context.Blogs.Where(b => b.Id > 3).OrderBy(b => b.Id).ToList(); +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task Query_directly_on_DbSet_property() + => Test( + """ +var blogs = context.Blogs.ToList(); +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task Query_directly_on_Set_method() + => Test( + """ +var blogs = context.Set().ToList(); +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task Query_over_enumerable_is_not_processed() + => Test( + "_ = new[] { 1, 2, 3 }.Where(i => i > 1).ToList();", + generatedCodeAsserter: code => Assert.Null(code)); + + [Fact] + public async Task Broken_up_query_does_not_work() + { + var exception = await Assert.ThrowsAsync(() => Test( + """ +var query = context.Blogs.Where(b => b.Id > 3); +_ = query.OrderBy(b => b.Id).ToList(); +""", + diagnosticsAsserter: d => Assert.Equal(LinqQuerySourceGenerator.Id, Assert.Single(d).Id))); + + Assert.Equal(CoreStrings.RuntimeQueryCompilationDisabled, exception.Message); + } + + [Fact] + public Task Same_terminating_operators_get_one_interceptor() + => Test(""" +var blogs = context.Blogs.Where(b => b.Id > 3).OrderBy(b => b.Id).ToList(); +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); + +var ids = context.Blogs.Select(b => b.Id).OrderBy(id => id).ToList(); +Assert.Equivalent(new[] { 8, 9 }, ids); +""", + generatedCodeAsserter: code => + { + Assert.NotNull(code); + Assert.Equal(2, CountOccurrences(code, "[InterceptsLocation(")); + Assert.Equal(1, CountOccurrences(code, "ToList_Safe")); + }); + + [Fact] + public async Task Source_generator_is_disabled_without_config_option() + { + var exception = await Assert.ThrowsAsync(() => Test( + "var blogs = context.Blogs.ToList();", + enableSourceGenerator: false, + generatedCodeAsserter: Assert.Null)); + + Assert.Equal(CoreStrings.RuntimeQueryCompilationDisabled, exception.Message); + } + + #region Terminating operators + + [Fact] + public Task ToList() + => Test(""" +var blogs = context.Blogs.ToList(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task ToListAsync() + => Test(""" +var blogs = await context.Blogs.ToListAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task ToArray() + => Test(""" +var blogs = context.Blogs.ToArray(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task ToArrayAsync() + => Test(""" +var blogs = await context.Blogs.ToArrayAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task ToDictionary() + => Test(""" +var blogs = context.Blogs.ToDictionary(kv => kv.Id, kv => kv.Name); +Assert.Equal(2, blogs.Count); +Assert.Equal("Blog1", blogs[8]); +Assert.Equal("Blog2", blogs[9]); +"""); + + [Fact] + public Task ToDictionaryAsync() + => Test(""" +var blogs = await context.Blogs.ToDictionaryAsync(kv => kv.Id, kv => kv.Name); +Assert.Equal(2, blogs.Count); +Assert.Equal("Blog1", blogs[8]); +Assert.Equal("Blog2", blogs[9]); +"""); + + [Fact] + public Task ToHashSet() + => Test(""" +var blogs = context.Blogs.ToHashSet(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task ToHashSetAsync() + => Test(""" +var blogs = await context.Blogs.ToHashSetAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [Fact] + public Task AsEnumerable() + => Test(""" +foreach (var blog in context.Blogs.Where(b => b.Id == 8).AsEnumerable()) +{ + Assert.Equal("Blog1", blog.Name); +} +"""); + + [Fact] + public Task AsAsyncEnumerable() + => Test(""" +await foreach (var blog in context.Blogs.Where(b => b.Id == 8).AsAsyncEnumerable()) +{ + Assert.Equal("Blog1", blog.Name); +} +"""); + + #endregion Terminating operators + + private static async Task Test( + string code, + bool enableSourceGenerator = true, + Action>? diagnosticsAsserter = null, + Action? generatedCodeAsserter = null) + { + var fullCode = $$""" +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Xunit; + +public static class TestContainer +{ + public static async Task Test() + { + await using var context = new BlogContext(); + + context.Blogs.AddRange(new Blog[] + { + new() { Id = 8, Name = "Blog1" }, + new() { Id = 9, Name = "Blog2" } + }); + context.SaveChanges(); + +{{code}} + } +} + +public class BlogContext : DbContext +{ + public DbSet Blogs { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + => optionsBuilder.UseInMemoryDatabase(Guid.NewGuid().ToString()); +} + +public class Blog +{ + public int Id { get; set; } + public string? Name { get; set; } +} +"""; + + var metadataReferences + = DependencyContext.Load(typeof(LinqQuerySourceGeneratorTests).Assembly)! + .CompileLibraries + .SelectMany(c => c.ResolveReferencePaths()) + .Select(path => MetadataReference.CreateFromFile(path)) + .Cast() + .ToList(); + + var interceptorsFeature = + new[] + { + new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.EntityFrameworkCore.GeneratedInterceptors") + }; + + var parseOptions = new CSharpParseOptions().WithFeatures(interceptorsFeature); + var syntaxTree = CSharpSyntaxTree.ParseText(fullCode, path: "Test.cs", options: parseOptions); + + var compilation = CSharpCompilation.Create( + "SourceGeneratorTests", + syntaxTrees: [syntaxTree], + metadataReferences, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var errorDiagnostics = compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).ToList(); + if (errorDiagnostics.Count > 0) + { + var stringBuilder = new StringBuilder(); + stringBuilder.AppendLine("Compilation failed:").AppendLine(); + + foreach (var errorDiagnostic in errorDiagnostics) + { + stringBuilder.AppendLine(errorDiagnostic.ToString()); + + var textLines = errorDiagnostic.Location.SourceTree!.GetText().Lines; + var startLine = errorDiagnostic.Location.GetLineSpan().StartLinePosition.Line; + var endLine = errorDiagnostic.Location.GetLineSpan().EndLinePosition.Line; + + if (startLine == endLine) + { + stringBuilder.Append("Line: ").AppendLine(textLines[startLine].ToString().TrimStart()); + } + else + { + stringBuilder.AppendLine("Lines:"); + for (var i = startLine; i <= endLine; i++) + { + stringBuilder.AppendLine(textLines[i].ToString()); + } + } + } + + throw new InvalidOperationException("Compilation failed:" + stringBuilder); + } + + var generator = new LinqQuerySourceGenerator(); + + var optionsProvider = new FakeAnalyzerConfigOptionsProvider( + enableSourceGenerator + ? [(LinqQuerySourceGenerator.DisableRuntimeCompilationMsbuildProperty, "true")] + : []); + + CSharpGeneratorDriver + .Create(generator) + .WithUpdatedParseOptions(parseOptions) + .WithUpdatedAnalyzerConfigOptions(optionsProvider) + .RunGeneratorsAndUpdateCompilation( + compilation, + out var outputCompilation, + out var diagnostics); + + diagnosticsAsserter ??= d => Assert.Empty(d); + diagnosticsAsserter(diagnostics); + + var (assemblyLoadContext, assembly) = EmitAndLoadAssembly(outputCompilation, ""); + + try + { + var testContainer = assembly.ExportedTypes.Single(t => t.Name == "TestContainer"); + var testMethod = testContainer.GetMethod("Test")!; + await (Task)testMethod.Invoke(obj: null, parameters: [])!; + } + finally + { + assemblyLoadContext.Unload(); + } + + if (generatedCodeAsserter is not null) + { + generatedCodeAsserter(outputCompilation.SyntaxTrees.Skip(1).SingleOrDefault()?.ToString()); + } + + static (AssemblyLoadContext, Assembly) EmitAndLoadAssembly(Compilation compilation, string assemblyLoadContextName) + { + var errorDiagnostics = compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).ToList(); + if (errorDiagnostics.Count > 0) + { + var stringBuilder = new StringBuilder(); + stringBuilder.AppendLine("Compilation failed:").AppendLine(); + + foreach (var errorDiagnostic in errorDiagnostics) + { + stringBuilder.AppendLine(errorDiagnostic.ToString()); + + var textLines = errorDiagnostic.Location.SourceTree!.GetText().Lines; + var startLine = errorDiagnostic.Location.GetLineSpan().StartLinePosition.Line; + var endLine = errorDiagnostic.Location.GetLineSpan().EndLinePosition.Line; + + if (startLine == endLine) + { + stringBuilder.Append("Line: ").AppendLine(textLines[startLine].ToString().TrimStart()); + } + else + { + stringBuilder.AppendLine("Lines:"); + for (var i = startLine; i <= endLine; i++) + { + stringBuilder.AppendLine(textLines[i].ToString()); + } + } + } + + throw new InvalidOperationException("Compilation failed:" + stringBuilder); + } + + using var memoryStream = new MemoryStream(); + var emitResult = compilation.Emit(memoryStream); + memoryStream.Position = 0; + + errorDiagnostics = emitResult.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToList(); + if (errorDiagnostics.Count > 0) + { + throw new InvalidOperationException( + "Compilation emit failed:" + Environment.NewLine + string.Join(Environment.NewLine, errorDiagnostics)); + } + + var assemblyLoadContext = new AssemblyLoadContext(assemblyLoadContextName, isCollectible: true); + var assembly = assemblyLoadContext.LoadFromStream(memoryStream); + return (assemblyLoadContext, assembly); + } + } + + private static int CountOccurrences(string s, string substring) + => s.Split(substring).Length - 1; +}