diff --git a/EFCore.sln.DotSettings b/EFCore.sln.DotSettings
index 6881961da98..d9724cb9987 100644
--- a/EFCore.sln.DotSettings
+++ b/EFCore.sln.DotSettings
@@ -301,9 +301,9 @@ The .NET Foundation licenses this file to you under the MIT license.
True
True
True
- True
True
True
+ True
True
True
True
@@ -325,7 +325,10 @@ The .NET Foundation licenses this file to you under the MIT license.
True
True
True
+ True
True
+ True
+ True
True
True
True
diff --git a/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs b/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs
index 60d4ef6301a..ce83a43ec36 100644
--- a/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs
+++ b/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs
@@ -3,6 +3,7 @@
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Migrations.Internal;
+using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Scaffolding.Internal;
namespace Microsoft.EntityFrameworkCore.Design;
diff --git a/src/EFCore.Design/Design/Internal/CSharpHelper.cs b/src/EFCore.Design/Design/Internal/CSharpHelper.cs
index 01d2f2c1acd..8b5cb3c4bac 100644
--- a/src/EFCore.Design/Design/Internal/CSharpHelper.cs
+++ b/src/EFCore.Design/Design/Internal/CSharpHelper.cs
@@ -1587,13 +1587,25 @@ private string ToSourceCode(SyntaxNode node)
public virtual string Statement(
Expression node,
ISet collectedNamespaces,
+ ISet unsafeAccessors,
IReadOnlyDictionary? constantReplacements,
IReadOnlyDictionary? memberAccessReplacements)
- => ToSourceCode(_translator.TranslateStatement(
- node,
- constantReplacements,
- memberAccessReplacements,
- collectedNamespaces));
+ {
+ var unsafeAccessorDeclarations = new HashSet();
+
+ var code = ToSourceCode(
+ _translator.TranslateStatement(
+ node,
+ constantReplacements,
+ memberAccessReplacements,
+ collectedNamespaces,
+ unsafeAccessorDeclarations));
+
+ // TODO: Possibly improve this (e.g. expose a single string that contains all the accessors concatenated?)
+ unsafeAccessors.UnionWith(unsafeAccessorDeclarations.Select(ToSourceCode));
+
+ return code;
+ }
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -1604,13 +1616,25 @@ private string ToSourceCode(SyntaxNode node)
public virtual string Expression(
Expression node,
ISet collectedNamespaces,
+ ISet unsafeAccessors,
IReadOnlyDictionary? constantReplacements,
IReadOnlyDictionary? memberAccessReplacements)
- => ToSourceCode(_translator.TranslateExpression(
- node,
- constantReplacements,
- memberAccessReplacements,
- collectedNamespaces));
+ {
+ var unsafeAccessorDeclarations = new HashSet();
+
+ var code = ToSourceCode(
+ _translator.TranslateExpression(
+ node,
+ constantReplacements,
+ memberAccessReplacements,
+ collectedNamespaces,
+ unsafeAccessorDeclarations));
+
+ // TODO: Possibly improve this (e.g. expose a single string that contains all the accessors concatenated?)
+ unsafeAccessors.UnionWith(unsafeAccessorDeclarations.Select(ToSourceCode));
+
+ return code;
+ }
private static bool IsIdentifierStartCharacter(char ch)
{
diff --git a/src/EFCore.Design/EFCore.Design.csproj b/src/EFCore.Design/EFCore.Design.csproj
index 7db02fcc958..8916afeb532 100644
--- a/src/EFCore.Design/EFCore.Design.csproj
+++ b/src/EFCore.Design/EFCore.Design.csproj
@@ -58,6 +58,8 @@
+
+
diff --git a/src/EFCore.Design/Properties/DesignStrings.Designer.cs b/src/EFCore.Design/Properties/DesignStrings.Designer.cs
index 4eff1157449..7c8b90a1f7e 100644
--- a/src/EFCore.Design/Properties/DesignStrings.Designer.cs
+++ b/src/EFCore.Design/Properties/DesignStrings.Designer.cs
@@ -207,6 +207,12 @@ public static string DuplicateMigrationName(object? migrationName)
GetString("DuplicateMigrationName", nameof(migrationName)),
migrationName);
+ ///
+ /// Dynamic LINQ queries are not supported when precompiling queries.
+ ///
+ public static string DynamicQueryNotSupported
+ => GetString("DynamicQueryNotSupported");
+
///
/// The encoding '{encoding}' specified in the output directive will be ignored. EF Core always scaffolds files using the encoding 'utf-8'.
///
@@ -609,6 +615,12 @@ public static string ProviderReturnedNullModel(object? providerTypeName)
GetString("ProviderReturnedNullModel", nameof(providerTypeName)),
providerTypeName);
+ ///
+ /// LINQ query comprehension syntax is currently unsupported in precompiled queries.
+ ///
+ public static string QueryComprehensionSyntaxNotSupportedInPrecompiledQueries
+ => GetString("QueryComprehensionSyntaxNotSupportedInPrecompiledQueries");
+
///
/// No files were generated in directory '{outputDirectoryName}'. The following file(s) already exist(s) and must be made writeable to continue: {readOnlyFiles}.
///
diff --git a/src/EFCore.Design/Properties/DesignStrings.resx b/src/EFCore.Design/Properties/DesignStrings.resx
index 842e2eb9170..c8ed2a6006b 100644
--- a/src/EFCore.Design/Properties/DesignStrings.resx
+++ b/src/EFCore.Design/Properties/DesignStrings.resx
@@ -192,6 +192,9 @@
The name '{migrationName}' is used by an existing migration.
+
+ Dynamic LINQ queries are not supported when precompiling queries.
+
The encoding '{encoding}' specified in the output directive will be ignored. EF Core always scaffolds files using the encoding 'utf-8'.
@@ -359,6 +362,9 @@ Change your target project to the migrations project by using the Package Manage
Metadata model returned should not be null. Provider: {providerTypeName}.
+
+ LINQ query comprehension syntax is currently unsupported in precompiled queries.
+
No files were generated in directory '{outputDirectoryName}'. The following file(s) already exist(s) and must be made writeable to continue: {readOnlyFiles}.
diff --git a/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs b/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
new file mode 100644
index 00000000000..7df792f374d
--- /dev/null
+++ b/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
@@ -0,0 +1,1309 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Immutable;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
+using System.Runtime.CompilerServices;
+using System.Text;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.EntityFrameworkCore.Internal;
+using static System.Linq.Expressions.Expression;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal;
+
+///
+/// Translates a Roslyn syntax tree into a LINQ expression tree.
+///
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+public class CSharpToLinqTranslator : CSharpSyntaxVisitor
+{
+ private static readonly SymbolDisplayFormat QualifiedTypeNameSymbolDisplayFormat = new(
+ typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);
+
+ private Compilation? _compilation;
+
+#pragma warning disable CS8618 // Uninitialized non-nullable fields. We check _compilation to make sure LoadCompilation was invoked.
+ private DbContext _userDbContext;
+ private Assembly? _additionalAssembly;
+ private INamedTypeSymbol _userDbContextSymbol;
+ private INamedTypeSymbol _formattableStringSymbol;
+#pragma warning restore CS8618
+
+ private SemanticModel _semanticModel = null!;
+
+ private static MethodInfo? _stringConcatMethod;
+ private static MethodInfo? _stringFormatMethod;
+ private static MethodInfo? _formattableStringFactoryCreateMethod;
+
+ ///
+ /// Loads the given and prepares to translate queries using the given .
+ ///
+ /// A containing the syntax nodes to be translated.
+ /// An instance of the user's .
+ /// An optional additional assemblies to resolve CLR types from.
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual void Load(Compilation compilation, DbContext userDbContext, Assembly? additionalAssembly = null)
+ {
+ _compilation = compilation;
+ _userDbContext = userDbContext;
+ _additionalAssembly = additionalAssembly;
+ _userDbContextSymbol = GetTypeSymbolOrThrow(userDbContext.GetType().FullName!);
+ _formattableStringSymbol = GetTypeSymbolOrThrow("System.FormattableString");
+
+ INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName)
+ => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName)
+ ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName);
+ }
+
+ private readonly Stack> _parameterStack
+ = new(new[] { ImmutableDictionary.Empty });
+
+ private readonly Dictionary _capturedVariables = new(SymbolEqualityComparer.Default);
+
+ ///
+ /// Translates a Roslyn syntax tree into a LINQ expression tree.
+ ///
+ /// The Roslyn syntax node to be translated.
+ ///
+ /// The for the Roslyn of which is a part.
+ ///
+ /// A LINQ expression tree translated from the provided .
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual Expression Translate(SyntaxNode node, SemanticModel semanticModel)
+ {
+ if (_compilation is null)
+ {
+ throw new InvalidOperationException("A compilation must be loaded.");
+ }
+
+ Check.DebugAssert(
+ ReferenceEquals(semanticModel.SyntaxTree, node.SyntaxTree),
+ "Provided semantic model doesn't match the provided syntax node");
+
+ _semanticModel = semanticModel;
+
+ // Perform data flow analysis to detect all captured data (closure parameters)
+ _capturedVariables.Clear();
+ foreach (var captured in _semanticModel.AnalyzeDataFlow(node).Captured)
+ {
+ _capturedVariables[captured] = null;
+ }
+
+ var result = Visit(node);
+
+ // TODO: Sanity check: make sure all captured variables in _capturedVariables have non-null values
+ // (i.e. have been encountered and referenced)
+
+ Debug.Assert(_parameterStack.Count == 1);
+ return result;
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [return: NotNullIfNotNull("node")]
+ public override Expression? Visit(SyntaxNode? node)
+ => base.Visit(node);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitAnonymousObjectCreationExpression(AnonymousObjectCreationExpressionSyntax anonymousObjectCreation)
+ {
+ // Creating an actual anonymous object means creating a new type, which can only be done with Reflection.Emit.
+ // At least for EF's purposes, it doesn't matter, so we build a placeholder.
+ if (_semanticModel.GetSymbolInfo(anonymousObjectCreation).Symbol is not IMethodSymbol constructorSymbol)
+ {
+ throw new InvalidOperationException(
+ "Could not find symbol for anonymous object creation initializer: " + anonymousObjectCreation);
+ }
+
+ var anonymousType = ResolveType(constructorSymbol.ContainingType);
+
+ var parameters = constructorSymbol.Parameters.ToArray();
+
+ var parameterInfos = new ParameterInfo[parameters.Length];
+ var memberInfos = new MemberInfo[parameters.Length];
+ var arguments = new Expression[parameters.Length];
+
+ foreach (var initializer in anonymousObjectCreation.Initializers)
+ {
+ // If the initializer's name isn't explicitly specified, infer it from the initializer's expression like the compiler does
+ var name = initializer.NameEquals is not null
+ ? initializer.NameEquals.Name.Identifier.Text
+ : initializer.Expression is MemberAccessExpressionSyntax memberAccess
+ ? memberAccess.Name.Identifier.Text
+ : throw new InvalidOperationException(
+ $"AnonymousObjectCreation: unnamed initializer with non-MemberAccess expression: {initializer.Expression}");
+
+ var position = Array.FindIndex(parameters, p => p.Name == name);
+ var parameter = parameters[position];
+ var parameterType = ResolveType(parameter.Type) ?? throw new InvalidOperationException(
+ "Could not resolve type symbol for: " + parameter.Type);
+
+ parameterInfos[position] = new FakeParameterInfo(name, parameterType, position);
+ arguments[position] = Visit(initializer.Expression);
+ memberInfos[position] = anonymousType.GetProperty(parameter.Name)!;
+ }
+
+ return New(
+ new FakeConstructorInfo(anonymousType, parameterInfos),
+ arguments: arguments,
+ memberInfos);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitArgument(ArgumentSyntax argument)
+ {
+ if (!argument.RefKindKeyword.IsKind(SyntaxKind.None))
+ {
+ throw new InvalidOperationException($"Argument with ref/out: {argument}");
+ }
+
+ return Visit(argument.Expression);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitArrayCreationExpression(ArrayCreationExpressionSyntax arrayCreation)
+ {
+ if (_semanticModel.GetTypeInfo(arrayCreation).Type is not IArrayTypeSymbol arrayTypeSymbol)
+ {
+ throw new InvalidOperationException($"ArrayCreation: non-array type symbol: {arrayCreation}");
+ }
+
+ if (arrayTypeSymbol.Rank > 1)
+ {
+ throw new NotImplementedException($"ArrayCreation: multi-dimensional array: {arrayCreation}");
+ }
+
+ var elementType = ResolveType(arrayTypeSymbol.ElementType);
+ Check.DebugAssert(elementType is not null, "elementType is not null");
+
+ return arrayCreation.Initializer is null
+ ? NewArrayBounds(elementType, Visit(arrayCreation.Type.RankSpecifiers[0].Sizes[0]))
+ : NewArrayInit(elementType, arrayCreation.Initializer.Expressions.Select(e => Visit(e)));
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitBinaryExpression(BinaryExpressionSyntax binary)
+ {
+ var left = Visit(binary.Left);
+ var right = Visit(binary.Right);
+
+ // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.BinaryExpressionSyntax
+ return binary.Kind() switch
+ {
+ // String concatenation
+ SyntaxKind.AddExpression
+ when left.Type == typeof(string) && right.Type == typeof(string)
+ => Add(left, right,
+ _stringConcatMethod ??=
+ typeof(string).GetMethod(nameof(string.Concat), new[] { typeof(string), typeof(string) })),
+
+ SyntaxKind.AddExpression => Add(left, right),
+ SyntaxKind.SubtractExpression => Subtract(left, right),
+ SyntaxKind.MultiplyExpression => Multiply(left, right),
+ SyntaxKind.DivideExpression => Divide(left, right),
+ SyntaxKind.ModuloExpression => Modulo(left, right),
+ SyntaxKind.LeftShiftExpression => LeftShift(left, right),
+ SyntaxKind.RightShiftExpression => RightShift(left, right),
+ // TODO UnsignedRightShiftExpression
+ SyntaxKind.LogicalOrExpression => OrElse(left, right),
+ SyntaxKind.LogicalAndExpression => AndAlso(left, right),
+
+ // For bitwise operations over enums, we the enum to its underlying type before the bitwise operation, and then back to the
+ // enum afterwards (this is corresponds to the LINQ expression tree that the compiler generates)
+ SyntaxKind.BitwiseOrExpression when left.Type.IsEnum || right.Type.IsEnum
+ => Convert(Or(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type),
+ SyntaxKind.BitwiseAndExpression when left.Type.IsEnum || right.Type.IsEnum
+ => Convert(And(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type),
+ SyntaxKind.ExclusiveOrExpression when left.Type.IsEnum || right.Type.IsEnum
+ => Convert(ExclusiveOr(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type),
+
+ SyntaxKind.BitwiseOrExpression => Or(left, right),
+ SyntaxKind.BitwiseAndExpression => And(left, right),
+ SyntaxKind.ExclusiveOrExpression => ExclusiveOr(left, right),
+
+ SyntaxKind.EqualsExpression => Equal(left, right),
+ SyntaxKind.NotEqualsExpression => NotEqual(left, right),
+ SyntaxKind.LessThanExpression => LessThan(left, right),
+ SyntaxKind.LessThanOrEqualExpression => LessThanOrEqual(left, right),
+ SyntaxKind.GreaterThanExpression => GreaterThan(left, right),
+ SyntaxKind.GreaterThanOrEqualExpression => GreaterThanOrEqual(left, right),
+ SyntaxKind.IsExpression => TypeIs(left, right is ConstantExpression { Value : Type type }
+ ? type
+ : throw new InvalidOperationException(
+ $"Encountered {SyntaxKind.IsExpression} with non-constant type right argument: {right}")),
+ SyntaxKind.AsExpression => TypeAs(left, right is ConstantExpression { Value : Type type }
+ ? type
+ : throw new InvalidOperationException(
+ $"Encountered {SyntaxKind.AsExpression} with non-constant type right argument: {right}")),
+ SyntaxKind.CoalesceExpression => Coalesce(left, right),
+
+ _ => throw new ArgumentOutOfRangeException($"BinaryExpressionSyntax with {binary.Kind()}")
+ };
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitCastExpression(CastExpressionSyntax cast)
+ => Convert(Visit(cast.Expression), ResolveType(cast.Type));
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitConditionalExpression(ConditionalExpressionSyntax conditional)
+ => Condition(
+ Visit(conditional.Condition),
+ Visit(conditional.WhenTrue),
+ Visit(conditional.WhenFalse));
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitElementAccessExpression(ElementAccessExpressionSyntax elementAccessExpression)
+ {
+ var arguments = elementAccessExpression.ArgumentList.Arguments;
+ var visitedExpression = Visit(elementAccessExpression.Expression);
+
+ switch (_semanticModel.GetTypeInfo(elementAccessExpression.Expression).ConvertedType)
+ {
+ case IArrayTypeSymbol:
+ Check.DebugAssert(elementAccessExpression.ArgumentList.Arguments.Count == 1,
+ $"ElementAccessExpressionSyntax over array with {arguments.Count} arguments");
+ return ArrayIndex(visitedExpression, Visit(arguments[0].Expression));
+
+ case INamedTypeSymbol:
+ var property = visitedExpression.Type
+ .GetProperties()
+ .Select(p => new { Property = p, IndexParameters = p.GetIndexParameters() })
+ .Where(
+ t => t.IndexParameters.Length == arguments.Count
+ && t.IndexParameters
+ .Select(p => p.ParameterType)
+ .SequenceEqual(arguments.Select(a => ResolveType(a.Expression))))
+ .Select(t => t.Property)
+ .FirstOrDefault();
+
+ if (property?.GetMethod is null)
+ {
+ throw new UnreachableException("No matching property found for ElementAccessExpressionSyntax");
+ }
+
+ return Call(visitedExpression, property.GetMethod, arguments.Select(a => Visit(a.Expression)));
+
+ case null:
+ throw new InvalidOperationException(
+ $"No type for expression {elementAccessExpression.Expression} in {nameof(ElementAccessExpressionSyntax)}");
+
+ default:
+ throw new NotImplementedException($"{nameof(ElementAccessExpressionSyntax)} over non-array");
+ }
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitIdentifierName(IdentifierNameSyntax identifierName)
+ {
+ if (_parameterStack.Peek().TryGetValue(identifierName.Identifier.Text, out var parameter))
+ {
+ return parameter;
+ }
+
+ var symbol = _semanticModel.GetSymbolInfo(identifierName).Symbol;
+
+ ITypeSymbol typeSymbol;
+ switch (symbol)
+ {
+ case INamedTypeSymbol s:
+ return Constant(ResolveType(s));
+ case ILocalSymbol s:
+ typeSymbol = s.Type;
+ break;
+ case IFieldSymbol s:
+ typeSymbol = s.Type;
+ break;
+ case IPropertySymbol s:
+ typeSymbol = s.Type;
+ break;
+ case null:
+ throw new InvalidOperationException($"Identifier without symbol: {identifierName}");
+ default:
+ throw new NotImplementedException($"IdentifierName of type {symbol.GetType().Name}: {identifierName}");
+ }
+
+ // TODO: Separate out EF Core-specific logic (EF Core would extend this visitor)
+ if (typeSymbol.Name.Contains("DbSet"))
+ {
+ throw new NotImplementedException("DbSet local symbol");
+ }
+
+ // We have an identifier which isn't in our parameters stack.
+
+ // First, if the identifier type is the user's DbContext type (e.g. DbContext local variable, or field/property),
+ // return a constant over that.
+ if (typeSymbol.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default))
+ {
+ return Constant(_userDbContext);
+ }
+
+ // The Translate entry point into the translator uses Roslyn's data flow analysis to locate all captured variables, and populates
+ // the _capturedVariable dictionary with them (with null values).
+ // TODO: Test closure over class member (not local variable)
+ if (symbol is ILocalSymbol localSymbol && _capturedVariables.TryGetValue(localSymbol, out var memberExpression))
+ {
+ // The first time we see a captured variable, we create MemberExpression for it and cache it in _capturedVariables.
+ return memberExpression
+ ?? (_capturedVariables[localSymbol] =
+ Field(
+ Constant(new FakeClosureFrameClass()),
+ new FakeFieldInfo(
+ typeof(FakeClosureFrameClass),
+ ResolveType(localSymbol.Type),
+ localSymbol.Name)));
+ }
+
+ throw new InvalidOperationException(
+ $"Encountered unknown identifier name '{identifierName}', which doesn't correspond to a lambda parameter or captured variable");
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitImplicitArrayCreationExpression(ImplicitArrayCreationExpressionSyntax implicitArrayCreation)
+ {
+ if (_semanticModel.GetTypeInfo(implicitArrayCreation).Type is not IArrayTypeSymbol arrayTypeSymbol)
+ {
+ throw new InvalidOperationException($"ArrayCreation: non-array type symbol: {implicitArrayCreation}");
+ }
+
+ if (arrayTypeSymbol.Rank > 1)
+ {
+ throw new NotImplementedException($"ArrayCreation: multi-dimensional array: {implicitArrayCreation}");
+ }
+
+ var elementType = ResolveType(arrayTypeSymbol.ElementType);
+ Check.DebugAssert(elementType is not null, "elementType is not null");
+
+ var initializers = implicitArrayCreation.Initializer.Expressions.Select(e => Visit(e));
+
+ return NewArrayInit(elementType, initializers);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitInterpolatedStringExpression(InterpolatedStringExpressionSyntax interpolatedString)
+ {
+ var formatBuilder = new StringBuilder();
+ var arguments = new List();
+ foreach (var fragment in interpolatedString.Contents)
+ {
+ switch (fragment)
+ {
+ case InterpolatedStringTextSyntax text:
+ formatBuilder.Append(text);
+ break;
+ case InterpolationSyntax interpolation:
+ var interpolationExpression = Visit(interpolation.Expression);
+ if (interpolationExpression.Type != typeof(object))
+ {
+ interpolationExpression = Convert(interpolationExpression, typeof(object));
+ }
+ arguments.Add(interpolationExpression);
+ formatBuilder.Append('{').Append(arguments.Count - 1).Append('}');
+ break;
+ default:
+ throw new UnreachableException();
+ }
+ }
+
+ // Return a call to string.Format(), unless we have an implicit conversion to FormattableString, in which case return a call to
+ // FormattableStringFactory.Create().
+ return Call(
+ _semanticModel.GetTypeInfo(interpolatedString).ConvertedType switch
+ {
+ { } t when t.Equals(_formattableStringSymbol, SymbolEqualityComparer.Default)
+ => _formattableStringFactoryCreateMethod ??= typeof(FormattableStringFactory).GetMethod(
+ nameof(FormattableStringFactory.Create), [typeof(string), typeof(object[])])!,
+
+ _ => _stringFormatMethod ??= typeof(string).GetMethod(nameof(string.Format), [typeof(string), typeof(object[])])!
+ },
+ Constant(formatBuilder.ToString()),
+ NewArrayInit(typeof(object), arguments));
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitInvocationExpression(InvocationExpressionSyntax invocation)
+ {
+ if (_semanticModel.GetSymbolInfo(invocation).Symbol is not IMethodSymbol methodSymbol)
+ {
+ throw new InvalidOperationException("Could not find symbol for method invocation: " + invocation);
+ }
+
+ // First, if the method return type is the user's DbContext type (e.g. DbContext local variable, or field/property), return a
+ // constant over that DbContext type; the invocation can serve as the root for a LINQ query we can precompile.
+ if (methodSymbol.ReturnType.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default))
+ {
+ return Constant(_userDbContext);
+ }
+
+ var declaringType = ResolveType(methodSymbol.ContainingType);
+
+ Expression? instance = null;
+ if (!methodSymbol.IsStatic || methodSymbol.IsExtensionMethod)
+ {
+ // In normal method calls (the ones we support), the invocation node is composed on top of a member access
+ if (invocation.Expression is not MemberAccessExpressionSyntax { Expression: var receiver })
+ {
+ throw new NotSupportedException($"Invocation over non-member access: {invocation}");
+ }
+
+ instance = Visit(receiver);
+ }
+
+ MethodInfo? methodInfo;
+
+ if (methodSymbol.IsGenericMethod)
+ {
+ var originalDefinition = methodSymbol.OriginalDefinition;
+ if (originalDefinition.ReducedFrom is not null)
+ {
+ originalDefinition = originalDefinition.ReducedFrom;
+ }
+
+ // To accurately find the right open generic method definition based on the Roslyn symbol, we need to create a mapping between
+ // generic type parameter names (based on the Roslyn side) and .NET reflection Types representing those type parameters.
+ // This includes both type parameters immediately on the generic method, as well as type parameters from the method's
+ // containing type (and recursively, its containing types)
+ var typeTypeParameterMap = new Dictionary(Foo(methodSymbol.ContainingType));
+
+ IEnumerable> Foo(INamedTypeSymbol typeSymbol)
+ {
+ // TODO: We match Roslyn type parameters by name, not sure that's right; also for the method's generic type parameters
+
+ if (typeSymbol.ContainingType is INamedTypeSymbol containingTypeSymbol)
+ {
+ foreach (var kvp in Foo(containingTypeSymbol))
+ {
+ yield return kvp;
+ }
+ }
+
+ var type = ResolveType(typeSymbol);
+ var genericArguments = type.GetGenericArguments();
+
+ Check.DebugAssert(
+ genericArguments.Length == typeSymbol.TypeParameters.Length,
+ "genericArguments.Length == typeSymbol.TypeParameters.Length");
+
+ foreach (var (typeParamSymbol, typeParamType) in typeSymbol.TypeParameters.Zip(genericArguments))
+ {
+ yield return new KeyValuePair(typeParamSymbol.Name, typeParamType);
+ }
+ }
+
+ var definitionMethodInfos = declaringType.GetMethods()
+ .Where(m =>
+ {
+ if (m.Name == methodSymbol.Name
+ && m.IsGenericMethodDefinition
+ && m.GetGenericArguments() is var candidateGenericArguments
+ && candidateGenericArguments.Length == originalDefinition.TypeParameters.Length
+ && m.GetParameters() is var candidateParams
+ && candidateParams.Length == originalDefinition.Parameters.Length)
+ {
+ var methodTypeParameterMap = new Dictionary(typeTypeParameterMap);
+
+ // Prepare a dictionary that will be used to resolve generic type parameters (ITypeParameterSymbol) to the
+ // corresponding reflection Type. This is needed to correctly (and recursively) resolve the type of parameters
+ // below.
+ foreach (var (symbol, type) in methodSymbol.TypeParameters.Zip(candidateGenericArguments))
+ {
+ if (symbol.Name != type.Name)
+ {
+ return false;
+ }
+
+ methodTypeParameterMap[symbol.Name] = type;
+ }
+
+ for (var i = 0; i < candidateParams.Length; i++)
+ {
+ var translatedParamType = ResolveType(originalDefinition.Parameters[i].Type, methodTypeParameterMap);
+ if (translatedParamType != candidateParams[i].ParameterType)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+ }).ToArray();
+
+ if (definitionMethodInfos.Length != 1)
+ {
+ throw new InvalidOperationException($"Invocation: Found {definitionMethodInfos.Length} matches for generic method: {invocation}");
+ }
+
+ var definitionMethodInfo = definitionMethodInfos[0];
+ var typeParams = methodSymbol.TypeArguments.Select(a => ResolveType(a)).ToArray();
+ methodInfo = definitionMethodInfo.MakeGenericMethod(typeParams);
+ }
+ else
+ {
+ // Non-generic method
+ var reducedMethodSymbol = methodSymbol.ReducedFrom ?? methodSymbol;
+
+ methodInfo = declaringType.GetMethod(
+ methodSymbol.Name,
+ BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static,
+ reducedMethodSymbol.Parameters.Select(p => ResolveType(p.Type)).ToArray());
+
+ if (methodInfo is null)
+ {
+ throw new InvalidOperationException(
+ $"Invocation: couldn't find method '{methodSymbol.Name}' on type '{declaringType.Name}': {invocation}");
+ }
+ }
+
+ // We have the reflection MethodInfo for the method, prepare the arguments.
+
+ // We can have less arguments than parameters when the method has optional parameters; fill in the missing ones with the default
+ // value.
+ // If the method also has a "params" parameter, we also need to take care of that - the syntactic arguments will need to be packed
+ // into the "params" array etc.
+ var parameters = methodInfo.GetParameters();
+ var sourceArguments = invocation.ArgumentList.Arguments;
+ var destArguments = new Expression?[parameters.Length];
+ var paramIndex = 0;
+
+ // At the syntactic level, an extension method invocation looks like a normal instance's.
+ // Prepend the instance to the argument list.
+ // TODO: Test invoking extension without extension syntax (as static)
+ if (methodSymbol is { IsExtensionMethod: true /*, ReceiverType: { } */ })
+ {
+ destArguments[0] = instance;
+ paramIndex = 1;
+ instance = null;
+ }
+
+ for (var sourceArgIndex = 0; paramIndex < parameters.Length; paramIndex++)
+ {
+ var parameter = parameters[paramIndex];
+ if (parameter.IsDefined(typeof(ParamArrayAttribute)))
+ {
+ // We've reached a "params" parameter; pack all the remaining args (possibly zero) into a NewArrayExpression
+ var elementType = parameter.ParameterType.GetElementType()!;
+ var paramsArguments = new Expression[sourceArguments.Count - sourceArgIndex];
+ for (var paramsArgIndex = 0; sourceArgIndex < sourceArguments.Count; sourceArgIndex++, paramsArgIndex++)
+ {
+ var arg = invocation.ArgumentList.Arguments[sourceArgIndex];
+ Check.DebugAssert(arg.NameColon is null, "Named argument in params");
+
+ paramsArguments[paramsArgIndex] = Visit(arg);
+ }
+
+ destArguments[paramIndex] = NewArrayInit(elementType, paramsArguments);
+ Check.DebugAssert(paramIndex == parameters.Length - 1, "Parameters after params");
+ break;
+ }
+
+ if (sourceArgIndex >= sourceArguments.Count)
+ {
+ // Fewer arguments than there are parameters - we have optional parameters.
+ Check.DebugAssert(parameter.IsOptional, "Missing non-optional argument");
+
+ destArguments[paramIndex] = Constant(
+ parameter.DefaultValue is null && parameter.ParameterType.IsValueType
+ ? Activator.CreateInstance(parameter.ParameterType)
+ : parameter.DefaultValue,
+ parameter.ParameterType);
+ continue;
+ }
+
+ var argument = invocation.ArgumentList.Arguments[sourceArgIndex++];
+
+ // Positional argument
+ if (argument.NameColon is null)
+ {
+ destArguments[paramIndex] = Visit(argument);
+ continue;
+ }
+
+ // Named argument
+ throw new NotImplementedException("Named argument");
+ }
+
+ Check.DebugAssert(destArguments.All(a => a is not null), "arguments.All(a => a is not null)");
+
+ // TODO: Generic type arguments
+ return Call(instance, methodInfo, destArguments!);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitLiteralExpression(LiteralExpressionSyntax literal)
+ => _semanticModel.GetTypeInfo(literal) is { ConvertedType: ITypeSymbol type }
+ ? Constant(literal.Token.Value, ResolveType(type))
+ : Constant(literal.Token.Value);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitMemberAccessExpression(MemberAccessExpressionSyntax memberAccess)
+ {
+ var expression = Visit(memberAccess.Expression);
+
+ if (_semanticModel.GetSymbolInfo(memberAccess).Symbol is not ISymbol memberSymbol)
+ {
+ throw new InvalidOperationException($"MemberAccess: Couldn't find symbol for member: {memberAccess}");
+ }
+
+ var containingType = ResolveType(memberSymbol.ContainingType);
+ var memberInfo = memberSymbol switch
+ {
+ IPropertySymbol p => (MemberInfo?)containingType.GetProperty(p.Name),
+ IFieldSymbol f => containingType.GetField(f.Name),
+ INamedTypeSymbol t => containingType.GetNestedType(t.Name),
+
+ null => throw new InvalidOperationException($"MemberAccess: Couldn't find symbol for member: {memberAccess}"),
+ _ => throw new NotSupportedException($"MemberAccess: unsupported member symbol '{memberSymbol.GetType().Name}': {memberAccess}")
+ };
+
+ switch (memberInfo)
+ {
+ case Type nestedType:
+ return Constant(nestedType);
+
+ case null:
+ throw new InvalidOperationException($"MemberAccess: couldn't find member '{memberSymbol.Name}': {memberAccess}");
+ }
+
+ // Enum field constant
+ if (containingType.IsEnum)
+ {
+ return Constant(Enum.Parse(containingType, memberInfo.Name), containingType);
+ }
+
+ // array.Length
+ if (expression.Type.IsArray && memberInfo.Name == "Length")
+ {
+ if (expression.Type.GetArrayRank() != 1)
+ {
+ throw new NotImplementedException("MemberAccess on multi-dimensional array");
+ }
+
+ return ArrayLength(expression);
+ }
+
+ return MakeMemberAccess(
+ expression is ConstantExpression { Value: Type } ? null : expression,
+ memberInfo);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitObjectCreationExpression(ObjectCreationExpressionSyntax objectCreation)
+ {
+ if (_semanticModel.GetSymbolInfo(objectCreation).Symbol is not IMethodSymbol constructorSymbol)
+ {
+ throw new InvalidOperationException($"ObjectCreation: couldn't find IMethodSymbol for constructor: {objectCreation}");
+ }
+
+ Check.DebugAssert(constructorSymbol.MethodKind == MethodKind.Constructor, "constructorSymbol.MethodKind == MethodKind.Constructor");
+
+ var type = ResolveType(constructorSymbol.ContainingType);
+
+ // Find the reflection constructor that matches the constructor symbol's signature
+ var parameterTypes = constructorSymbol.Parameters.Select(ps => ResolveType(ps.Type)).ToArray();
+ var constructor = type.GetConstructor(parameterTypes);
+
+ var newExpression = constructor is not null
+ ? New(
+ constructor,
+ objectCreation.ArgumentList?.Arguments.Select(a => Visit(a)) ?? Array.Empty())
+ : parameterTypes.Length == 0 // For structs, there's no actual parameterless constructor
+ ? New(type)
+ : throw new InvalidOperationException($"ObjectCreation: Missing constructor: {objectCreation}");
+
+ switch (objectCreation.Initializer)
+ {
+ // No initializers, just return the NewExpression
+ case null or { Expressions: [] }:
+ return newExpression;
+
+ // Assignment initializer (new Blog { Name = "foo" })
+ case { Expressions: [AssignmentExpressionSyntax, ..] }:
+ return MemberInit(
+ newExpression,
+ objectCreation.Initializer.Expressions.Select(
+ e =>
+ {
+ if (e is not AssignmentExpressionSyntax { Left: var lValue, Right: var value })
+ {
+ throw new NotSupportedException(
+ $"ObjectCreation: non-assignment initializer expression of type '{e.GetType().Name}': {objectCreation}");
+ }
+
+ var lValueSymbol = _semanticModel.GetSymbolInfo(lValue).Symbol;
+ var memberInfo = lValueSymbol switch
+ {
+ IPropertySymbol p => (MemberInfo?)type.GetProperty(p.Name),
+ IFieldSymbol f => type.GetField(f.Name),
+
+ _ => throw new InvalidOperationException(
+ $"ObjectCreation: unsupported initializer for member of type '{lValueSymbol?.GetType().Name}': {e}")
+ };
+
+ if (memberInfo is null)
+ {
+ throw new InvalidOperationException(
+ $"ObjectCreation: couldn't find initialized member '{lValueSymbol.Name}': {e}");
+ }
+
+ return Bind(memberInfo, Visit(value));
+ }));
+
+ // Non-assignment initializer => list initializer (new List { 1, 2, 3 })
+ default:
+ // Find the correct Add() method on the collection type
+ // TODO: This doesn't work if there are multiple Add() methods (contrived). Complete solution would be to find the base
+ // Type for all initializer expressions and find an Add overload of that type (or a superclass thereof)
+ var addMethod = type.GetMethods().SingleOrDefault(m => m.Name == "Add" && m.GetParameters().Length == 1);
+ if (addMethod is null)
+ {
+ throw new InvalidOperationException(
+ $"Couldn't find single Add method on type '{type.Name}', required for list initializer");
+ }
+
+ // TODO: Dictionary initializer, where each ElementInit has more than one expression
+
+ return ListInit(
+ newExpression,
+ objectCreation.Initializer.Expressions.Select(e => ElementInit(addMethod, Visit(e))));
+ }
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitParenthesizedExpression(ParenthesizedExpressionSyntax parenthesized)
+ => Visit(parenthesized.Expression);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax lambda)
+ => VisitLambdaExpression(lambda);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitPredefinedType(PredefinedTypeSyntax predefinedType)
+ => Constant(ResolveType(predefinedType), typeof(Type));
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitPrefixUnaryExpression(PrefixUnaryExpressionSyntax unary)
+ {
+ var operand = Visit(unary.Operand);
+
+ // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.PrefixUnaryExpressionSyntax
+
+ return unary.Kind() switch
+ {
+ SyntaxKind.UnaryPlusExpression => UnaryPlus(operand),
+ SyntaxKind.UnaryMinusExpression => Negate(operand),
+ SyntaxKind.BitwiseNotExpression => Not(operand),
+ SyntaxKind.LogicalNotExpression => Not(operand),
+
+ SyntaxKind.AddressOfExpression => throw NotSupportedInExpressionTrees(),
+ SyntaxKind.IndexExpression => throw NotSupportedInExpressionTrees(),
+ SyntaxKind.PointerIndirectionExpression => throw NotSupportedInExpressionTrees(),
+ SyntaxKind.PreDecrementExpression => throw NotSupportedInExpressionTrees(),
+ SyntaxKind.PreIncrementExpression => throw NotSupportedInExpressionTrees(),
+
+ _ => throw new UnreachableException(
+ $"Unexpected syntax kind '{unary.Kind()}' when visiting a {nameof(PrefixUnaryExpressionSyntax)}")
+ };
+
+ NotSupportedException NotSupportedInExpressionTrees()
+ => throw new UnreachableException(
+ $"Unary expression of type {unary.Kind()} is not supported in expression trees");
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitPostfixUnaryExpression(PostfixUnaryExpressionSyntax unary)
+ {
+ var operand = Visit(unary.Operand);
+
+ // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.PostfixUnaryExpressionSyntax
+
+ return unary.Kind() switch
+ {
+ SyntaxKind.SuppressNullableWarningExpression => operand,
+
+ SyntaxKind.PostIncrementExpression => throw NotSupportedInExpressionTrees(),
+ SyntaxKind.PostDecrementExpression => throw NotSupportedInExpressionTrees(),
+
+ _ => throw new UnreachableException(
+ $"Unexpected syntax kind '{unary.Kind()}' when visiting a {nameof(PostfixUnaryExpressionSyntax)}")
+ };
+
+ NotSupportedException NotSupportedInExpressionTrees()
+ => throw new UnreachableException(
+ $"Unary expression of type {unary.Kind()} is not supported in expression trees");
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitQueryExpression(QueryExpressionSyntax node)
+ => throw new NotSupportedException(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax lambda)
+ => VisitLambdaExpression(lambda);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression VisitTypeOfExpression(TypeOfExpressionSyntax typeOf)
+ {
+ if (_semanticModel.GetSymbolInfo(typeOf.Type).Symbol is not ITypeSymbol typeSymbol)
+ {
+ throw new InvalidOperationException(
+ "Could not find symbol for typeof() expression: " + typeOf);
+ }
+
+ var type = ResolveType(typeSymbol);
+ return Constant(type, typeof(Type));
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override Expression DefaultVisit(SyntaxNode node)
+ => throw new NotSupportedException($"Unsupported syntax node of type '{node.GetType()}': {node}");
+
+ private Expression VisitLambdaExpression(AnonymousFunctionExpressionSyntax lambda)
+ {
+ if (lambda.ExpressionBody is null)
+ {
+ throw new NotSupportedException("Lambda with null expression body");
+ }
+
+ if (lambda.Modifiers.Any())
+ {
+ throw new NotSupportedException("Lambda with modifiers not supported: " + lambda.Modifiers);
+ }
+
+ if (!lambda.AsyncKeyword.IsKind(SyntaxKind.None))
+ {
+ throw new NotSupportedException("Async lambdas are not supported");
+ }
+
+ var lambdaParameters = lambda switch
+ {
+ SimpleLambdaExpressionSyntax simpleLambda => SyntaxFactory.SingletonSeparatedList(simpleLambda.Parameter),
+ ParenthesizedLambdaExpressionSyntax parenthesizedLambda => parenthesizedLambda.ParameterList.Parameters,
+
+ _ => throw new UnreachableException()
+ };
+
+ var translatedParameters = new List();
+ foreach (var parameter in lambdaParameters)
+ {
+ if (_semanticModel.GetDeclaredSymbol(parameter) is not { } parameterSymbol ||
+ ResolveType(parameterSymbol.Type) is not { } parameterType)
+ {
+ throw new InvalidOperationException("Could not found symbol for parameter lambda: " + parameter);
+ }
+
+ translatedParameters.Add(Parameter(parameterType, parameter.Identifier.Text));
+ }
+
+ _parameterStack.Push(_parameterStack.Peek()
+ .AddRange(translatedParameters.Select(p => new KeyValuePair(p.Name ?? throw new NotImplementedException(), p))));
+
+ try
+ {
+ var body = Visit(lambda.ExpressionBody);
+ return Lambda(body, translatedParameters);
+ }
+ finally
+ {
+ _parameterStack.Pop();
+ }
+ }
+
+ ///
+ /// Given a Roslyn type symbol, returns a .NET reflection .
+ ///
+ /// The type symbol to be translated.
+ /// A .NET reflection that corresponds to .
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual Type TranslateType(ITypeSymbol typeSymbol)
+ => ResolveType(typeSymbol);
+
+ private Type ResolveType(SyntaxNode node)
+ => _semanticModel.GetTypeInfo(node).Type is { } typeSymbol
+ ? ResolveType(typeSymbol)
+ : throw new InvalidOperationException("Could not find type symbol for: " + node);
+
+ private Type ResolveType(ITypeSymbol typeSymbol, Dictionary? genericParameterMap = null)
+ {
+ switch (typeSymbol)
+ {
+ case INamedTypeSymbol { IsAnonymousType: true } anonymousTypeSymbol:
+ _anonymousTypeDefinitions ??= LoadAnonymousTypes(anonymousTypeSymbol.ContainingAssembly);
+ var properties = anonymousTypeSymbol.GetMembers().OfType().ToArray();
+ var found = _anonymousTypeDefinitions.TryGetValue(properties.Select(p => p.Name).ToArray(),
+ out var anonymousTypeGenericDefinition);
+ Debug.Assert(found, "Anonymous type not found");
+
+ var constructorParameters = anonymousTypeGenericDefinition!.GetConstructors()[0].GetParameters();
+ var genericTypeArguments = new Type[constructorParameters.Length];
+
+ for (var i = 0; i < constructorParameters.Length; i++)
+ {
+ genericTypeArguments[i] =
+ ResolveType(properties.FirstOrDefault(p => p.Name == constructorParameters[i].Name)!.Type);
+ }
+
+ // TODO: Cache closed anonymous types
+
+ return anonymousTypeGenericDefinition.MakeGenericType(genericTypeArguments);
+
+ case INamedTypeSymbol { IsDefinition: true } genericTypeSymbol:
+ return GetClrType(genericTypeSymbol);
+
+ case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol:
+ {
+ var definition = GetClrType(genericTypeSymbol.OriginalDefinition);
+ var typeArguments = genericTypeSymbol.TypeArguments.Select(a => ResolveType(a, genericParameterMap)).ToArray();
+ return definition.MakeGenericType(typeArguments);
+ }
+
+ case ITypeParameterSymbol typeParameterSymbol:
+ return genericParameterMap?.TryGetValue(typeParameterSymbol.Name, out var type) == true
+ ? type
+ : throw new InvalidOperationException($"Unknown generic type parameter symbol {typeParameterSymbol}");
+
+ case INamedTypeSymbol namedTypeSymbol:
+ return GetClrType(namedTypeSymbol);
+
+ case IArrayTypeSymbol arrayTypeSymbol:
+ // The ContainingAssembly of array type symbols can be null; recurse down the element types (down to the non-array element
+ // type) to get the assembly.
+ var containingAssembly = arrayTypeSymbol.ContainingAssembly;
+ ITypeSymbol currentSymbol = arrayTypeSymbol;
+ while (containingAssembly is null && currentSymbol is IArrayTypeSymbol { ElementType: var nestedTypeSymbol })
+ {
+ currentSymbol = nestedTypeSymbol;
+ containingAssembly = currentSymbol.ContainingAssembly;
+ }
+
+ return GetClrTypeFromAssembly(
+ containingAssembly,
+ typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat));
+
+ default:
+ return GetClrTypeFromAssembly(
+ typeSymbol.ContainingAssembly,
+ typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat));
+ }
+
+ Type GetClrType(INamedTypeSymbol symbol)
+ {
+ var name = symbol.ContainingType is null
+ ? typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)
+ : typeSymbol.Name;
+
+ if (symbol.IsGenericType)
+ {
+ name += '`' + symbol.Arity.ToString();
+ }
+
+ if (symbol.ContainingType is not null)
+ {
+ var containingType = ResolveType(symbol.ContainingType);
+
+ return containingType.GetNestedType(name)
+ ?? throw new InvalidOperationException(
+ $"Couldn't find nested type '{name}' on containing type '{containingType.Name}'");
+ }
+
+ return GetClrTypeFromAssembly(typeSymbol.ContainingAssembly, name);
+ }
+
+ Type GetClrTypeFromAssembly(IAssemblySymbol? assemblySymbol, string name)
+ => (assemblySymbol is null
+ ? Type.GetType(name)!
+ : Type.GetType($"{name}, {assemblySymbol.Name}"))
+ // If we can't find the Type, check the assembly where the user's DbContext type lives; this is primarily to support
+ // testing, where user code is in an assembly that's built as part of the the test and loaded into a specific
+ // AssemblyLoadContext (which gets unloaded later).
+ ?? _additionalAssembly?.GetType(name)
+ ?? throw new InvalidOperationException(
+ $"Couldn't resolve CLR type '{name}' in assembly '{assemblySymbol?.Name}'");
+
+ Dictionary LoadAnonymousTypes(IAssemblySymbol assemblySymbol)
+ {
+ Assembly? assembly;
+ try
+ {
+ assembly = Assembly.Load(assemblySymbol.Name);
+ }
+ catch (FileNotFoundException)
+ {
+ // If we can't find the assembly, use the assembly where the user's DbContext type lives; this is primarily to support
+ // testing, where user code is in an assembly that's built as part of the the test and loaded into a specific
+ // AssemblyLoadContext (which gets unloaded later).
+
+ // TODO: Strings
+ assembly = _additionalAssembly
+ ?? throw new InvalidOperationException($"Could not load assembly for IAssemblySymbol '{assemblySymbol.Name}'");
+ }
+
+ // Get all the anonymous type in the assembly, and index them by the ordered names of their properties.
+ // Note that anonymous types are generic, so we don't have property types in the key.
+
+ // TODO: An alternative strategy would be to just generate the types as we need them (with ref.emit) - that's probably safer.
+ // TODO: Though it may mean that the resulting CLR Type can't be anonymous (Type.IsAnonymousType()) - not sure that matters.
+ return assembly.GetTypes()
+ .Where(t => t.IsAnonymousType())
+ .ToDictionary(t => t.GetProperties().Select(x => x.Name).ToArray(), t => t, new ArrayStructuralComparer());
+ }
+ }
+
+ private sealed class ArrayStructuralComparer : IEqualityComparer
+ {
+ public bool Equals(T[]? x, T[]? y)
+ => x is null ? y is null : y is not null && x.SequenceEqual(y);
+
+ public int GetHashCode(T[] obj)
+ {
+ var hashcode = new HashCode();
+
+ foreach (var value in obj)
+ {
+ hashcode.Add(value);
+ }
+
+ return hashcode.ToHashCode();
+ }
+ }
+
+ private Dictionary? _anonymousTypeDefinitions;
+
+ [CompilerGenerated]
+ private sealed class FakeClosureFrameClass;
+
+ private sealed class FakeFieldInfo(Type declaringType, Type fieldType, string name) : FieldInfo
+ {
+ public override object[] GetCustomAttributes(bool inherit)
+ => Array.Empty();
+
+ public override object[] GetCustomAttributes(Type attributeType, bool inherit)
+ => Array.Empty();
+
+ public override bool IsDefined(Type attributeType, bool inherit)
+ => false;
+
+ public override Type DeclaringType { get; } = declaringType;
+
+ public override string Name { get; } = name;
+
+ public override Type? ReflectedType => null;
+
+ // We implement GetValue since ExpressionTreeFuncletizer calls it to get the parameter value. In AOT generation time, we obviously
+ // have no parameter values, nor do we need them for the first part of the query pipeline.
+ public override object? GetValue(object? obj)
+ => FieldType.IsValueType
+ ? Activator.CreateInstance(FieldType)
+ : FieldType == typeof(string)
+ ? ""
+ : null;
+
+ public override void SetValue(object? obj, object? value, BindingFlags invokeAttr, Binder? binder,
+ CultureInfo? culture)
+ => throw new NotSupportedException();
+
+ public override FieldAttributes Attributes
+ => FieldAttributes.Public;
+
+ public override RuntimeFieldHandle FieldHandle
+ => throw new NotSupportedException();
+
+ public override Type FieldType { get; } = fieldType;
+ }
+
+ private sealed class FakeConstructorInfo(Type type, ParameterInfo[] parameters) : ConstructorInfo
+ {
+ public override object[] GetCustomAttributes(bool inherit)
+ => Array.Empty();
+
+ public override object[] GetCustomAttributes(Type attributeType, bool inherit)
+ => Array.Empty();
+
+ public override bool IsDefined(Type attributeType, bool inherit)
+ => false;
+
+ public override Type DeclaringType { get; } = type;
+
+ public override string Name
+ => ".ctor";
+
+ public override Type ReflectedType
+ => DeclaringType;
+
+ public override MethodImplAttributes GetMethodImplementationFlags()
+ => MethodImplAttributes.Managed;
+
+ public override ParameterInfo[] GetParameters()
+ => parameters;
+
+ public override MethodAttributes Attributes
+ => MethodAttributes.Public;
+
+ public override RuntimeMethodHandle MethodHandle
+ => throw new NotSupportedException();
+
+ public override object Invoke(object? obj, BindingFlags invokeAttr, Binder? binder, object?[]? parameters,
+ CultureInfo? culture)
+ => throw new NotSupportedException();
+
+ public override object Invoke(BindingFlags invokeAttr, Binder? binder, object?[]? parameters,
+ CultureInfo? culture)
+ => throw new NotSupportedException();
+ }
+
+ private sealed class FakeParameterInfo(string name, Type parameterType, int position) : ParameterInfo
+ {
+ public override ParameterAttributes Attributes
+ => ParameterAttributes.In;
+
+ public override string? Name { get; } = name;
+ public override Type ParameterType { get; } = parameterType;
+ public override int Position { get; } = position;
+
+ public override MemberInfo Member
+ => throw new NotSupportedException();
+ }
+}
diff --git a/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs b/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
index e7b5f0d1d45..4ff06e5fcda 100644
--- a/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
+++ b/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
@@ -10,11 +10,9 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
-using Microsoft.EntityFrameworkCore.ChangeTracking.Internal;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-using E = System.Linq.Expressions.Expression;
namespace Microsoft.EntityFrameworkCore.Query.Internal;
@@ -43,6 +41,19 @@ private sealed record LiftedState
internal readonly Dictionary Variables = new();
internal readonly HashSet VariableNames = [];
internal readonly List UnassignedVariableDeclarations = [];
+
+ internal LiftedState CreateChild()
+ {
+ var child = new LiftedState();
+
+ foreach (var (parameter, name) in Variables)
+ {
+ child.Variables.Add(parameter, name);
+ }
+ child.VariableNames.UnionWith(VariableNames);
+
+ return child;
+ }
}
private LiftedState _liftedState = new();
@@ -53,13 +64,15 @@ private sealed record LiftedState
private readonly HashSet _capturedVariables = [];
private ISet _collectedNamespaces = null!;
+ private readonly Dictionary _methodUnsafeAccessors = new();
+ private readonly Dictionary<(FieldInfo Field, bool ForWrite), MethodDeclarationSyntax> _fieldUnsafeAccessors = new();
- private static MethodInfo? _activatorCreateInstanceMethod;
private static MethodInfo? _mathPowMethod;
private readonly SideEffectDetectionSyntaxWalker _sideEffectDetector = new();
private readonly ConstantDetectionSyntaxWalker _constantDetector = new();
private readonly SyntaxGenerator _g;
+ private readonly StringBuilder _stringBuilder = new();
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -98,8 +111,9 @@ public virtual IReadOnlySet CapturedVariables
public virtual SyntaxNode TranslateStatement(
Expression node,
IReadOnlyDictionary? constantReplacements,
- ISet collectedNamespaces)
- => TranslateCore(node, constantReplacements, collectedNamespaces, statementContext: true);
+ ISet collectedNamespaces,
+ ISet unsafeAccessors)
+ => TranslateCore(node, constantReplacements, collectedNamespaces, unsafeAccessors, statementContext: true);
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -110,8 +124,9 @@ public virtual IReadOnlySet CapturedVariables
public virtual SyntaxNode TranslateExpression(
Expression node,
IReadOnlyDictionary? constantReplacements,
- ISet collectedNamespaces)
- => TranslateCore(node, constantReplacements, collectedNamespaces, statementContext: false);
+ ISet collectedNamespaces,
+ ISet unsafeAccessors)
+ => TranslateCore(node, constantReplacements, collectedNamespaces, unsafeAccessors, statementContext: false);
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -123,6 +138,7 @@ public virtual IReadOnlySet CapturedVariables
Expression node,
IReadOnlyDictionary? constantReplacements,
ISet collectedNamespaces,
+ ISet unsafeAccessors,
bool statementContext)
{
_capturedVariables.Clear();
@@ -146,6 +162,11 @@ public virtual IReadOnlySet CapturedVariables
Check.DebugAssert(_stack.Peek().Labels.Count == 0, "_stack.Peek().Labels.Count == 0");
Check.DebugAssert(_stack.Peek().UnnamedLabelNames.Count == 0, "_stack.Peek().UnnamedLabelNames.Count == 0");
+ foreach (var unsafeAccessor in _fieldUnsafeAccessors.Values.Concat(_methodUnsafeAccessors.Values))
+ {
+ unsafeAccessors.Add(unsafeAccessor);
+ }
+
return Result!;
}
@@ -288,7 +309,7 @@ protected override Expression VisitBinary(BinaryExpression binary)
case ExpressionType.Power when binary.Left.Type == typeof(double) && binary.Right.Type == typeof(double):
return Visit(
- E.Call(
+ Expression.Call(
_mathPowMethod ??= typeof(Math).GetMethod(
nameof(Math.Pow), BindingFlags.Static | BindingFlags.Public, [typeof(double), typeof(double)])!,
binary.Left,
@@ -299,9 +320,9 @@ protected override Expression VisitBinary(BinaryExpression binary)
case ExpressionType.PowerAssign:
return Visit(
- E.Assign(
+ Expression.Assign(
binary.Left,
- E.Power(
+ Expression.Power(
binary.Left,
binary.Right)));
}
@@ -461,7 +482,7 @@ protected override Expression VisitBlock(BlockExpression block)
if (blockContext != ExpressionContext.Expression)
{
ownStackFrame = PushNewStackFrame();
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
}
var stackFrame = _stack.Peek();
@@ -469,217 +490,211 @@ protected override Expression VisitBlock(BlockExpression block)
// Do a 1st pass to identify and register any labels, since goto can appear before its label.
PreprocessLabels();
- try
+ // Go over the block's variables, assign names to any unnamed ones and uniquify. Then add them to our stack frame, unless
+ // this is an expression block that will get lifted.
+ foreach (var parameter in block.Variables)
{
- // Go over the block's variables, assign names to any unnamed ones and uniquify. Then add them to our stack frame, unless
- // this is an expression block that will get lifted.
- foreach (var parameter in block.Variables)
- {
- var (variables, variableNames) = (stackFrame.Variables, stackFrame.VariableNames);
+ var (variables, variableNames) = (stackFrame.Variables, stackFrame.VariableNames);
- var uniquifiedName = UniquifyVariableName(parameter.Name ?? "unnamed");
+ var uniquifiedName = UniquifyVariableName(parameter.Name ?? "unnamed");
- if (blockContext == ExpressionContext.Expression)
+ if (blockContext == ExpressionContext.Expression)
+ {
+ if (!_liftedState.Variables.TryAdd(parameter, uniquifiedName))
{
- if (_liftedState.Variables.ContainsKey(parameter))
- {
- throw new NotSupportedException("Parameter clash during expression lifting for: " + parameter.Name);
- }
-
- _liftedState.Variables.Add(parameter, uniquifiedName);
- _liftedState.VariableNames.Add(uniquifiedName);
+ throw new NotSupportedException("Parameter clash during expression lifting for: " + parameter.Name);
}
- else
- {
- if (!variables.TryAdd(parameter, uniquifiedName))
- {
- throw new InvalidOperationException(
- DesignStrings.SameParameterExpressionDeclaredAsVariableInNestedBlocks(parameter.Name ?? ""));
- }
- variableNames.Add(uniquifiedName);
+ _liftedState.VariableNames.Add(uniquifiedName);
+ }
+ else
+ {
+ if (!variables.TryAdd(parameter, uniquifiedName))
+ {
+ throw new InvalidOperationException(
+ DesignStrings.SameParameterExpressionDeclaredAsVariableInNestedBlocks(parameter.Name ?? ""));
}
+
+ variableNames.Add(uniquifiedName);
}
+ }
- var unassignedVariables = block.Variables.ToList();
+ var unassignedVariables = block.Variables.ToList();
- var statements = new List();
- LabeledStatementSyntax? pendingLabeledStatement = null;
+ var statements = new List();
+ LabeledStatementSyntax? pendingLabeledStatement = null;
- // Now visit the block's expressions
- for (var i = 0; i < block.Expressions.Count; i++)
- {
- var expression = block.Expressions[i];
- var onLastBlockLine = i == block.Expressions.Count - 1;
- _onLastLambdaLine = parentOnLastLambdaLine && onLastBlockLine;
+ // Now visit the block's expressions
+ for (var i = 0; i < block.Expressions.Count; i++)
+ {
+ var expression = block.Expressions[i];
+ var onLastBlockLine = i == block.Expressions.Count - 1;
+ _onLastLambdaLine = parentOnLastLambdaLine && onLastBlockLine;
- // Any lines before the last are evaluated in statement context (they aren't returned); the last line is evaluated in the
- // context of the block as a whole. _context now refers to the statement's context, blockContext to the block's.
- var statementContext = onLastBlockLine ? _context : ExpressionContext.Statement;
+ // Any lines before the last are evaluated in statement context (they aren't returned); the last line is evaluated in the
+ // context of the block as a whole. _context now refers to the statement's context, blockContext to the block's.
+ var statementContext = onLastBlockLine ? _context : ExpressionContext.Statement;
- SyntaxNode translated;
- using (ChangeContext(statementContext))
+ SyntaxNode translated;
+ using (ChangeContext(statementContext))
+ {
+ translated = Translate(expression);
+ }
+
+ // If we have a labeled statement, unwrap it and keep the label as pending. VisitLabel returns a dummy statement (since
+ // LINQ labels don't have a statement, unlike C#), so we'll skip that statement and add the label to the next real one.
+ if (translated is LabeledStatementSyntax labeledStatement)
+ {
+ if (pendingLabeledStatement is not null)
{
- translated = Translate(expression);
+ throw new NotImplementedException("Multiple labels on the same statement");
}
- // If we have a labeled statement, unwrap it and keep the label as pending. VisitLabel returns a dummy statement (since
- // LINQ labels don't have a statement, unlike C#), so we'll skip that statement and add the label to the next real one.
- if (translated is LabeledStatementSyntax labeledStatement)
- {
- if (pendingLabeledStatement is not null)
- {
- throw new NotImplementedException("Multiple labels on the same statement");
- }
+ pendingLabeledStatement = labeledStatement;
+ translated = labeledStatement.Statement;
+ }
- pendingLabeledStatement = labeledStatement;
- translated = labeledStatement.Statement;
- }
+ // Syntax optimization. This is an assignment of a block variable to some value. Render this as:
+ // var x = ;
+ // ... instead of:
+ // int x;
+ // x = ;
+ // ... except for expression context (i.e. on the last line), where we just return the value if needed.
+ if (expression is BinaryExpression { NodeType: ExpressionType.Assign, Left: ParameterExpression lValue }
+ && translated is AssignmentExpressionSyntax { Right: var valueSyntax }
+ && statementContext == ExpressionContext.Statement
+ && unassignedVariables.Remove(lValue))
+ {
+ var useExplicitVariableType = valueSyntax.Kind() == SyntaxKind.NullLiteralExpression;
- // Syntax optimization. This is an assignment of a block variable to some value. Render this as:
- // var x = ;
- // ... instead of:
- // int x;
- // x = ;
- // ... except for expression context (i.e. on the last line), where we just return the value if needed.
- if (expression is BinaryExpression { NodeType: ExpressionType.Assign, Left: ParameterExpression lValue }
- && translated is AssignmentExpressionSyntax { Right: var valueSyntax }
- && statementContext == ExpressionContext.Statement
- && unassignedVariables.Remove(lValue))
- {
- var useExplicitVariableType = valueSyntax.Kind() == SyntaxKind.NullLiteralExpression;
+ translated = useExplicitVariableType
+ ? _g.LocalDeclarationStatement(Generate(lValue.Type), LookupVariableName(lValue), valueSyntax)
+ : _g.LocalDeclarationStatement(LookupVariableName(lValue), valueSyntax);
+ }
- translated = useExplicitVariableType
- ? _g.LocalDeclarationStatement(Generate(lValue.Type), LookupVariableName(lValue), valueSyntax)
- : _g.LocalDeclarationStatement(LookupVariableName(lValue), valueSyntax);
- }
+ if (statementContext == ExpressionContext.Expression)
+ {
+ // We're on the last line of a block in expression context - the block is being lifted out.
+ // All statements before the last line (this one) have already been added to _liftedStatements, just return the last
+ // expression.
+ Check.DebugAssert(onLastBlockLine, "onLastBlockLine");
+ Result = translated;
+ break;
+ }
- if (statementContext == ExpressionContext.Expression)
+ if (blockContext != ExpressionContext.Expression)
+ {
+ if (_liftedState.Statements.Count > 0)
{
- // We're on the last line of a block in expression context - the block is being lifted out.
- // All statements before the last line (this one) have already been added to _liftedStatements, just return the last
- // expression.
- Check.DebugAssert(onLastBlockLine, "onLastBlockLine");
- Result = translated;
- break;
+ // If any expressions were lifted out of the current expression, flatten them into our own block, just before the
+ // expression from which they were lifted. Note that we don't do this in Expression context, since our own block is
+ // lifted out.
+ statements.AddRange(_liftedState.Statements);
+ _liftedState.Statements.Clear();
}
- if (blockContext != ExpressionContext.Expression)
+ // Same for any variables being lifted out of the block; we add them to our own stack frame so that we can do proper
+ // variable name uniquification etc.
+ if (_liftedState.Variables.Count > 0)
{
- if (_liftedState.Statements.Count > 0)
+ foreach (var (parameter, name) in _liftedState.Variables)
{
- // If any expressions were lifted out of the current expression, flatten them into our own block, just before the
- // expression from which they were lifted. Note that we don't do this in Expression context, since our own block is
- // lifted out.
- statements.AddRange(_liftedState.Statements);
- _liftedState.Statements.Clear();
+ stackFrame.Variables[parameter] = name;
+ stackFrame.VariableNames.Add(name);
}
- // Same for any variables being lifted out of the block; we add them to our own stack frame so that we can do proper
- // variable name uniquification etc.
- if (_liftedState.Variables.Count > 0)
- {
- foreach (var (parameter, name) in _liftedState.Variables)
- {
- stackFrame.Variables[parameter] = name;
- stackFrame.VariableNames.Add(name);
- }
-
- _liftedState.Variables.Clear();
- }
- }
-
- // Skip useless expressions with no side effects in statement context (these can be the result of switch/conditional lifting
- // with assignment lowering)
- if (statementContext == ExpressionContext.Statement && !_sideEffectDetector.MayHaveSideEffects(translated))
- {
- continue;
+ _liftedState.Variables.Clear();
}
+ }
- var statement = translated switch
- {
- StatementSyntax s => s,
+ // Skip useless expressions with no side effects in statement context (these can be the result of switch/conditional lifting
+ // with assignment lowering)
+ if (statementContext == ExpressionContext.Statement && !_sideEffectDetector.MayHaveSideEffects(translated))
+ {
+ continue;
+ }
- // If this is the last line in an expression lambda, wrap it in a return statement.
- ExpressionSyntax e when _onLastLambdaLine && statementContext == ExpressionContext.ExpressionLambda
- => ReturnStatement(e),
+ var statement = translated switch
+ {
+ StatementSyntax s => s,
- // If we're in statement context and we have an expression that can't stand alone (e.g. literal), assign it to discard
- ExpressionSyntax e when statementContext == ExpressionContext.Statement && !IsExpressionValidAsStatement(e)
- => ExpressionStatement((ExpressionSyntax)_g.AssignmentStatement(_g.IdentifierName("_"), e)),
+ // If this is the last line in an expression lambda, wrap it in a return statement.
+ ExpressionSyntax e when _onLastLambdaLine && statementContext == ExpressionContext.ExpressionLambda
+ => ReturnStatement(e),
- ExpressionSyntax e => ExpressionStatement(e),
+ // If we're in statement context and we have an expression that can't stand alone (e.g. literal), assign it to discard
+ ExpressionSyntax e when statementContext == ExpressionContext.Statement && !IsExpressionValidAsStatement(e)
+ => ExpressionStatement((ExpressionSyntax)_g.AssignmentStatement(_g.IdentifierName("_"), e)),
- _ => throw new ArgumentOutOfRangeException()
- };
+ ExpressionSyntax e => ExpressionStatement(e),
- if (blockContext == ExpressionContext.Expression)
- {
- // This block is in expression context, and so will be lifted (we won't be returning a block).
- _liftedState.Statements.Add(statement);
- }
- else
- {
- if (pendingLabeledStatement is not null)
- {
- statement = pendingLabeledStatement.WithStatement(statement);
- pendingLabeledStatement = null;
- }
+ _ => throw new ArgumentOutOfRangeException()
+ };
- statements.Add(statement);
- }
+ if (blockContext == ExpressionContext.Expression)
+ {
+ // This block is in expression context, and so will be lifted (we won't be returning a block).
+ _liftedState.Statements.Add(statement);
}
-
- // If a label existed on the last line of the block, add an empty statement (since C# requires it); for expression blocks we'd
- // have to lift that, not supported for now.
- if (pendingLabeledStatement is not null)
+ else
{
- if (blockContext == ExpressionContext.Expression)
+ if (pendingLabeledStatement is not null)
{
- throw new NotImplementedException("Label on last expression of an expression block");
+ statement = pendingLabeledStatement.WithStatement(statement);
+ pendingLabeledStatement = null;
}
- else
- {
- statements.Add(pendingLabeledStatement.WithStatement(EmptyStatement()));
- }
- }
- // Above we transform top-level assignments (i = 8) to var-declarations with initializers (var i = 8); those variables have
- // already been taken care of and removed from the list.
- // But there may still be variables that get assigned inside nested blocks or other situations; prepare declarations for those
- // and either add them to the block, or lift them if we're an expression block.
- var unassignedVariableDeclarations =
- unassignedVariables.Select(
- v => (LocalDeclarationStatementSyntax)_g.LocalDeclarationStatement(Generate(v.Type), LookupVariableName(v)));
+ statements.Add(statement);
+ }
+ }
+ // If a label existed on the last line of the block, add an empty statement (since C# requires it); for expression blocks we'd
+ // have to lift that, not supported for now.
+ if (pendingLabeledStatement is not null)
+ {
if (blockContext == ExpressionContext.Expression)
{
- _liftedState.UnassignedVariableDeclarations.AddRange(unassignedVariableDeclarations);
+ throw new NotImplementedException("Label on last expression of an expression block");
}
else
{
- statements.InsertRange(0, unassignedVariableDeclarations.Concat(_liftedState.UnassignedVariableDeclarations));
- _liftedState.UnassignedVariableDeclarations.Clear();
-
- // We're done. If the block is in an expression context, it needs to be lifted out; but not if it's in a lambda (in that
- // case we just added return above).
- Result = Block(statements);
+ statements.Add(pendingLabeledStatement.WithStatement(EmptyStatement()));
}
+ }
+
+ // Above we transform top-level assignments (i = 8) to var-declarations with initializers (var i = 8); those variables have
+ // already been taken care of and removed from the list.
+ // But there may still be variables that get assigned inside nested blocks or other situations; prepare declarations for those
+ // and either add them to the block, or lift them if we're an expression block.
+ var unassignedVariableDeclarations =
+ unassignedVariables.Select(
+ v => (LocalDeclarationStatementSyntax)_g.LocalDeclarationStatement(Generate(v.Type), LookupVariableName(v), initializer: _g.DefaultExpression(Generate(v.Type))));
- return block;
+ if (blockContext == ExpressionContext.Expression)
+ {
+ _liftedState.UnassignedVariableDeclarations.AddRange(unassignedVariableDeclarations);
}
- finally
+ else
{
- _onLastLambdaLine = parentOnLastLambdaLine;
- _liftedState = parentLiftedState;
+ statements.InsertRange(0, unassignedVariableDeclarations.Concat(_liftedState.UnassignedVariableDeclarations));
+ _liftedState.UnassignedVariableDeclarations.Clear();
- if (ownStackFrame is not null)
- {
- var popped = _stack.Pop();
- Check.DebugAssert(popped.Equals(ownStackFrame), "popped.Equals(ownStackFrame)");
- }
+ // We're done. If the block is in an expression context, it needs to be lifted out; but not if it's in a lambda (in that
+ // case we just added return above).
+ Result = Block(statements);
+ }
+
+ if (ownStackFrame is not null)
+ {
+ var popped = _stack.Pop();
+ Check.DebugAssert(popped.Equals(ownStackFrame), "popped.Equals(ownStackFrame)");
}
+ _onLastLambdaLine = parentOnLastLambdaLine;
+ _liftedState = parentLiftedState;
+
+ return block;
+
// Returns true for expressions which have side-effects, and can therefore appear alone as a statement
static bool IsExpressionValidAsStatement(ExpressionSyntax expression)
=> expression.Kind() switch
@@ -826,7 +841,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
}
var parentLiftedState = _liftedState;
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
// If we're in a lambda body, we try to translate as an expression if possible (i.e. no blocks in the true/false arms).
using (ChangeContext(ExpressionContext.Expression))
@@ -858,13 +873,13 @@ protected override Expression VisitConditional(ConditionalExpression conditional
TranslateConditionalStatement(
conditional.Update(
conditional.Test,
- conditional.IfTrue is BlockExpression ? conditional.IfTrue : E.Block(conditional.IfTrue),
- conditional.IfFalse is BlockExpression ? conditional.IfFalse : E.Block(conditional.IfFalse))));
+ conditional.IfTrue is BlockExpression ? conditional.IfTrue : Expression.Block(conditional.IfTrue),
+ conditional.IfFalse is BlockExpression ? conditional.IfFalse : Expression.Block(conditional.IfFalse))));
}
// We're in regular expression context, and there are lifted expressions inside one of the arms; we translate to an if/else
// statement but lowering an assignment into both sides of the condition
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
IdentifierNameSyntax assignmentVariable;
TypeSyntax? loweredAssignmentVariableType = null;
@@ -872,7 +887,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
if (lowerableAssignmentVariable is null)
{
var name = UniquifyVariableName("liftedConditional");
- var parameter = E.Parameter(conditional.Type, name);
+ var parameter = Expression.Parameter(conditional.Type, name);
assignmentVariable = IdentifierName(name);
loweredAssignmentVariableType = Generate(parameter.Type);
}
@@ -1101,11 +1116,6 @@ IEqualityComparer c
Generate(typeof(Encoding)),
IdentifierName(nameof(Encoding.Default))),
- FieldInfo fieldInfo
- => HandleFieldInfo(fieldInfo),
-
- //TODO: Handle PropertyInfo
-
_ => GenerateUnknownValue(value)
};
@@ -1162,27 +1172,6 @@ ExpressionSyntax HandleValueTuple(ITuple tuple)
return TupleExpression(SeparatedList(arguments));
}
-
- ExpressionSyntax HandleFieldInfo(FieldInfo fieldInfo)
- => fieldInfo.DeclaringType is null
- ? throw new NotSupportedException("Field without a declaring type: " + fieldInfo.Name)
- : (ExpressionSyntax)InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- TypeOfExpression(Generate(fieldInfo.DeclaringType)),
- IdentifierName(nameof(Type.GetField))),
- ArgumentList(
- SeparatedList(new[] {
- Argument(LiteralExpression(
- SyntaxKind.StringLiteralExpression,
- Literal(fieldInfo.Name))),
- Argument(BinaryExpression(
- SyntaxKind.BitwiseOrExpression,
- HandleEnum(fieldInfo.IsStatic ? BindingFlags.Static : BindingFlags.Instance),
- BinaryExpression(
- SyntaxKind.BitwiseOrExpression,
- HandleEnum(fieldInfo.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic),
- HandleEnum(BindingFlags.DeclaredOnly)))) })));
}
///
@@ -1200,6 +1189,12 @@ protected virtual ExpressionSyntax GenerateUnknownValue(object value)
return DefaultExpression(Generate(type));
}
+ if (value is IRelationalQuotableExpression relationalQuotableExpression
+ && Translate(relationalQuotableExpression.Quote()) is ExpressionSyntax expressionSyntax)
+ {
+ return expressionSyntax;
+ }
+
throw new NotSupportedException(
$"Encountered a constant of unsupported type '{value.GetType().Name}'. Only primitive constant nodes are supported."
+ Environment.NewLine + value);
@@ -1227,35 +1222,48 @@ protected override Expression VisitGoto(GotoExpression gotoNode)
///
protected override Expression VisitInvocation(InvocationExpression invocation)
{
- var lambda = (LambdaExpression)invocation.Expression;
-
- // We need to inline the lambda invocation into the tree, by replacing parameters in the lambda body with the invocation arguments.
- // However, if an argument to the invocation can have side effects (e.g. a method call), and it's referenced multiple times from
- // the body, then that would cause multiple evaluation, which is wrong (same if the arguments are evaluated only once but in reverse
- // order).
- // So we have to lift such arguments.
- var arguments = new Expression[invocation.Arguments.Count];
-
- for (var i = 0; i < arguments.Length; i++)
+ if (invocation.Expression is LambdaExpression lambda)
{
- var argument = invocation.Arguments[i];
+ // We need to inline the lambda invocation into the tree, by replacing parameters in the lambda body with the invocation arguments.
+ // However, if an argument to the invocation can have side effects (e.g. a method call), and it's referenced multiple times from
+ // the body, then that would cause multiple evaluation, which is wrong (same if the arguments are evaluated only once but in reverse
+ // order).
+ // So we have to lift such arguments.
+ var arguments = new Expression[invocation.Arguments.Count];
- if (argument is ConstantExpression)
+ for (var i = 0; i < arguments.Length; i++)
{
- // No need to evaluate into a separate variable, just pass directly
- arguments[i] = argument;
- continue;
+ var argument = invocation.Arguments[i];
+
+ if (argument is ConstantExpression)
+ {
+ // No need to evaluate into a separate variable, just pass directly
+ arguments[i] = argument;
+ continue;
+ }
+
+ // Need to lift
+ var name = UniquifyVariableName(lambda.Parameters[i].Name ?? "lifted");
+ var parameter = Expression.Parameter(argument.Type, name);
+ _liftedState.Statements.Add(GenerateVarDeclaration(name, Translate(argument)));
+ _liftedState.VariableNames.Add(name);
+ arguments[i] = parameter;
}
- // Need to lift
- var name = UniquifyVariableName(lambda.Parameters[i].Name ?? "lifted");
- var parameter = E.Parameter(argument.Type, name);
- _liftedState.Statements.Add(GenerateVarDeclaration(name, Translate(argument)));
- arguments[i] = parameter;
+ var replacedBody = new ReplacingExpressionVisitor(lambda.Parameters, arguments).Visit(lambda.Body);
+ Result = Translate(replacedBody);
}
+ else
+ {
+ // The invocation is over a non-inline lambda expression (i.e. field/property/method)
+ var expression = (ExpressionSyntax)Translate(invocation.Expression);
+
+ var translatedExpressions = TranslateList(invocation.Arguments);
- var replacedBody = new ReplacingExpressionVisitor(lambda.Parameters, arguments).Visit(lambda.Body);
- Result = Translate(replacedBody);
+ Result = InvocationExpression(
+ expression,
+ ArgumentList(SeparatedList(translatedExpressions.Select(Argument))));
+ }
return invocation;
}
@@ -1330,14 +1338,7 @@ protected virtual TypeSyntax Generate(Type type)
generic);
}
- if (type.IsNested)
- {
- AddNamespace(type.DeclaringType!);
- }
- else if (type.Namespace != null)
- {
- _collectedNamespaces.Add(type.Namespace);
- }
+ AddNamespace(type);
return generic;
}
@@ -1509,10 +1510,10 @@ protected override Expression VisitLoop(LoopExpression loop)
if (loop.ContinueLabel is not null)
{
- var blockBody = loop.Body is BlockExpression b ? b : E.Block(loop.Body);
+ var blockBody = loop.Body is BlockExpression b ? b : Expression.Block(loop.Body);
blockBody = blockBody.Update(
blockBody.Variables,
- new[] { E.Label(loop.ContinueLabel) }.Concat(blockBody.Expressions));
+ new[] { Expression.Label(loop.ContinueLabel) }.Concat(blockBody.Expressions));
rewrittenLoop1 = loop.Update(
loop.BreakLabel,
@@ -1525,9 +1526,9 @@ protected override Expression VisitLoop(LoopExpression loop)
if (loop.BreakLabel is not null)
{
rewrittenLoop2 =
- E.Block(
+ Expression.Block(
rewrittenLoop1.Update(breakLabel: null, rewrittenLoop1.ContinueLabel, rewrittenLoop1.Body),
- E.Label(loop.BreakLabel));
+ Expression.Label(loop.BreakLabel));
}
if (rewrittenLoop2 != loop)
@@ -1567,18 +1568,27 @@ protected override Expression VisitMember(MemberExpression member)
when constantExpression.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
&& System.Attribute.IsDefined(constantExpression.Type, typeof(CompilerGeneratedAttribute), inherit: true):
// Unwrap closure
- VisitConstant(E.Constant(closureField.GetValue(constantExpression.Value), member.Type));
+ VisitConstant(Expression.Constant(closureField.GetValue(constantExpression.Value), member.Type));
break;
// TODO: private event
default:
- Result = MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- member.Expression is null
- ? Generate(member.Member.DeclaringType!) // static
- : Translate(member.Expression),
- IdentifierName(member.Member.Name));
+ var expression = member switch
+ {
+ // Static member
+ { Expression: null } => Generate(member.Member.DeclaringType!),
+
+ // If the member isn't declared on the same type as the expression, (e.g. explicit interface implementation), add
+ // a cast up to the declaring type.
+ _ when member.Member.DeclaringType is Type declaringType && declaringType != member.Expression.Type
+ => ParenthesizedExpression(
+ CastExpression(Generate(member.Member.DeclaringType), Translate(member.Expression))),
+
+ _ => Translate(member.Expression)
+ };
+
+ Result = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, expression, IdentifierName(member.Member.Name));
break;
}
@@ -1591,24 +1601,28 @@ when constantExpression.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- protected virtual void TranslateNonPublicMemberAccess(MemberExpression member)
+ protected virtual void TranslateNonPublicMemberAccess(MemberExpression memberExpression)
{
- if (member.Expression is null)
+ if (memberExpression.Expression is null)
{
throw new NotImplementedException("Private static field access");
}
- var translatedExpression = Translate(member.Expression);
- Result = ParenthesizedExpression(
- CastExpression(
- Generate(member.Type),
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- GenerateValue(member.Member),
- IdentifierName(nameof(FieldInfo.GetValue))),
- ArgumentList(
- SingletonSeparatedList(Argument(translatedExpression))))));
+ // Get an unsafe accessor for this field/property (this internally caches and adds it to the output list of unsafe accessors)
+
+ // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]
+ // static extern ref int UnsafeAccessor_Foo_Name(Foo f);
+ var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration(
+ memberExpression.Member is PropertyInfo propertyInfo
+ ? propertyInfo.GetMethod ?? throw new UnreachableException("Attempting to read from property without getter")
+ : memberExpression.Member,
+ forWrite: false);
+
+ // The unsafe accessor declaration has been created; invoke it.
+ Result =
+ _g.InvocationExpression(
+ _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text),
+ Translate(memberExpression.Expression));
}
///
@@ -1622,19 +1636,205 @@ protected virtual void TranslateNonPublicMemberAccess(MemberExpression member)
Expression value,
SyntaxKind assignmentKind)
{
- // LINQ expression trees can directly access private members. Use the .NET [UnsafeAccessor] feature.
+ // LINQ expression trees can directly access private members, but C# code cannot. Use the .NET [UnsafeAccessor] feature.
if (memberExpression.Expression is null)
{
throw new NotImplementedException("Private static field assignment");
}
- Result = InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- GenerateValue(memberExpression.Member),
- IdentifierName(nameof(FieldInfo.SetValue))),
- ArgumentList(
- SeparatedList(new[] { Argument(Translate(memberExpression.Expression)), Argument(Translate(value)) })));
+ // Get an unsafe accessor for this field/property (this internally caches and adds it to the output list of unsafe accessors)
+
+ // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]
+ // static extern ref int UnsafeAccessor_Foo_Name(Foo f);
+ var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration(
+ memberExpression.Member is PropertyInfo propertyInfo
+ ? propertyInfo.SetMethod ?? throw new UnreachableException("Attempting to assign to property without setter")
+ : memberExpression.Member,
+ forWrite: true);
+
+ // The unsafe accessor declaration has been created; invoke it.
+ Result = memberExpression.Member switch
+ {
+ FieldInfo => AssignmentExpression(
+ assignmentKind,
+ (ExpressionSyntax)_g.InvocationExpression(
+ _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text),
+ Translate(memberExpression.Expression)),
+ Translate(value)),
+
+ PropertyInfo =>
+ _g.InvocationExpression(
+ _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), Translate(memberExpression.Expression),
+ assignmentKind is SyntaxKind.SimpleAssignmentExpression
+ ? Translate(value)
+ : throw new NotImplementedException("Compound assignment of private property not yet supported")),
+
+ _ => throw new UnreachableException()
+ };
+ }
+
+ private MethodDeclarationSyntax GetUnsafeAccessorDeclaration(MemberInfo member, bool forWrite = false)
+ {
+ MethodDeclarationSyntax? unsafeAccessorDeclaration;
+
+ switch (member)
+ {
+ case FieldInfo field:
+ {
+ // Note that we generate two accessors for fields (get/set), since the get accessor needs to be used in expression trees,
+ // which don't support ref return
+ if (_fieldUnsafeAccessors.TryGetValue((field, forWrite), out unsafeAccessorDeclaration))
+ {
+ return unsafeAccessorDeclaration;
+ }
+
+ break;
+ }
+
+ case MethodBase method: // Also constructors
+ {
+ if (_methodUnsafeAccessors.TryGetValue(method, out unsafeAccessorDeclaration))
+ {
+ return unsafeAccessorDeclaration;
+ }
+
+ break;
+ }
+
+ default:
+ throw new UnreachableException();
+ }
+
+ _stringBuilder.Clear().Append("UnsafeAccessor_");
+
+ if (member.DeclaringType?.Namespace?.Replace(".", "_") is string typeNamespace)
+ {
+ _stringBuilder.Append(typeNamespace).Append('_');
+ }
+
+ _stringBuilder.Append(member.DeclaringType!.Name).Append('_');
+
+ var memberName = member.Name;
+ _stringBuilder.Append(
+ member switch
+ {
+ // If this is the backing field of an auto-property, extract the name of the property from its compiler-generated name
+ // (e.g. k__BackingField)
+ FieldInfo when memberName[0] == '<' && memberName.IndexOf(">k__BackingField", StringComparison.Ordinal) is > 1 and var pos
+ => memberName[1..pos],
+ ConstructorInfo => "Ctor",
+ _ => memberName
+ });
+
+ var unsafeAccessorName = _stringBuilder.ToString();
+
+ switch (member)
+ {
+ case FieldInfo field:
+ {
+ // Unsafe accessor for fields:
+ // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_bar")]
+ // private static extern ref int GetSetPrivateField(Foo f);
+ // Note that we generate two accessors for fields (get/set), since the get accessor needs to be used in expression trees,
+ // which don't support ref return
+ unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration(
+ unsafeAccessorName + (forWrite ? "_Set" : "_Get"),
+ accessibility: Accessibility.Private,
+ modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern,
+ returnType: forWrite
+ ? RefType(Generate(field.FieldType))
+ : Generate(field.FieldType),
+ parameters: [_g.ParameterDeclaration("instance", Generate(member.DeclaringType))]);
+
+ unsafeAccessorDeclaration =
+ (MethodDeclarationSyntax)_g.AddAttributes(
+ unsafeAccessorDeclaration,
+ _g.Attribute(
+ "UnsafeAccessor",
+ _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Field)),
+ _g.AttributeArgument(
+ nameof(UnsafeAccessorAttribute.Name), _g.LiteralExpression(member.Name))));
+ break;
+ }
+
+ case MethodInfo { IsStatic: false } method:
+ {
+ // Unsafe accessor for methods. Note that this is used also for property getter and setter:
+ // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Bar")]
+ // private static void SetPrivateProperty(Foo f, int value);
+ unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration(
+ unsafeAccessorName,
+ accessibility: Accessibility.Private,
+ modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern,
+ parameters:
+ [
+ _g.ParameterDeclaration("instance", Generate(member.DeclaringType)),
+ .. method.GetParameters()
+ .Select(
+ p => _g.ParameterDeclaration(
+ p.Name ?? throw new UnreachableException("Missing parameter name"),
+ Generate(p.ParameterType)))
+ ]);
+
+ unsafeAccessorDeclaration =
+ (MethodDeclarationSyntax)_g.AddAttributes(
+ unsafeAccessorDeclaration,
+ _g.Attribute(
+ "UnsafeAccessor",
+ _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Method)),
+ _g.AttributeArgument(
+ nameof(UnsafeAccessorAttribute.Name), _g.LiteralExpression(memberName))));
+
+ break;
+ }
+
+ case ConstructorInfo constructor:
+ {
+ // Unsafe accessor for constructors:
+ // [UnsafeAccessor(UnsafeAccessorKind.Constructor)]
+ // extern static Class PrivateCtor(int i);
+ unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration(
+ unsafeAccessorName,
+ accessibility: Accessibility.Private,
+ modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern,
+ returnType: Generate(member.DeclaringType),
+ parameters: constructor.GetParameters()
+ .Select(
+ p => _g.ParameterDeclaration(
+ p.Name ?? throw new UnreachableException("Missing parameter name"),
+ Generate(p.ParameterType))));
+
+ unsafeAccessorDeclaration =
+ (MethodDeclarationSyntax)_g.AddAttributes(
+ unsafeAccessorDeclaration,
+ _g.Attribute(
+ "UnsafeAccessor",
+ _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Constructor))));
+
+ break;
+ }
+
+ default:
+ throw new UnreachableException("Unsafe declaration for unknown member type: " + member.GetType().Name);
+ }
+
+ unsafeAccessorDeclaration = unsafeAccessorDeclaration
+ .WithBody(null)
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+
+ switch (member)
+ {
+ case FieldInfo field:
+ _fieldUnsafeAccessors[(field, forWrite)] = unsafeAccessorDeclaration;
+ break;
+ case MethodBase method:
+ _methodUnsafeAccessors[method] = unsafeAccessorDeclaration;
+ break;
+ default:
+ throw new UnreachableException();
+ }
+
+ return unsafeAccessorDeclaration;
}
///
@@ -1711,24 +1911,25 @@ protected override Expression VisitMethodCall(MethodCallExpression call)
}
else
{
- ExpressionSyntax expression;
- if (call.Object is null)
+ var expression = call switch
{
- // Static method call. Recursively add MemberAccessExpressions for all declaring types (for methods on nested types)
- expression = GetMemberAccessesForAllDeclaringTypes(call.Method.DeclaringType);
+ { Method.IsStatic: true } => GetMemberAccessesForAllDeclaringTypes(call.Method.DeclaringType),
- ExpressionSyntax GetMemberAccessesForAllDeclaringTypes(Type type)
- => type.DeclaringType is null
- ? Generate(type)
- : MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- GetMemberAccessesForAllDeclaringTypes(type.DeclaringType),
- IdentifierName(type.Name));
- }
- else
- {
- expression = Translate(call.Object);
- }
+ // If the member isn't declared on the same type as the expression, (e.g. explicit interface implementation), add
+ // a cast up to the declaring type.
+ { Method.DeclaringType: Type declaringType, Object.Type: Type objectType, } when declaringType != objectType
+ => ParenthesizedExpression(CastExpression(Generate(declaringType), Translate(call.Object))),
+
+ _ => Translate(call.Object)
+ };
+
+ ExpressionSyntax GetMemberAccessesForAllDeclaringTypes(Type type)
+ => type.DeclaringType is null
+ ? Generate(type)
+ : MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ GetMemberAccessesForAllDeclaringTypes(type.DeclaringType),
+ IdentifierName(type.Name));
if (call.Method.Name.StartsWith("get_", StringComparison.Ordinal)
&& call.Method.GetParameters().Length == 1
@@ -1834,32 +2035,16 @@ protected override Expression VisitNew(NewExpression node)
return node;
}
- // If the type has any required properties and the constructor doesn't have [SetsRequiredMembers], we can't just generate an
- // instantiation expression.
- // TODO: Currently matching attributes by name since we target .NET 6.0. If/when we target .NET 7.0 and above, match the type.
- if (node.Type.GetCustomAttributes(inherit: true)
- .Any(a => a.GetType().FullName == "System.Runtime.CompilerServices.RequiredMemberAttribute")
- && node.Constructor is not null
- && node.Constructor.GetCustomAttributes()
- .Any(a => a.GetType().FullName == "System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute")
- != true)
- {
- // If the constructor is parameterless, we generate Activator.Create() which is almost as fast (<10ns difference).
- // For constructors with parameters, we currently throw as not supported (we can pass parameters, but boxing, probably
- // speed degradation etc.).
- if (node.Constructor.GetParameters().Length == 0)
- {
- Result =
- Translate(
- E.Call(
- (_activatorCreateInstanceMethod ??= typeof(Activator).GetMethod(
- nameof(Activator.CreateInstance), [])!)
- .MakeGenericMethod(node.Type)));
- }
- else
- {
- throw new NotImplementedException("Instantiation of type with required properties via constructor that has parameters");
- }
+ // If the constructor isn't public, or it has required properties and the constructor doesn't have [SetsRequiredMembers], we can't
+ // just generate a regular instantiation expression (won't compile). Generate an unsafe accessor instead.
+ if (node.Constructor is ConstructorInfo constructor
+ && (!constructor.IsPublic
+ || node.Type.GetCustomAttribute() is not null
+ && constructor.GetCustomAttribute() is null))
+ {
+ var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration(constructor);
+
+ Result = _g.InvocationExpression(_g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), arguments);
}
else
{
@@ -1870,9 +2055,9 @@ protected override Expression VisitNew(NewExpression node)
initializer: null);
}
- if (node.Constructor?.DeclaringType is not null)
+ if (node.Type.Namespace is not null)
{
- AddNamespace(node.Constructor?.DeclaringType!);
+ AddNamespace(node.Type);
}
return node;
@@ -1940,7 +2125,7 @@ protected virtual CSharpSyntaxNode TranslateSwitch(SwitchExpression switchNode,
case ExpressionContext.Statement:
{
var parentLiftedState = _liftedState;
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
var cases = List(
switchNode.Cases.Select(
@@ -1993,7 +2178,7 @@ SyntaxList ProcessArmBody(Expression body)
}
var parentLiftedState = _liftedState;
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
// Translate all arms
var arms = SeparatedList(
@@ -2020,7 +2205,7 @@ SyntaxList ProcessArmBody(Expression body)
// There are lifted expressions inside some of the arms, we must lift the entire switch expression, rewriting it to
// a switch statement.
- _liftedState = new LiftedState();
+ _liftedState = parentLiftedState.CreateChild();
IdentifierNameSyntax assignmentVariable;
TypeSyntax? loweredAssignmentVariableType = null;
@@ -2028,7 +2213,7 @@ SyntaxList ProcessArmBody(Expression body)
if (lowerableAssignmentVariable is null)
{
var name = UniquifyVariableName("liftedSwitch");
- var parameter = E.Parameter(switchNode.Type, name);
+ var parameter = Expression.Parameter(switchNode.Type, name);
assignmentVariable = IdentifierName(name);
loweredAssignmentVariableType = Generate(parameter.Type);
}
@@ -2116,8 +2301,8 @@ static ConditionalExpression RewriteSwitchToConditionals(SwitchExpression node)
.Aggregate(
node.DefaultBody,
(expression, arm) => expression is null
- ? E.IfThen(E.Equal(node.SwitchValue, arm.Label), arm.Body)
- : E.IfThenElse(E.Equal(node.SwitchValue, arm.Label), arm.Body, expression))
+ ? Expression.IfThen(Expression.Equal(node.SwitchValue, arm.Label), arm.Body)
+ : Expression.IfThenElse(Expression.Equal(node.SwitchValue, arm.Label), arm.Body, expression))
?? throw new NotImplementedException("Empty switch statement"));
}
@@ -2128,8 +2313,8 @@ static ConditionalExpression RewriteSwitchToConditionals(SwitchExpression node)
.Reverse()
.Aggregate(
node.DefaultBody,
- (expression, arm) => E.Condition(
- E.Equal(node.SwitchValue, arm.Label),
+ (expression, arm) => Expression.Condition(
+ Expression.Equal(node.SwitchValue, arm.Label),
arm.Body,
expression));
}
@@ -2162,7 +2347,7 @@ protected override Expression VisitTry(TryExpression tryNode)
Result = _g.TryCatchStatement(
translatedBody,
- catchClauses: [TranslateCatchBlock(E.Catch(typeof(Exception), tryNode.Fault), noType: true)]);
+ catchClauses: [TranslateCatchBlock(Expression.Catch(typeof(Exception), tryNode.Fault), noType: true)]);
return tryNode;
}
@@ -2239,8 +2424,8 @@ protected override Expression VisitUnary(UnaryExpression unary)
ExpressionType.Quote => operand,
ExpressionType.UnaryPlus => PrefixUnaryExpression(SyntaxKind.UnaryPlusExpression, operand),
ExpressionType.Unbox => operand,
- ExpressionType.Increment => Translate(E.Add(unary.Operand, E.Constant(1))),
- ExpressionType.Decrement => Translate(E.Subtract(unary.Operand, E.Constant(1))),
+ ExpressionType.Increment => Translate(Expression.Add(unary.Operand, Expression.Constant(1))),
+ ExpressionType.Decrement => Translate(Expression.Subtract(unary.Operand, Expression.Constant(1))),
ExpressionType.PostIncrementAssign => PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, operand),
ExpressionType.PostDecrementAssign => PostfixUnaryExpression(SyntaxKind.PostDecrementExpression, operand),
ExpressionType.PreIncrementAssign => PrefixUnaryExpression(SyntaxKind.PreIncrementExpression, operand),
@@ -2633,6 +2818,10 @@ public bool MayHaveSideEffects(SyntaxNode node)
public override void Visit(SyntaxNode node)
{
_mayHaveSideEffects |= MayHaveSideEffectsCore(node);
+ if (_mayHaveSideEffects)
+ {
+ return;
+ }
base.Visit(node);
}
@@ -2640,9 +2829,10 @@ public override void Visit(SyntaxNode node)
private static bool MayHaveSideEffectsCore(SyntaxNode node)
=> node switch
{
- IdentifierNameSyntax or LiteralExpressionSyntax => false,
+ IdentifierNameSyntax or LiteralExpressionSyntax or PredefinedTypeSyntax => false,
ExpressionStatementSyntax e => MayHaveSideEffectsCore(e.Expression),
EmptyStatementSyntax => false,
+ DefaultExpressionSyntax => false,
// TODO: we can exempt most binary and unary expressions as well, e.g. i + 5, but not anything involving assignment
_ => true
diff --git a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
new file mode 100644
index 00000000000..11ae12ebae7
--- /dev/null
+++ b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
@@ -0,0 +1,1128 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Runtime.ExceptionServices;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Editing;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal;
+
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+public class PrecompiledQueryCodeGenerator
+{
+ private readonly QueryLocator _queryLocator;
+ private readonly CSharpToLinqTranslator _csharpToLinqTranslator;
+
+ private SyntaxGenerator _g = null!;
+ private IQueryCompiler _queryCompiler = null!;
+ private ExpressionTreeFuncletizer _funcletizer = null!;
+ private LinqToCSharpSyntaxTranslator _linqToCSharpTranslator = null!;
+ private LiftableConstantProcessor _liftableConstantProcessor = null!;
+
+ private Symbols _symbols;
+
+ private readonly HashSet _namespaces = new();
+ private readonly HashSet _unsafeAccessors = new();
+ private readonly IndentedStringBuilder _code = new();
+
+ private const string InterceptorsNamespace = "Microsoft.EntityFrameworkCore.GeneratedInterceptors";
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public PrecompiledQueryCodeGenerator()
+ {
+ _queryLocator = new QueryLocator();
+ _csharpToLinqTranslator = new CSharpToLinqTranslator();
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual IReadOnlyList GeneratePrecompiledQueries(
+ Compilation compilation,
+ SyntaxGenerator syntaxGenerator,
+ DbContext dbContext,
+ List precompilationErrors,
+ Assembly? additionalAssembly = null,
+ CancellationToken cancellationToken = default)
+ {
+ _queryLocator.Initialize(compilation);
+ _symbols = Symbols.Load(compilation);
+ _g = syntaxGenerator;
+ _linqToCSharpTranslator = new LinqToCSharpSyntaxTranslator(_g);
+ _liftableConstantProcessor = new LiftableConstantProcessor(null!);
+ _queryCompiler = dbContext.GetService();
+ _unsafeAccessors.Clear();
+ _funcletizer = new ExpressionTreeFuncletizer(
+ dbContext.Model,
+ dbContext.GetService(),
+ dbContext.GetType(),
+ generateContextAccessors: false,
+ dbContext.GetService>());
+
+ // This must be done after we complete generating the final compilation above
+ _csharpToLinqTranslator.Load(compilation, dbContext, additionalAssembly);
+
+ // TODO: Ignore our auto-generated code! Also compiled model, generated code (comment, filename...?).
+ var generatedSyntaxTrees = new List();
+ foreach (var syntaxTree in compilation.SyntaxTrees)
+ {
+ if (_queryLocator.LocateQueries(syntaxTree, precompilationErrors, cancellationToken) is not { Count: > 0 } locatedQueries)
+ {
+ continue;
+ }
+
+ var semanticModel = compilation.GetSemanticModel(syntaxTree);
+ var generatedSyntaxTree = ProcessSyntaxTreeAsync(
+ syntaxTree, semanticModel, locatedQueries, precompilationErrors, cancellationToken);
+ if (generatedSyntaxTree is not null)
+ {
+ generatedSyntaxTrees.Add(generatedSyntaxTree);
+ }
+ }
+
+ return generatedSyntaxTrees;
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ protected virtual GeneratedInterceptorFile? ProcessSyntaxTreeAsync(
+ SyntaxTree syntaxTree,
+ SemanticModel semanticModel,
+ IReadOnlyList locatedQueries,
+ List precompilationErrors,
+ CancellationToken cancellationToken)
+ {
+ var queriesPrecompiledInFile = 0;
+ _namespaces.Clear();
+ _code.Clear();
+ _code
+ .AppendLine()
+ .AppendLine("#pragma warning disable EF9100 // Precompiled query is experimental")
+ .AppendLine()
+ .Append("namespace ").AppendLine(InterceptorsNamespace)
+ .AppendLine("{")
+ .IncrementIndent()
+ .AppendLine("file static class EntityFrameworkCoreInterceptors")
+ .AppendLine("{")
+ .IncrementIndent();
+
+ for (var queryNum = 0; queryNum < locatedQueries.Count; queryNum++)
+ {
+ var querySyntax = locatedQueries[queryNum];
+
+ try
+ {
+ // We have a query lambda, as a Roslyn syntax tree. Translate to LINQ expression tree.
+ // TODO: Add verification that this is an EF query over our user's context. If translation returns null the moment
+ // there's another query root (another context or another LINQ provider), that's fine.
+ if (_csharpToLinqTranslator.Translate(querySyntax, semanticModel) is not MethodCallExpression terminatingOperator)
+ {
+ throw new UnreachableException("Non-method call encountered as the root of a LINQ query");
+ }
+
+ // We have a LINQ representation of the query tree as it appears in the user's source code, but this isn't the same as the
+ // LINQ tree the EF query pipeline needs to get; the latter is the result of evaluating the queryable operators in the user's
+ // source code. For example, in the user's code the root is a DbSet as the root, but the expression tree we require needs to
+ // contain an EntityQueryRootExpression. To get the LINQ tree for EF, we need to evaluate the operator chain, building an
+ // expression tree as usual.
+
+ // However, we cannot evaluate the last operator, since that would execute the query instead of returning an expression tree.
+ // So we need to chop off the last operator before evaluation, and then (optionally) recompose it back afterwards.
+ // For ToList(), we don't actually recompose it (since ToList() isn't a node in the expression tree), and for async operators,
+ // we need to rewrite them to their sync counterparts (since that's what gets injected into the query tree).
+ var penultimateOperator = terminatingOperator switch
+ {
+ // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators)
+ { Object: Expression @object } => @object,
+ { Arguments: [var sourceArgument, ..] } => sourceArgument,
+ _ => throw new UnreachableException()
+ };
+
+ penultimateOperator = Expression.Lambda>(penultimateOperator)
+ .Compile(preferInterpretation: true)().Expression;
+
+ // Pass the query through EF's query pipeline; this returns the query's executor function, which can produce an enumerable
+ // that invokes the query.
+ // Note that we cannot recompose the terminating operator on top of the evaluated penultimate, since method signatures
+ // may not allow that (e.g. DbSet.AsAsyncEnumerable() requires a DbSet, but the evaluated value for a DbSet is
+ // EntityQueryRootExpression. So we handle the penultimate and the terminating separately.
+ var queryExecutor = CompileQuery(penultimateOperator, terminatingOperator);
+
+ // The query has been compiled successfully by the EF query pipeline.
+ // Now go over each LINQ operator, generating an interceptor for it.
+ _code.AppendLine($"#region Query{queryNum + 1}").AppendLine();
+
+ try
+ {
+ _funcletizer.ResetPathCalculation();
+
+ if (querySyntax is not { Expression: MemberAccessExpressionSyntax { Expression: var penultimateOperatorSyntax } })
+ {
+ throw new UnreachableException();
+ }
+
+ // Generate interceptors for all LINQ operators in the query, starting from the root up until the penultimate.
+ // Then generate the interceptor for the terminating operator, and finally the query's executor.
+ GenerateOperatorInterceptorsRecursively(
+ _code, penultimateOperator, penultimateOperatorSyntax, semanticModel, queryNum + 1, out var operatorNum,
+ cancellationToken: cancellationToken);
+
+ GenerateOperatorInterceptor(
+ _code, terminatingOperator, querySyntax, semanticModel, queryNum + 1, operatorNum + 1, isTerminatingOperator: true,
+ cancellationToken);
+
+ GenerateQueryExecutor(_code, queryNum + 1, queryExecutor, _namespaces, _unsafeAccessors);
+ }
+ finally
+ {
+ _code
+ .AppendLine()
+ .AppendLine($"#endregion Query{queryNum + 1}");
+ }
+ }
+ catch (Exception e)
+ {
+ precompilationErrors.Add(new(querySyntax, e));
+ continue;
+ }
+
+ // We're done generating the interceptors for the query's LINQ operators.
+
+ queriesPrecompiledInFile++;
+ }
+
+ if (queriesPrecompiledInFile == 0)
+ {
+ return null;
+ }
+
+ // Output all the unsafe accessors that were generated for all intercepted shapers, e.g.:
+ // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]
+ // static extern ref int GetSet_Foo_Name(Foo f);
+ if (_unsafeAccessors.Count > 0)
+ {
+ _code.AppendLine("#region Unsafe accessors");
+ foreach (var unsafeAccessor in _unsafeAccessors)
+ {
+ _code.AppendLine(unsafeAccessor.NormalizeWhitespace().ToFullString());
+ }
+ _code.AppendLine("#endregion Unsafe accessors");
+ }
+
+ _code
+ .DecrementIndent().AppendLine("}")
+ .DecrementIndent().AppendLine("}");
+
+ var mainCode = _code.ToString();
+
+ _code.Clear();
+ _code.AppendLine("// ").AppendLine();
+
+ foreach (var ns in _namespaces
+ // In addition to the namespaces auto-detected by LinqToCSharpTranslator, we manually add these namespaces which are required
+ // by manually generated code above.
+ .Append("System")
+ .Append("System.Collections.Concurrent")
+ .Append("System.Collections.Generic")
+ .Append("System.Linq")
+ .Append("System.Linq.Expressions")
+ .Append("System.Runtime.CompilerServices")
+ .Append("System.Reflection")
+ .Append("System.Threading.Tasks")
+ .Append("Microsoft.EntityFrameworkCore")
+ .Append("Microsoft.EntityFrameworkCore.ChangeTracking.Internal")
+ .Append("Microsoft.EntityFrameworkCore.Diagnostics")
+ .Append("Microsoft.EntityFrameworkCore.Infrastructure")
+ .Append("Microsoft.EntityFrameworkCore.Infrastructure.Internal")
+ .Append("Microsoft.EntityFrameworkCore.Internal")
+ .Append("Microsoft.EntityFrameworkCore.Metadata")
+ .Append("Microsoft.EntityFrameworkCore.Query")
+ .Append("Microsoft.EntityFrameworkCore.Query.Internal")
+ .Append("Microsoft.EntityFrameworkCore.Storage")
+ .OrderBy(
+ ns => ns switch
+ {
+ _ when ns.StartsWith("System.", StringComparison.Ordinal) => 10,
+ _ when ns.StartsWith("Microsoft.", StringComparison.Ordinal) => 9,
+ _ => 0
+ })
+ .ThenBy(ns => ns))
+ {
+ _code.Append("using ").Append(ns).AppendLine(";");
+ }
+
+ _code.AppendLine(mainCode);
+
+ _code.AppendLine(
+ """
+namespace System.Runtime.CompilerServices
+{
+ [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
+ sealed class InterceptsLocationAttribute : Attribute
+ {
+ public InterceptsLocationAttribute(string filePath, int line, int column) { }
+ }
+}
+""");
+
+ return new(
+ $"{Path.GetFileNameWithoutExtension(syntaxTree.FilePath)}.EFInterceptors.g{Path.GetExtension(syntaxTree.FilePath)}",
+ _code.ToString());
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ protected virtual Expression CompileQuery(Expression penultimateOperator, MethodCallExpression terminatingOperator)
+ {
+ // First, check whether this is an async query.
+ var async = terminatingOperator.Type.IsGenericType
+ && terminatingOperator.Type.GetGenericTypeDefinition() is var genericDefinition
+ && (genericDefinition == typeof(Task<>) || genericDefinition == typeof(ValueTask<>));
+
+ var preparedQuery = PrepareQueryForCompilation(penultimateOperator, terminatingOperator);
+
+ // We now need to figure out the return type of the query's executor.
+ // Non-scalar query expressions (e.g. ToList()) return an IQueryable; the query executor will return an enumerable (sync or async).
+ // Scalar query expressions just return the scalar type.
+ var returnType = preparedQuery.Type.IsGenericType
+ && preparedQuery.Type.GetGenericTypeDefinition().IsAssignableTo(typeof(IQueryable))
+ ? (async
+ ? typeof(IAsyncEnumerable<>)
+ : typeof(IEnumerable<>)).MakeGenericType(preparedQuery.Type.GetGenericArguments()[0])
+ : terminatingOperator.Type;
+
+ // We now have the query as a finalized LINQ expression tree, ready for compilation.
+ // Compile the query, invoking CompileQueryToExpression on the IQueryCompiler from the user's context instance.
+ try
+ {
+ return (Expression)_queryCompiler.GetType()
+ .GetMethod(nameof(IQueryCompiler.PrecompileQuery))!
+ .MakeGenericMethod(returnType)
+ .Invoke(_queryCompiler, [preparedQuery, async])!;
+ }
+ catch (TargetInvocationException e) when (e.InnerException is not null)
+ {
+ // Unwrap the TargetInvocationException wrapper we get from Invoke()
+ ExceptionDispatchInfo.Capture(e.InnerException).Throw();
+ throw;
+ }
+ }
+
+ private void GenerateOperatorInterceptorsRecursively(
+ IndentedStringBuilder code,
+ Expression operatorExpression,
+ ExpressionSyntax operatorSyntax,
+ SemanticModel semanticModel,
+ int queryNum,
+ out int operatorNum,
+ CancellationToken cancellationToken)
+ {
+ // For non-root operators, we get here with an InvocationExpressionSyntax and its corresponding LINQ MethodCallExpression.
+ // For the query root, we usually don't get called here: a regular EntityQueryRootExpression corresponds to a DbSet (either
+ // property access on DbContext or a Set<>() method invocation). We can't intercept property accesses, and in any case there's
+ // nothing to intercept there.
+ // However, for FromSql specifically, we get here with an InvocationExpressionSyntax (representing the FromSql() invocation), but
+ // with a corresponding FromSqlQueryRootExpression - not a MethodCallExpression. We must pass this query root through the
+ // funcletizer as usual to mimic the normal flow.
+ switch (operatorExpression)
+ {
+ // Regular, non-root LINQ operator; the LINQ method call must correspond to a Roslyn syntax invocation.
+ // We first recurse to handle the nested operator (i.e. generate the interceptor from the root outer).
+ case MethodCallExpression operatorMethodCall:
+ if (operatorSyntax is not InvocationExpressionSyntax
+ {
+ Expression: MemberAccessExpressionSyntax { Expression: var nestedOperatorSyntax }
+ })
+ {
+ throw new UnreachableException();
+ }
+
+ // We're an operator (not the query root).
+ // Continue recursing down - we want to handle from the root up.
+
+ var nestedOperatorExpression = operatorMethodCall switch
+ {
+ // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators)
+ { Object: Expression @object } => @object,
+ { Arguments: [var sourceArgument, ..] } => sourceArgument,
+ _ => throw new UnreachableException()
+ };
+
+ GenerateOperatorInterceptorsRecursively(
+ code, nestedOperatorExpression, nestedOperatorSyntax, semanticModel, queryNum, out operatorNum,
+ cancellationToken: cancellationToken);
+
+ operatorNum++;
+
+ GenerateOperatorInterceptor(
+ code, operatorExpression, operatorSyntax, semanticModel, queryNum, operatorNum, isTerminatingOperator: false,
+ cancellationToken);
+ return;
+
+ // For FromSql() queries, an InvocationExpressionSyntax (representing the FromSql() invocation), but with a corresponding
+ // FromSqlQueryRootExpression - not a MethodCallExpression.
+ // We must generate an interceptor for FromSql() and pass the arguments array through the funcletizer as usual.
+ case FromSqlQueryRootExpression:
+ operatorNum = 1;
+ GenerateOperatorInterceptor(
+ code, operatorExpression, operatorSyntax, semanticModel, queryNum, operatorNum, isTerminatingOperator: false,
+ cancellationToken);
+ return;
+
+ // For other query roots, we don't generate interceptors - there are no possible captured variables that need to be
+ // pass through funcletization (as with FromSqlQueryRootExpression). Simply return to process the first non-root operator.
+ case QueryRootExpression:
+ operatorNum = 0;
+ return;
+
+ default:
+ throw new UnreachableException();
+ }
+ }
+
+ private void GenerateOperatorInterceptor(
+ IndentedStringBuilder code,
+ Expression operatorExpression,
+ ExpressionSyntax operatorSyntax,
+ SemanticModel semanticModel,
+ int queryNum,
+ int operatorNum,
+ bool isTerminatingOperator,
+ CancellationToken cancellationToken)
+ {
+ // At this point we know we're intercepting a method call invocation.
+ // Extract the MemberAccessExpressionSyntax for the invocation, representing the method being called.
+ var memberAccessSyntax = (operatorSyntax as InvocationExpressionSyntax)?.Expression as MemberAccessExpressionSyntax
+ ?? throw new UnreachableException();
+
+ // Create the parameter list for our interceptor method from the LINQ operator method's parameter list
+ if (semanticModel.GetSymbolInfo(memberAccessSyntax, cancellationToken).Symbol is not IMethodSymbol operatorSymbol)
+ {
+ throw new InvalidOperationException("Couldn't find method symbol for: " + memberAccessSyntax);
+ }
+
+ // Throughout the code generation below, we will only be dealing with the original generic definition of the operator (and
+ // generating a generic interceptor); we'll never be dealing with the concrete types for this invocation, since these may
+ // be unspeakable anonymous types which we can't embed in generated code.
+ operatorSymbol = operatorSymbol.OriginalDefinition;
+
+ // For extension methods, this provides the form which has the "this" as its first parameter.
+ // TODO: Validate the below, throw informative (e.g. top-level TVF fails here because non-generic)
+ var reducedOperatorSymbol = operatorSymbol.GetConstructedReducedFrom() ?? operatorSymbol;
+
+ var (sourceVariableName, sourceTypeSymbol) = reducedOperatorSymbol.IsStatic
+ ? (reducedOperatorSymbol.Parameters[0].Name, reducedOperatorSymbol.Parameters[0].Type)
+ : ("source", reducedOperatorSymbol.ReceiverType!);
+
+ if (sourceTypeSymbol is not INamedTypeSymbol { TypeArguments: [var sourceElementTypeSymbol]})
+ {
+ throw new UnreachableException($"Non-IQueryable first parameter in LINQ operator '{operatorSymbol.Name}'");
+ }
+
+ var sourceElementTypeName = sourceElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+
+ var returnTypeSymbol = reducedOperatorSymbol.ReturnType;
+
+ // Unwrap Task to get the element type (e.g. Task>)
+ var returnTypeWithoutTask = returnTypeSymbol is INamedTypeSymbol namedReturnType
+ && returnTypeSymbol.OriginalDefinition.Equals(_symbols.GenericTask, SymbolEqualityComparer.Default)
+ ? namedReturnType.TypeArguments[0]
+ : returnTypeSymbol;
+
+ var returnElementTypeSymbol = returnTypeWithoutTask switch
+ {
+ IArrayTypeSymbol arrayTypeSymbol => arrayTypeSymbol.ElementType,
+ INamedTypeSymbol namedReturnType2
+ when namedReturnType2.AllInterfaces.Prepend(namedReturnType2)
+ .Any(
+ i => i.OriginalDefinition.Equals(_symbols.GenericEnumerable, SymbolEqualityComparer.Default)
+ || i.OriginalDefinition.Equals(_symbols.GenericAsyncEnumerable, SymbolEqualityComparer.Default)
+ || i.OriginalDefinition.Equals(_symbols.GenericEnumerator, SymbolEqualityComparer.Default))
+ => namedReturnType2.TypeArguments[0],
+ _ => null
+ };
+
+ // Output the interceptor method signature preceded by the [InterceptsLocation] attribute.
+ var startPosition = operatorSyntax.SyntaxTree.GetLineSpan(memberAccessSyntax.Name.Span, cancellationToken).StartLinePosition;
+ var interceptorName = $"Query{queryNum}_{memberAccessSyntax.Name}{operatorNum}";
+ code.AppendLine($"""[InterceptsLocation("{operatorSyntax.SyntaxTree.FilePath}", {startPosition.Line + 1}, {startPosition.Character + 1})]""");
+ GenerateInterceptorMethodSignature();
+ code.AppendLine("{").IncrementIndent();
+
+ // If this is the first query operator (no nested operator), cast the input source to IInfrastructure and extract the
+ // DbContext, create a new QueryContext, and wrap it all in a PrecompiledQueryContext that will flow through to the
+ // terminating operator, where the query will actually get executed.
+ // Otherwise, if this is a non-first operator, receive the PrecompiledQueryContext from the nested operator and flow it forward.
+ code.AppendLine(
+ "var precompiledQueryContext = "
+ + (operatorNum == 1
+ ? $"new PrecompiledQueryContext<{sourceElementTypeName}>(((IInfrastructure){sourceVariableName}).Instance);"
+ : $"(PrecompiledQueryContext<{sourceElementTypeName}>){sourceVariableName};"));
+
+ var declaredQueryContextVariable = false;
+
+ ProcessCapturedVariables();
+
+ if (isTerminatingOperator)
+ {
+ // We're intercepting the query's terminating operator - this is where the query actually gets executed.
+ if (!declaredQueryContextVariable)
+ {
+ code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;");
+ }
+
+ var executorFieldIdentifier = $"Query{queryNum}_Executor";
+ code.AppendLine(
+ $"{executorFieldIdentifier} ??= Query{queryNum}_GenerateExecutor(precompiledQueryContext.DbContext, precompiledQueryContext.QueryContext);");
+
+ if (returnElementTypeSymbol is null)
+ {
+ // The query returns a scalar, not an enumerable (e.g. the terminating operator is Max()).
+ // The executor directly returns the needed result (e.g. int), so just return that.
+ var returnType = returnTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ code.AppendLine($"return ((Func)({executorFieldIdentifier}))(queryContext);");
+ }
+ else
+ {
+ // The query returns an IEnumerable/IAsyncEnumerable/IQueryable, which is a bit trickier: the executor doesn't return a
+ // simple value as in the scalar case, but rather e.g. SingleQueryingEnumerable; we need to compose the terminating
+ // operator (e.g. ToList()) on top of that. Cast the executor delegate to Func>
+ // (contravariance).
+ var isAsync =
+ operatorExpression.Type.IsGenericType
+ && operatorExpression.Type.GetGenericTypeDefinition() is var genericDefinition
+ && (
+ genericDefinition == typeof(Task<>)
+ || genericDefinition == typeof(ValueTask<>)
+ || genericDefinition == typeof(IAsyncEnumerable<>));
+
+ var isQueryable = !isAsync
+ && operatorExpression.Type.IsGenericType
+ && operatorExpression.Type.GetGenericTypeDefinition() == typeof(IQueryable<>);
+
+ var returnValue = isAsync
+ ? $"IAsyncEnumerable<{sourceElementTypeName}>"
+ : $"IEnumerable<{sourceElementTypeName}>";
+
+ code.AppendLine(
+ $"var queryingEnumerable = ((Func)({executorFieldIdentifier}))(queryContext);");
+
+ if (isQueryable)
+ {
+ // If the terminating operator returns IQueryable, that means the query is actually evaluated via foreach
+ // (i.e. there's no method such as AsEnumerable/ToList which evaluates). Note that this is necessarily sync only -
+ // IQueryable can't be directly inside await foreach (AsAsyncEnumerable() is required).
+ // For this case, we need to compose AsQueryable() on top, to make the querying enumerable compatible with the
+ // operator signature.
+ code.AppendLine("return queryingEnumerable.AsQueryable();");
+ }
+ else
+ {
+ if (isAsync)
+ {
+ // For sync queries, we get an IEnumerable above, and can just compose the original terminating operator
+ // directly on top of that (ToList(), ToDictionary()...).
+ // But for async queries, we get an IAsyncEnumerable above, but cannot directly compose the original
+ // terminating operator (ToListAsync(), ToDictionaryAsync()...), since those require an IQueryable in their
+ // signature (which they internally case to IAsyncEnumerable).
+ // So we introduce an adapter in the middle, which implements both IQueryable (to be able to compose
+ // ToListAsync() on top), and IAsyncEnumerable (so that the actual implementation of ToListAsync() works).
+ // TODO: This is an additional runtime allocation; if we had System.Linq.Async we wouldn't need this. We could
+ // have additional versions of all async terminating operators over IAsyncEnumerable (effectively duplicating
+ // System.Linq.Async) as an alternative.
+ code.AppendLine($"var asyncQueryingEnumerable = new PrecompiledQueryableAsyncEnumerableAdapter<{sourceElementTypeName}>(queryingEnumerable);");
+ code.Append("return asyncQueryingEnumerable");
+ }
+ else
+ {
+ code.Append("return queryingEnumerable");
+ }
+
+ // Invoke the original terminating operator (e.g. ToList(), ToDictionary()...) on the querying enumerable, passing
+ // through the interceptor's arguments.
+ code.AppendLine(
+ $".{memberAccessSyntax.Name}({string.Join(", ", operatorSymbol.Parameters.Select(p => p.Name))});");
+ }
+ }
+ }
+ else
+ {
+ // Non-terminating operator - we need to flow precompiledQueryContext forward.
+
+ // The operator returns a different IQueryable type as its source (e.g. Select), convert the precompiledQueryContext
+ // before returning it.
+ Check.DebugAssert(returnElementTypeSymbol is not null, "Non-terminating operator must return IEnumerable");
+
+ code.AppendLine(
+ returnTypeSymbol switch
+ {
+ // The operator return IQueryable or IOrderedQueryable.
+ // If T is the same as the source, simply return our context as is (note that PrecompiledQueryContext implements
+ // IOrderedQueryable). Otherwise, e.g. Select() is being applied - change the context's type.
+ _ when returnTypeSymbol.OriginalDefinition.Equals(_symbols.IQueryable, SymbolEqualityComparer.Default)
+ || returnTypeSymbol.OriginalDefinition.Equals(_symbols.IOrderedQueryable, SymbolEqualityComparer.Default)
+ => SymbolEqualityComparer.Default.Equals(sourceElementTypeSymbol, returnElementTypeSymbol)
+ ? "return precompiledQueryContext;"
+ : $"return precompiledQueryContext.ToType<{returnElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();",
+
+ _ when returnTypeSymbol.OriginalDefinition.Equals(_symbols.IIncludableQueryable, SymbolEqualityComparer.Default)
+ && returnTypeSymbol is INamedTypeSymbol { OriginalDefinition.TypeArguments: [_, var includedPropertySymbol] }
+ => $"return precompiledQueryContext.ToIncludable<{includedPropertySymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();",
+
+ _ => throw new UnreachableException()
+ });
+ }
+
+ code.DecrementIndent().AppendLine("}").AppendLine();
+
+ void GenerateInterceptorMethodSignature()
+ {
+ code
+ .Append("public static ")
+ .Append(_g.TypeExpression(reducedOperatorSymbol.ReturnType).ToFullString())
+ .Append(' ')
+ .Append(interceptorName);
+
+ var (typeParameters, constraints) = (reducedOperatorSymbol.IsGenericMethod, reducedOperatorSymbol.ContainingType.IsGenericType) switch
+ {
+ (true, false) => (reducedOperatorSymbol.TypeParameters, ((MethodDeclarationSyntax)_g.MethodDeclaration(reducedOperatorSymbol)).ConstraintClauses),
+ (false, true) => (reducedOperatorSymbol.ContainingType.TypeParameters, ((TypeDeclarationSyntax)_g.Declaration(reducedOperatorSymbol.ContainingType)).ConstraintClauses),
+ (false, false) => ([], []),
+ (true, true) => throw new NotImplementedException("Generic method on generic type not supported")
+ };
+
+ if (typeParameters.Length > 0)
+ {
+ code.Append('<');
+ for (var i = 0; i < typeParameters.Length; i++)
+ {
+ if (i > 0)
+ {
+ code.Append(", ");
+ }
+
+ code.Append(_g.TypeExpression(typeParameters[i]).ToFullString());
+ }
+
+ code.Append('>');
+ }
+
+ code.Append('(');
+
+ // For instance methods (IEnumerable.GetEnumerator(), DbSet.GetAsyncEnumerable()...), we generate an extension method
+ // (with this) for the interceptor.
+ if (reducedOperatorSymbol is { IsStatic: false, ReceiverType: not null })
+ {
+ code
+ .Append("this ")
+ .Append(_g.TypeExpression(reducedOperatorSymbol.ReceiverType).ToFullString())
+ .Append(' ')
+ .Append(sourceVariableName);
+ }
+
+ for (var i = 0; i < reducedOperatorSymbol.Parameters.Length; i++)
+ {
+ var parameter = reducedOperatorSymbol.Parameters[i];
+
+ if (i == 0)
+ {
+ switch (reducedOperatorSymbol)
+ {
+ case { IsExtensionMethod: true }:
+ code.Append("this ");
+ break;
+
+ // For instance methods we already added a this parameter above
+ case { IsStatic: false, ReceiverType: not null }:
+ code.Append(", ");
+ break;
+
+ default:
+ throw new NotImplementedException("Non-extension static method not supported");
+ }
+ }
+ else
+ {
+ code.Append(", ");
+ }
+
+ code
+ .Append(_g.TypeExpression(parameter.Type).ToFullString())
+ .Append(' ')
+ .Append(parameter.Name);
+ }
+
+ code.AppendLine(")");
+
+ foreach (var f in constraints)
+ {
+ code.AppendLine(f.NormalizeWhitespace().ToFullString());
+ }
+ }
+
+ void ProcessCapturedVariables()
+ {
+ // Go over the operator's arguments (skipping the first, which is the source).
+ // For those which have captured variables, run them through our funcletizer, which will return code for extracting any captured
+ // variables from them.
+ switch (operatorExpression)
+ {
+ // Regular case: this is an operator method
+ case MethodCallExpression operatorMethodCall:
+ {
+ var parameters = operatorMethodCall.Method.GetParameters();
+
+ for (var i = 1; i < parameters.Length; i++)
+ {
+ var parameter = parameters[i];
+
+ if (parameter.ParameterType == typeof(CancellationToken))
+ {
+ continue;
+ }
+
+ if (_funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i) is not ExpressionTreeFuncletizer.PathNode
+ evaluatableRootPaths)
+ {
+ // There are no captured variables in this lambda argument - skip the argument
+ continue;
+ }
+
+ // We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code
+ // which extracts them and sets them on our query context.
+ if (!declaredQueryContextVariable)
+ {
+ code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;");
+ declaredQueryContextVariable = true;
+ }
+
+ if (!parameter.ParameterType.IsSubclassOf(typeof(Expression)))
+ {
+ // Special case: this is a non-lambda argument (Skip/Take/FromSql).
+ // Simply add the argument directly as a parameter
+ code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameter.Name});""");
+ continue;
+ }
+
+ var variableCounter = 0;
+
+ // Lambda argument. Recurse through evaluatable path trees.
+ foreach (var child in evaluatableRootPaths.Children!)
+ {
+ GenerateCapturedVariableExtractors(parameter.Name!, parameter.ParameterType, child);
+
+ void GenerateCapturedVariableExtractors(
+ string currentIdentifier,
+ Type currentType,
+ ExpressionTreeFuncletizer.PathNode capturedVariablesPathTree)
+ {
+ var linqPathSegment =
+ capturedVariablesPathTree.PathFromParent!(Expression.Parameter(currentType, currentIdentifier));
+ var collectedNamespaces = new HashSet();
+ var unsafeAccessors = new HashSet();
+ var roslynPathSegment = _linqToCSharpTranslator.TranslateExpression(
+ linqPathSegment, constantReplacements: null, collectedNamespaces, unsafeAccessors);
+
+ var variableName = capturedVariablesPathTree.ExpressionType.Name;
+ variableName = char.ToLower(variableName[0]) + variableName[1..^"Expression".Length] + ++variableCounter;
+ code.AppendLine(
+ $"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");
+
+ if (capturedVariablesPathTree.Children?.Count > 0)
+ {
+ // This is an intermediate node which has captured variables in the children. Continue recursing down.
+ foreach (var child in capturedVariablesPathTree.Children)
+ {
+ GenerateCapturedVariableExtractors(variableName, capturedVariablesPathTree.ExpressionType, child);
+ }
+
+ return;
+ }
+
+ // We've reached a leaf, meaning that it's an evaluatable node that contains captured variables.
+ // Generate code to evaluate this node and assign the result to the parameters dictionary:
+
+ // TODO: For the common case of a simple parameter (member access over closure type), generate reflection code directly
+ // TODO: instead of going through the interpreter, as we do in the funcletizer itself (for perf)
+ // TODO: Remove the convert to object. We can flow out the actual type of the evaluatable root, and just stick it
+ // in Func<> instead of object.
+ // TODO: For specific cases, don't go through the interpreter, but just integrate code that extracts the value directly.
+ // (see ExpressionTreeFuncletizer.Evaluate()).
+ // TODO: Basically this means that the evaluator should come from ExpressionTreeFuncletizer itself, as part of its outputs
+ // TODO: Integrate try/catch around the evaluation?
+ code.AppendLine("queryContext.AddParameter(");
+ using (code.Indent())
+ {
+ code
+ .Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",")
+ .AppendLine($"Expression.Lambda>(Expression.Convert({variableName}, typeof(object)))")
+ .AppendLine(".Compile(preferInterpretation: true)")
+ .AppendLine(".Invoke());");
+ }
+ }
+ }
+ }
+
+ break;
+ }
+
+ // Special case: this is a FromSql query root; we're intercepting the invocation syntax for the FromSql() call, but on the LINQ
+ // side we have a query root (i.e. not the MethodCallExpression for the FromSql(), but rather its evaluated result)
+ case FromSqlQueryRootExpression fromSqlQueryRoot:
+ {
+ if (_funcletizer.CalculatePathsToEvaluatableRoots(fromSqlQueryRoot.Argument) is not ExpressionTreeFuncletizer.PathNode
+ evaluatableRootPaths)
+ {
+ // There are no captured variables in this FromSqlQueryRootExpression, skip it.
+ break;
+ }
+
+ // We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code
+ // which extracts them and sets them on our query context.
+ if (!declaredQueryContextVariable)
+ {
+ code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;");
+ declaredQueryContextVariable = true;
+ }
+
+ var argumentsParameter = reducedOperatorSymbol switch
+ {
+ { Name: "FromSqlRaw", Parameters: [_, _, { Name: "parameters" }] } => "parameters",
+ { Name: "FromSql", Parameters: [_, { Name: "sql" }] } => "sql.GetArguments()",
+ { Name: "FromSqlInterpolated", Parameters: [_, { Name: "sql" }] } => "sql.GetArguments()",
+ _ => throw new UnreachableException()
+ };
+
+ code.AppendLine(
+ $"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {argumentsParameter});""");
+
+ break;
+ }
+
+ default:
+ throw new UnreachableException();
+ }
+ }
+ }
+
+ private void GenerateQueryExecutor(
+ IndentedStringBuilder code,
+ int queryNum,
+ Expression queryExecutor,
+ HashSet namespaces,
+ HashSet unsafeAccessors)
+ {
+ // We're going to generate the method which will create the query executor (Func).
+ // Note that the we store the executor itself (and return it) as object, not as a typed Func.
+ // We can't strong-type it since it may return an anonymous type, which is unspeakable; so instead we cast down from object to
+ // the real strongly-typed signature inside the interceptor, where the return value is represented as a generic type parameter
+ // (which can be an anonymous type).
+ code
+ .AppendLine($"private static object Query{queryNum}_GenerateExecutor(DbContext dbContext, QueryContext queryContext)")
+ .AppendLine("{")
+ .IncrementIndent()
+ .AppendLine("var relationalModel = dbContext.Model.GetRelationalModel();")
+ .AppendLine("var relationalTypeMappingSource = dbContext.GetService();")
+ .AppendLine("var materializerLiftableConstantContext = new RelationalMaterializerLiftableConstantContext(dbContext.GetService(), dbContext.GetService());");
+
+ HashSet variableNames = ["relationalModel", "relationalTypeMappingSource", "materializerLiftableConstantContext"];
+
+ var materializerLiftableConstantContext =
+ Expression.Parameter(typeof(RelationalMaterializerLiftableConstantContext), "materializerLiftableConstantContext");
+
+ // The materializer expression tree contains LiftedConstantExpression nodes, which contain instructions on how to resolve
+ // constant values which need to be lifted.
+ var queryExecutorAfterLiftingExpression =
+ _liftableConstantProcessor.LiftConstants(queryExecutor, materializerLiftableConstantContext, variableNames);
+
+ foreach (var liftedConstant in _liftableConstantProcessor.LiftedConstants)
+ {
+ var variableValueSyntax = _linqToCSharpTranslator.TranslateExpression(
+ liftedConstant.Expression, constantReplacements: null, namespaces, unsafeAccessors);
+ // code.AppendLine($"{liftedConstant.Parameter.Type.Name} {liftedConstant.Parameter.Name} = {variableValueSyntax.NormalizeWhitespace().ToFullString()};");
+ code.AppendLine($"var {liftedConstant.Parameter.Name} = {variableValueSyntax.NormalizeWhitespace().ToFullString()};");
+ }
+
+ var queryExecutorSyntaxTree =
+ (AnonymousFunctionExpressionSyntax)_linqToCSharpTranslator.TranslateExpression(
+ queryExecutorAfterLiftingExpression,
+ constantReplacements: null,
+ namespaces,
+ unsafeAccessors);
+
+ code
+ .AppendLine($"return {queryExecutorSyntaxTree.NormalizeWhitespace().ToFullString()};")
+ .DecrementIndent()
+ .AppendLine("}")
+ .AppendLine()
+ .AppendLine($"private static object Query{queryNum}_Executor;");
+ }
+
+ ///
+ /// Performs processing of a query's terminating operator before handing the query off for EF compilation.
+ /// This involves removing the operator when it shouldn't be in the tree (e.g. ToList()), and rewriting async terminating operators
+ /// to their sync counterparts (e.g. MaxAsync() -> Max()). This only needs to be modified/overridden if a new terminating operator
+ /// is introduced which needs to be rewritten.
+ ///
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ protected virtual Expression PrepareQueryForCompilation(Expression penultimateOperator, MethodCallExpression terminatingOperator)
+ {
+ var method = terminatingOperator.Method;
+
+ return method.Name switch
+ {
+ // These sync terminating operators are defined over IEnumerable, and don't inject a node into the query tree. Simply remove them.
+ nameof(Enumerable.AsEnumerable)
+ or nameof(Enumerable.ToArray)
+ or nameof(Enumerable.ToDictionary)
+ or nameof(Enumerable.ToHashSet)
+ or nameof(Enumerable.ToLookup)
+ or nameof(Enumerable.ToList)
+ when method.DeclaringType == typeof(Enumerable)
+ => penultimateOperator,
+
+ nameof(IEnumerable.GetEnumerator)
+ when method.DeclaringType is { IsConstructedGenericType: true } declaringType
+ && declaringType.GetGenericTypeDefinition() == typeof(IEnumerable<>)
+ => penultimateOperator,
+
+ // Async ToListAsync, ToArrayAsync and AsAsyncEnumerable don't inject a node into the query tree - remove these as well.
+ nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable)
+ or nameof(EntityFrameworkQueryableExtensions.ToArrayAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToDictionaryAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToHashSetAsync)
+ // or nameof(EntityFrameworkQueryableExtensions.ToLookupAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToListAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
+ => penultimateOperator,
+
+ // There's also an instance method version of AsAsyncEnumerable on DbSet, remove that as well.
+ nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable)
+ when method.DeclaringType?.IsConstructedGenericType == true
+ && method.DeclaringType.GetGenericTypeDefinition() == typeof(DbSet<>)
+ => penultimateOperator,
+
+ // The EF async counterparts to all the standard scalar-returning terminating operators. These need to be rewritten, as they
+ // inject the sync versions into the query tree.
+ nameof(EntityFrameworkQueryableExtensions.AllAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
+ => RewriteToSync(QueryableMethods.All),
+ nameof(EntityFrameworkQueryableExtensions.AnyAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.AnyWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.AnyAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.AnyWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.AverageAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(
+ QueryableMethods.GetAverageWithoutSelector(method.GetParameters()[0].ParameterType.GenericTypeArguments[0])),
+ nameof(EntityFrameworkQueryableExtensions.AverageAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(
+ QueryableMethods.GetAverageWithSelector(
+ method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])),
+ nameof(EntityFrameworkQueryableExtensions.ContainsAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
+ => RewriteToSync(QueryableMethods.Contains),
+ nameof(EntityFrameworkQueryableExtensions.CountAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.CountWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.CountAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.CountWithPredicate),
+ // nameof(EntityFrameworkQueryableExtensions.DefaultIfEmptyAsync)
+ nameof(EntityFrameworkQueryableExtensions.ElementAtAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
+ => RewriteToSync(QueryableMethods.ElementAt),
+ nameof(EntityFrameworkQueryableExtensions.ElementAtOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
+ => RewriteToSync(QueryableMethods.ElementAtOrDefault),
+ nameof(EntityFrameworkQueryableExtensions.FirstAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.FirstWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.FirstAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.FirstWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.FirstOrDefaultWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.FirstOrDefaultWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LastAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.LastWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LastAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.LastWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.LastOrDefaultWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.LastOrDefaultWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LongCountAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.LongCountWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.LongCountAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.LongCountWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.MaxAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.MaxWithoutSelector),
+ nameof(EntityFrameworkQueryableExtensions.MaxAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.MaxWithSelector),
+ nameof(EntityFrameworkQueryableExtensions.MinAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.MinWithoutSelector),
+ nameof(EntityFrameworkQueryableExtensions.MinAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.MinWithSelector),
+ // nameof(EntityFrameworkQueryableExtensions.MaxByAsync)
+ // nameof(EntityFrameworkQueryableExtensions.MinByAsync)
+ nameof(EntityFrameworkQueryableExtensions.SingleAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.SingleWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.SingleAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.SingleWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.SingleOrDefaultWithoutPredicate),
+ nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(QueryableMethods.SingleOrDefaultWithPredicate),
+ nameof(EntityFrameworkQueryableExtensions.SumAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2
+ => RewriteToSync(QueryableMethods.GetSumWithoutSelector(method.GetParameters()[0].ParameterType.GenericTypeArguments[0])),
+ nameof(EntityFrameworkQueryableExtensions.SumAsync)
+ when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3
+ => RewriteToSync(
+ QueryableMethods.GetSumWithSelector(
+ method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])),
+
+ // ExecuteDelete/Update behave just like other scalar-returning operators
+ nameof(RelationalQueryableExtensions.ExecuteDeleteAsync) when method.DeclaringType == typeof(RelationalQueryableExtensions)
+ => RewriteToSync(typeof(RelationalQueryableExtensions).GetMethod(nameof(RelationalQueryableExtensions.ExecuteDelete))),
+ nameof(RelationalQueryableExtensions.ExecuteUpdateAsync) when method.DeclaringType == typeof(RelationalQueryableExtensions)
+ => RewriteToSync(typeof(RelationalQueryableExtensions).GetMethod(nameof(RelationalQueryableExtensions.ExecuteUpdate))),
+
+ // In the regular case (sync terminating operator which needs to stay in the query tree), simply compose the terminating
+ // operator over the penultimate and return that.
+ _ => terminatingOperator switch
+ {
+ // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators)
+ { Object: Expression }
+ => terminatingOperator.Update(penultimateOperator, terminatingOperator.Arguments),
+ { Arguments: [_, ..] }
+ => terminatingOperator.Update(@object: null, [penultimateOperator, .. terminatingOperator.Arguments.Skip(1)]),
+ _ => throw new UnreachableException()
+ }
+ };
+
+ MethodCallExpression RewriteToSync(MethodInfo? syncMethod)
+ {
+ if (syncMethod is null)
+ {
+ throw new UnreachableException($"Could find replacement method for {method.Name}");
+ }
+
+ if (syncMethod.IsGenericMethodDefinition)
+ {
+ syncMethod = syncMethod.MakeGenericMethod(method.GetGenericArguments());
+ }
+
+ // Replace the first argument with the penultimate argument, and chop off the CancellationToken argument
+ Expression[] syncArguments =
+ [penultimateOperator, .. terminatingOperator.Arguments.Skip(1).Take(terminatingOperator.Arguments.Count - 2)];
+
+ return Expression.Call(terminatingOperator.Object, syncMethod, syncArguments);
+ }
+ }
+
+ ///
+ /// Contains information on a failure to precompile a specific query in the user's source code.
+ /// Includes information about the query, its location, and the exception that occured.
+ ///
+ public sealed record QueryPrecompilationError(SyntaxNode SyntaxNode, Exception Exception);
+
+ private readonly struct Symbols
+ {
+ private readonly Compilation _compilation;
+
+ // ReSharper disable InconsistentNaming
+ public readonly INamedTypeSymbol GenericEnumerable;
+ public readonly INamedTypeSymbol GenericAsyncEnumerable;
+ public readonly INamedTypeSymbol GenericEnumerator;
+ public readonly INamedTypeSymbol IQueryable;
+ public readonly INamedTypeSymbol IOrderedQueryable;
+ public readonly INamedTypeSymbol IIncludableQueryable;
+ public readonly INamedTypeSymbol GenericTask;
+ // ReSharper restore InconsistentNaming
+
+ private Symbols(Compilation compilation)
+ {
+ _compilation = compilation;
+
+ GenericEnumerable =
+ GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerable`1");
+ GenericAsyncEnumerable =
+ GetTypeSymbolOrThrow("System.Collections.Generic.IAsyncEnumerable`1");
+ GenericEnumerator =
+ GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerator`1");
+ IQueryable =
+ GetTypeSymbolOrThrow("System.Linq.IQueryable`1");
+ IOrderedQueryable =
+ GetTypeSymbolOrThrow("System.Linq.IOrderedQueryable`1");
+ IIncludableQueryable =
+ GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.Query.IIncludableQueryable`2");
+ GenericTask =
+ GetTypeSymbolOrThrow("System.Threading.Tasks.Task`1");
+ }
+
+ public static Symbols Load(Compilation compilation)
+ => new(compilation);
+
+ private INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName)
+ => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName)
+ ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName);
+ }
+
+ ///
+ /// A generated file containing LINQ operator interceptors.
+ ///
+ /// The path of the generated file.
+ /// The code of the generated file.
+ public sealed record GeneratedInterceptorFile(string Path, string Code);
+}
diff --git a/src/EFCore.Design/Query/Internal/QueryLocator.cs b/src/EFCore.Design/Query/Internal/QueryLocator.cs
new file mode 100644
index 00000000000..3270d0c74e9
--- /dev/null
+++ b/src/EFCore.Design/Query/Internal/QueryLocator.cs
@@ -0,0 +1,371 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.EntityFrameworkCore.Internal;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal;
+
+///
+/// Statically analyzes user code and locates EF LINQ queries within it, by identifying well-known terminating operators
+/// (e.g. ToList , Single ).
+///
+///
+/// After a is loaded via , is called repeatedly
+/// for all syntax trees in the compilation.
+///
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+public class QueryLocator : CSharpSyntaxWalker
+{
+ private Compilation? _compilation;
+ private Symbols _symbols;
+
+ private SemanticModel _semanticModel = null!;
+ private CancellationToken _cancellationToken;
+ private List _locatedQueries = null!;
+ private List _precompilationErrors = null!;
+
+
+ ///
+ /// Loads a new , representing a user project in which to locate queries.
+ ///
+ /// A representing a user project.
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual void Initialize(Compilation compilation)
+ {
+ _compilation = compilation;
+ _symbols = Symbols.Load(compilation);
+ }
+
+ ///
+ /// Locates EF LINQ queries within the given , which represents user code.
+ ///
+ /// A in which to locate EF LINQ queries.
+ ///
+ /// A list of errors populated with dynamic LINQ queries detected in .
+ ///
+ /// A to observe while waiting for the task to complete.
+ /// A list of EF LINQ queries confirmed to be compatible with precompilation.
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual IReadOnlyList LocateQueries(
+ SyntaxTree syntaxTree,
+ List precompilationErrors,
+ CancellationToken cancellationToken = default)
+ {
+ if (_compilation is null)
+ {
+ throw new InvalidOperationException("A compilation must be loaded.");
+ }
+
+ if (!_compilation.SyntaxTrees.Contains(syntaxTree))
+ {
+ throw new ArgumentException("");
+ }
+
+ _cancellationToken = cancellationToken;
+ _semanticModel = _compilation.GetSemanticModel(syntaxTree);
+ _locatedQueries = new();
+ _precompilationErrors = precompilationErrors;
+ Visit(syntaxTree.GetRoot(cancellationToken));
+
+ return _locatedQueries;
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override void VisitInvocationExpression(InvocationExpressionSyntax invocation)
+ {
+ // TODO: Support non-extension invocation syntax: var blogs = ToList(ctx.Blogs);
+ if (invocation.Expression is MemberAccessExpressionSyntax
+ {
+ Name: IdentifierNameSyntax { Identifier.Text: var identifier },
+ Expression: var innerExpression
+ })
+ {
+ // First, pattern-match on the method name as a string; this avoids accessing the semantic model for each and
+ // every invocation (more efficient).
+ switch (identifier)
+ {
+ // These sync terminating operators exist exist over IEnumerable only, so verify the actual argument is an IQueryable (otherwise
+ // this is just LINQ to Objects)
+ case nameof(Enumerable.AsEnumerable) or nameof(Enumerable.ToArray) or nameof(Enumerable.ToDictionary)
+ or nameof(Enumerable.ToHashSet) or nameof(Enumerable.ToLookup) or nameof(Enumerable.ToList)
+ when IsOnEnumerable() && IsQueryable(innerExpression):
+
+ case nameof(IEnumerable.GetEnumerator)
+ when IsOnIEnumerable() && IsQueryable(innerExpression):
+
+ // The async terminating operators are defined by EF, and accept an IQueryable - no need to look at the argument.
+ case nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable)
+ or nameof(EntityFrameworkQueryableExtensions.ToArrayAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToDictionaryAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToHashSetAsync)
+ // or nameof(EntityFrameworkQueryableExtensions.ToLookupAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ToListAsync)
+ when IsOnEfQueryableExtensions():
+
+ case nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable)
+ when IsOnEfQueryableExtensions() || IsOnTypeSymbol(_symbols.DbSet):
+
+ case nameof(Queryable.All)
+ or nameof(Queryable.Any)
+ or nameof(Queryable.Average)
+ or nameof(Queryable.Contains)
+ or nameof(Queryable.Count)
+ or nameof(Queryable.DefaultIfEmpty)
+ or nameof(Queryable.ElementAt)
+ or nameof(Queryable.ElementAtOrDefault)
+ or nameof(Queryable.First)
+ or nameof(Queryable.FirstOrDefault)
+ or nameof(Queryable.Last)
+ or nameof(Queryable.LastOrDefault)
+ or nameof(Queryable.LongCount)
+ or nameof(Queryable.Max)
+ or nameof(Queryable.MaxBy)
+ or nameof(Queryable.Min)
+ or nameof(Queryable.MinBy)
+ or nameof(Queryable.Single)
+ or nameof(Queryable.SingleOrDefault)
+ or nameof(Queryable.Sum)
+ when IsOnQueryable():
+
+ case nameof(EntityFrameworkQueryableExtensions.AllAsync)
+ or nameof(EntityFrameworkQueryableExtensions.AnyAsync)
+ or nameof(EntityFrameworkQueryableExtensions.AverageAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ContainsAsync)
+ or nameof(EntityFrameworkQueryableExtensions.CountAsync)
+ // or nameof(EntityFrameworkQueryableExtensions.DefaultIfEmptyAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ElementAtAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ElementAtOrDefaultAsync)
+ or nameof(EntityFrameworkQueryableExtensions.FirstAsync)
+ or nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync)
+ or nameof(EntityFrameworkQueryableExtensions.LastAsync)
+ or nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync)
+ or nameof(EntityFrameworkQueryableExtensions.LongCountAsync)
+ or nameof(EntityFrameworkQueryableExtensions.MaxAsync)
+ // or nameof(EntityFrameworkQueryableExtensions.MaxByAsync)
+ or nameof(EntityFrameworkQueryableExtensions.MinAsync)
+ // or nameof(EntityFrameworkQueryableExtensions.MinByAsync)
+ or nameof(EntityFrameworkQueryableExtensions.SingleAsync)
+ or nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync)
+ or nameof(EntityFrameworkQueryableExtensions.SumAsync)
+ or nameof(EntityFrameworkQueryableExtensions.ForEachAsync)
+ when IsOnEfQueryableExtensions():
+
+ case nameof(RelationalQueryableExtensions.ExecuteDelete)
+ or nameof(RelationalQueryableExtensions.ExecuteUpdate)
+ or nameof(RelationalQueryableExtensions.ExecuteDeleteAsync)
+ or nameof(RelationalQueryableExtensions.ExecuteUpdateAsync)
+ when IsOnEfRelationalQueryableExtensions():
+ if (ProcessQueryCandidate(invocation))
+ {
+ return;
+ }
+
+ break;
+ }
+ }
+
+ base.VisitInvocationExpression(invocation);
+
+ bool IsOnEnumerable()
+ => IsOnTypeSymbol(_symbols.Enumerable);
+
+ bool IsOnIEnumerable()
+ => IsOnTypeSymbol(_symbols.IEnumerableOfT);
+
+ bool IsOnQueryable()
+ => IsOnTypeSymbol(_symbols.Queryable);
+
+ bool IsOnEfQueryableExtensions()
+ => IsOnTypeSymbol(_symbols.EfQueryableExtensions);
+
+ bool IsOnEfRelationalQueryableExtensions()
+ => IsOnTypeSymbol(_symbols.EfRelationalQueryableExtensions);
+
+ bool IsOnTypeSymbol(ITypeSymbol typeSymbol)
+ => _semanticModel.GetSymbolInfo(invocation, _cancellationToken).Symbol is IMethodSymbol methodSymbol
+ && methodSymbol.ContainingType.OriginalDefinition.Equals(typeSymbol, SymbolEqualityComparer.Default);
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override void VisitForEachStatement(ForEachStatementSyntax forEach)
+ {
+ // Note: a LINQ queryable can't be placed directly inside await foreach, since IQueryable does not extend
+ // IAsyncEnumerable. So users need to add our AsAsyncEnumerable, which is detected above as a normal invocation.
+
+ // C# interceptors can (currently) intercept only method calls, not property accesses; this means that we can't
+ // TODO: Support DbSet() method call directly inside foreach/await foreach
+ // TODO: Warn for DbSet property access directly inside foreach (can't be intercepted so not supported)
+ if (forEach.Expression is InvocationExpressionSyntax invocation
+ && IsQueryable(invocation)
+ && ProcessQueryCandidate(invocation))
+ {
+ return;
+ }
+
+ base.VisitForEachStatement(forEach);
+ }
+
+ private bool ProcessQueryCandidate(InvocationExpressionSyntax query)
+ {
+ // TODO: Carefully think about exactly what kind of verification we want to do here: static/non-static, actually get the
+ // TODO: method symbols and confirm it's an IQueryable flowing all the way through, etc.
+ // TODO: Move this code out, for reuse in the inner loop source generator
+
+ // Work backwards through the LINQ operator chain until we reach something that isn't a method invocation
+ ExpressionSyntax operatorSyntax = query;
+ while (operatorSyntax is InvocationExpressionSyntax
+ {
+ Expression: MemberAccessExpressionSyntax { Expression: var innerExpression }
+ })
+ {
+ if (innerExpression is QueryExpressionSyntax or ParenthesizedExpressionSyntax { Expression: QueryExpressionSyntax })
+ {
+ _precompilationErrors.Add(
+ new(query, new InvalidOperationException(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries)));
+ return false;
+ }
+
+ operatorSyntax = innerExpression;
+ }
+
+ // We've reached a non-invocation.
+
+ // First, check if this is a property access for a DbSet
+ if (operatorSyntax is MemberAccessExpressionSyntax { Expression: var innerExpression2 }
+ && IsDbContext(innerExpression2))
+ {
+ _locatedQueries.Add(query);
+
+ // TODO: Check symbol for DbSet?
+ return true;
+ }
+
+ // If we had context.Set(), the Set() method was skipped like any other method, and we're on the context.
+ if (IsDbContext(operatorSyntax))
+ {
+ _locatedQueries.Add(query);
+ return true;
+ }
+
+ _precompilationErrors.Add(new(query, new InvalidOperationException(DesignStrings.DynamicQueryNotSupported)));
+ return false;
+
+ bool IsDbContext(ExpressionSyntax expression)
+ {
+ return _semanticModel.GetSymbolInfo(expression, _cancellationToken).Symbol switch
+ {
+ ILocalSymbol localSymbol => IsDbContextType(localSymbol.Type),
+ IPropertySymbol propertySymbol => IsDbContextType(propertySymbol.Type),
+ IFieldSymbol fieldSymbol => IsDbContextType(fieldSymbol.Type),
+ IMethodSymbol methodSymbol => IsDbContextType(methodSymbol.ReturnType),
+ _ => false
+ };
+
+ bool IsDbContextType(ITypeSymbol typeSymbol)
+ {
+ while (true)
+ {
+ // TODO: Check for the user's specific DbContext type
+ if (typeSymbol.Equals(_symbols.DbContext, SymbolEqualityComparer.Default))
+ {
+ return true;
+ }
+
+ if (typeSymbol.BaseType is null)
+ {
+ return false;
+ }
+
+ typeSymbol = typeSymbol.BaseType;
+ }
+ }
+ }
+ }
+
+ private bool IsQueryable(ExpressionSyntax expression)
+ => _semanticModel.GetSymbolInfo(expression, _cancellationToken).Symbol switch
+ {
+ IMethodSymbol methodSymbol
+ => methodSymbol.ReturnType.OriginalDefinition.Equals(_symbols.IQueryableOfT, SymbolEqualityComparer.Default)
+ || methodSymbol.ReturnType.OriginalDefinition.AllInterfaces
+ .Contains(_symbols.IQueryable, SymbolEqualityComparer.Default),
+
+ IPropertySymbol propertySymbol => IsDbSet(propertySymbol.Type),
+
+ _ => false
+ };
+
+ // TODO: Handle DbSet subclasses which aren't InternalDbSet?
+ private bool IsDbSet(ITypeSymbol typeSymbol)
+ => SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _symbols.DbSet);
+
+ private readonly struct Symbols
+ {
+ private readonly Compilation _compilation;
+
+ // ReSharper disable InconsistentNaming
+ public readonly INamedTypeSymbol IQueryableOfT;
+ public readonly INamedTypeSymbol IQueryable;
+ public readonly INamedTypeSymbol DbContext;
+ public readonly INamedTypeSymbol DbSet;
+
+ public readonly INamedTypeSymbol Enumerable;
+ public readonly INamedTypeSymbol IEnumerableOfT;
+ public readonly INamedTypeSymbol Queryable;
+ public readonly INamedTypeSymbol EfQueryableExtensions;
+ public readonly INamedTypeSymbol EfRelationalQueryableExtensions;
+ // ReSharper restore InconsistentNaming
+
+ private Symbols(Compilation compilation)
+ {
+ _compilation = compilation;
+
+ IQueryableOfT = GetTypeSymbolOrThrow("System.Linq.IQueryable`1");
+ IQueryable = GetTypeSymbolOrThrow("System.Linq.IQueryable");
+ DbContext = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.DbContext");
+ DbSet = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.DbSet`1");
+
+ Enumerable = GetTypeSymbolOrThrow("System.Linq.Enumerable");
+ IEnumerableOfT = GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerable`1");
+ Queryable = GetTypeSymbolOrThrow("System.Linq.Queryable");
+ EfQueryableExtensions = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions");
+ EfRelationalQueryableExtensions = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.RelationalQueryableExtensions");
+ }
+
+ public static Symbols Load(Compilation compilation)
+ => new(compilation);
+
+ private INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName)
+ => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName)
+ ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName);
+ }
+}
diff --git a/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs b/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs
index 4036908d8f0..2ab5978e829 100644
--- a/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs
+++ b/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs
@@ -44,10 +44,11 @@ public RuntimeModelLinqToCSharpSyntaxTranslator(SyntaxGenerator syntaxGenerator)
Expression node,
IReadOnlyDictionary? constantReplacements,
IReadOnlyDictionary? memberAccessReplacements,
- ISet collectedNamespaces)
+ ISet collectedNamespaces,
+ ISet unsafeAccessors)
{
_memberAccessReplacements = memberAccessReplacements;
- var result = TranslateStatement(node, constantReplacements, collectedNamespaces);
+ var result = TranslateStatement(node, constantReplacements, collectedNamespaces, unsafeAccessors);
_memberAccessReplacements = null;
return result;
}
@@ -62,10 +63,11 @@ public RuntimeModelLinqToCSharpSyntaxTranslator(SyntaxGenerator syntaxGenerator)
Expression node,
IReadOnlyDictionary? constantReplacements,
IReadOnlyDictionary? memberAccessReplacements,
- ISet collectedNamespaces)
+ ISet collectedNamespaces,
+ ISet unsafeAccessors)
{
_memberAccessReplacements = memberAccessReplacements;
- var result = TranslateExpression(node, constantReplacements, collectedNamespaces);
+ var result = TranslateExpression(node, constantReplacements, collectedNamespaces, unsafeAccessors);
_memberAccessReplacements = null;
return result;
}
@@ -120,7 +122,7 @@ protected override void TranslateNonPublicMemberAccess(MemberExpression memberEx
protected override void TranslateNonPublicMemberAssignment(MemberExpression memberExpression, Expression value, SyntaxKind assignmentKind)
{
var propertyInfo = memberExpression.Member as PropertyInfo;
- var member = propertyInfo?.SetMethod! ?? memberExpression.Member;
+ var member = propertyInfo?.SetMethod ?? memberExpression.Member;
if (_memberAccessReplacements?.TryGetValue(member, out var methodName) == true)
{
AddNamespace(methodName.Namespace);
@@ -134,10 +136,10 @@ protected override void TranslateNonPublicMemberAssignment(MemberExpression memb
Result = InvocationExpression(
IdentifierName(methodName.Name),
ArgumentList(SeparatedList(new[]
- {
- Argument(Translate(memberExpression.Expression)),
- Argument(Translate(value))
- })));
+ {
+ Argument(Translate(memberExpression.Expression)),
+ Argument(Translate(value))
+ })));
}
else
{
diff --git a/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs b/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs
index 88d7f2fe530..b46304c7855 100644
--- a/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs
+++ b/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs
@@ -1074,16 +1074,19 @@ private void
out var structuralGetterExpression,
out var hasStructuralSentinelExpression);
+ // TODO
+ var unsafeAccessors = new HashSet();
+
mainBuilder
.Append(variableName).AppendLine(".SetGetter(")
.IncrementIndent()
- .AppendLines(_code.Expression(getterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(getterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(_code.Expression(hasSentinelExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(hasSentinelExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(_code.Expression(structuralGetterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(structuralGetterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(_code.Expression(hasStructuralSentinelExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(hasStructuralSentinelExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -1092,7 +1095,7 @@ private void
mainBuilder
.Append(variableName).AppendLine(".SetSetter(")
.IncrementIndent()
- .AppendLines(_code.Expression(setterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(setterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -1101,7 +1104,7 @@ private void
mainBuilder
.Append(variableName).AppendLine(".SetMaterializationSetter(")
.IncrementIndent()
- .AppendLines(_code.Expression(materializationSetterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(materializationSetterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -1115,19 +1118,19 @@ private void
mainBuilder
.Append(variableName).AppendLine(".SetAccessors(")
.IncrementIndent()
- .AppendLines(_code.Expression(currentValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(currentValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(_code.Expression(preStoreGeneratedCurrentValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(preStoreGeneratedCurrentValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(originalValueGetter == null
? "null"
- : _code.Expression(originalValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(originalValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(_code.Expression(relationshipSnapshotGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(relationshipSnapshotGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(valueBufferGetter == null
? "null"
- : _code.Expression(valueBufferGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(valueBufferGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
}
@@ -1918,6 +1921,9 @@ private void
out var createAndSetCollection,
out var createCollection);
+ // TODO
+ var unsafeAccessors = new HashSet();
+
AddNamespace(propertyType, parameters.Namespaces);
mainBuilder
.Append(parameters.TargetName)
@@ -1925,23 +1931,23 @@ private void
.IncrementIndent()
.AppendLines(getCollection == null
? "null"
- : _code.Expression(getCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(getCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(setCollection == null
? "null"
- : _code.Expression(setCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(setCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(setCollectionForMaterialization == null
? "null"
- : _code.Expression(setCollectionForMaterialization, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(setCollectionForMaterialization, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(createAndSetCollection == null
? "null"
- : _code.Expression(createAndSetCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(createAndSetCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(",")
.AppendLines(createCollection == null
? "null"
- : _code.Expression(createCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ : _code.Expression(createCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
}
@@ -2142,11 +2148,14 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
var runtimeType = (IRuntimeEntityType)entityType;
+ // TODO
+ var unsafeAccessors = new HashSet();
+
var originalValuesFactory = OriginalValuesFactoryFactory.Instance.CreateExpression(runtimeType);
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetOriginalValuesFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(originalValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(originalValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -2154,7 +2163,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetStoreGeneratedValuesFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(storeGeneratedValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(storeGeneratedValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -2162,7 +2171,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetTemporaryValuesFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(temporaryValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(temporaryValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -2170,7 +2179,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetShadowValuesFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(shadowValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(shadowValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -2178,7 +2187,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetEmptyShadowValuesFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(emptyShadowValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(emptyShadowValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
@@ -2186,7 +2195,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame
mainBuilder
.Append(parameters.TargetName).AppendLine(".SetRelationshipSnapshotFactory(")
.IncrementIndent()
- .AppendLines(_code.Expression(relationshipSnapshotFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
+ .AppendLines(_code.Expression(relationshipSnapshotFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true)
.AppendLine(");")
.DecrementIndent();
diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
index 85715b76f2c..f2abc516164 100644
--- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
+++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
@@ -182,7 +182,14 @@ public static DbCommand CreateDbCommand(this IQueryable source)
sql.GetArguments()));
}
- private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal]
+ public static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
IQueryable source,
string sql,
object?[] arguments,
diff --git a/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs
index 90f2fdd6636..d7df2bb5eb2 100644
--- a/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs
+++ b/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs
@@ -35,6 +35,15 @@ public class RelationalQueryCompilationContextFactory : IQueryCompilationContext
///
protected virtual RelationalQueryCompilationContextDependencies RelationalDependencies { get; }
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual QueryCompilationContext Create(bool async, bool precompiling)
+ => new RelationalQueryCompilationContext(Dependencies, RelationalDependencies, async, precompiling);
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -42,5 +51,5 @@ public class RelationalQueryCompilationContextFactory : IQueryCompilationContext
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public virtual QueryCompilationContext Create(bool async)
- => new RelationalQueryCompilationContext(Dependencies, RelationalDependencies, async);
+ => throw new UnreachableException("The overload with `precompiling` should be called");
}
diff --git a/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs
index d4d0ff8ee33..826f26d027a 100644
--- a/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs
@@ -85,7 +85,7 @@ private Expression VisitSelect(SelectExpression selectExpression)
var orderings = this.VisitAndConvert(selectExpression.Orderings);
var offset = (SqlExpression?)Visit(selectExpression.Offset);
var limit = (SqlExpression?)Visit(selectExpression.Limit);
- return selectExpression.Update(projections, tables, predicate, groupBy, having, orderings, limit, offset);
+ return selectExpression.Update(tables, predicate, groupBy, having, projections, orderings, offset, limit);
}
private Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression)
diff --git a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs
index 970c8eda4b2..2f3935a8184 100644
--- a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs
+++ b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs
@@ -22,11 +22,13 @@ public class RelationalQueryCompilationContext : QueryCompilationContext
/// Parameter object containing dependencies for this class.
/// Parameter object containing relational dependencies for this class.
/// A bool value indicating whether it is for async query.
+ /// Indicates whether the query is being precompiled.
public RelationalQueryCompilationContext(
QueryCompilationContextDependencies dependencies,
RelationalQueryCompilationContextDependencies relationalDependencies,
- bool async)
- : base(dependencies, async)
+ bool async,
+ bool precompiling)
+ : base(dependencies, async, precompiling)
{
RelationalDependencies = relationalDependencies;
QuerySplittingBehavior = RelationalOptionsExtension.Extract(ContextOptions).QuerySplittingBehavior;
diff --git a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
index 9a87151efe4..88f40c34035 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
@@ -25,7 +25,14 @@ public DeleteExpression(TableExpression table, SelectExpression selectExpression
{
}
- private DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags)
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags)
{
Table = table;
SelectExpression = selectExpression;
diff --git a/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs
index 06388fcdcfa..e7a4be40eee 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs
@@ -32,7 +32,14 @@ public class ExceptExpression : SetOperationBase
{
}
- private ExceptExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public ExceptExpression(
string alias,
SelectExpression source1,
SelectExpression source2,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs
index a94b187b8ed..d42b54c186a 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs
@@ -55,7 +55,7 @@ public virtual ExistsExpression Update(SelectExpression subquery)
public override Expression Quote()
=> New(
_quotingConstructor ??=
- typeof(ExistsExpression).GetConstructor([typeof(SelectExpression), typeof(bool), typeof(RelationalTypeMapping)])!,
+ typeof(ExistsExpression).GetConstructor([typeof(SelectExpression), typeof(RelationalTypeMapping)])!,
Subquery.Quote(),
RelationalExpressionQuotingUtilities.QuoteTypeMapping(TypeMapping));
diff --git a/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs
index d21c4a1f834..b1050066e0f 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs
@@ -27,7 +27,14 @@ public InnerJoinExpression(TableExpressionBase table, SqlExpression joinPredicat
{
}
- private InnerJoinExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public InnerJoinExpression(
TableExpressionBase table,
SqlExpression joinPredicate,
bool prunable,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs
index 2a3327e66a2..85009fc32c2 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs
@@ -32,7 +32,14 @@ public class IntersectExpression : SetOperationBase
{
}
- private IntersectExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public IntersectExpression(
string alias,
SelectExpression source1,
SelectExpression source2,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs
index e359be4355a..a5e2ea0ebc5 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs
@@ -27,7 +27,14 @@ public LeftJoinExpression(TableExpressionBase table, SqlExpression joinPredicate
{
}
- private LeftJoinExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public LeftJoinExpression(
TableExpressionBase table,
SqlExpression joinPredicate,
bool prunable,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs
index 91e319011f9..6692ce4cc51 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs
@@ -25,7 +25,14 @@ public OuterApplyExpression(TableExpressionBase table)
{
}
- private OuterApplyExpression(TableExpressionBase table, IReadOnlyDictionary? annotations)
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public OuterApplyExpression(TableExpressionBase table, IReadOnlyDictionary? annotations)
: base(table, prunable: false, annotations)
{
}
diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs
index 82031d038d3..82bd96b7d91 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs
@@ -595,14 +595,14 @@ protected override Expression VisitExtension(Expression expression)
return base.VisitExtension(
selectExpression.Update(
- selectExpression.Projection,
visitedTables ?? selectExpression.Tables,
selectExpression.Predicate,
selectExpression.GroupBy,
selectExpression.Having,
+ selectExpression.Projection,
selectExpression.Orderings,
- selectExpression.Limit,
- selectExpression.Offset));
+ selectExpression.Offset,
+ selectExpression.Limit));
}
private void RemapProjections(int[]? map, SelectExpression selectExpression)
diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
index 2334d206162..7720b6b8e8c 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
@@ -63,18 +63,75 @@ public sealed partial class SelectExpression : TableExpressionBase
public SelectExpression(
string? alias,
List tables,
+ SqlExpression? predicate,
List groupBy,
+ SqlExpression? having,
List projections,
+ bool distinct,
List orderings,
+ SqlExpression? offset,
+ SqlExpression? limit,
+ ISet tags,
IReadOnlyDictionary? annotations,
- SqlAliasManager sqlAliasManager)
+ SqlAliasManager? sqlAliasManager,
+ bool isMutable)
: base(alias, annotations)
{
- _projection = projections;
+ Check.DebugAssert(!(isMutable && sqlAliasManager is null), "Need SqlAliasManager when the SelectExpression is mutable");
+
_tables = tables;
+ Predicate = predicate;
_groupBy = groupBy;
+ Having = having;
+ _projection = projections;
+ IsDistinct = distinct;
_orderings = orderings;
- _sqlAliasManager = sqlAliasManager;
+ Offset = offset;
+ Limit = limit;
+ Tags = tags;
+ IsMutable = isMutable;
+
+ _sqlAliasManager = sqlAliasManager!;
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal]
+ public SelectExpression(
+ string? alias,
+ IReadOnlyList tables,
+ SqlExpression? predicate,
+ IReadOnlyList groupBy,
+ SqlExpression? having,
+ IReadOnlyList projections,
+ bool distinct,
+ IReadOnlyList orderings,
+ SqlExpression? offset,
+ SqlExpression? limit,
+ IReadOnlySet tags,
+ IReadOnlyDictionary? annotations)
+ : this(alias, tables.ToList(), predicate, groupBy.ToList(), having, projections.ToList(), distinct, orderings.ToList(),
+ offset, limit, tags.ToHashSet(), annotations, sqlAliasManager: null, isMutable: false)
+ {
+ }
+
+ private SelectExpression(
+ string? alias,
+ List tables,
+ List groupBy,
+ List projections,
+ List orderings,
+ IReadOnlyDictionary? annotations,
+ SqlAliasManager sqlAliasManager)
+ : this(
+ alias, tables, predicate: null, groupBy: groupBy, having: null, projections: projections, distinct: false, orderings: orderings, offset: null,
+ limit: null, tags: new HashSet(),
+ annotations: annotations, sqlAliasManager: sqlAliasManager, isMutable: true)
+ {
}
///
@@ -119,7 +176,11 @@ public SelectExpression(SqlExpression projection, SqlAliasManager sqlAliasManage
// should have an alias manager at all, so this is temporary).
[EntityFrameworkInternal]
public static SelectExpression CreateImmutable(string alias, List tables, List projection)
- => new(alias, tables, groupBy: [], projections: projection, orderings: [], annotations: null, sqlAliasManager: null!) { IsMutable = false };
+ => new(
+ alias, tables, predicate: null, groupBy: [], having: null, projections: projection, distinct: false, orderings: [],
+ offset: null, limit: null,
+ tags: new HashSet(), sqlAliasManager: null, annotations: new Dictionary(),
+ isMutable: false);
///
/// The list of tags applied to this .
@@ -3753,17 +3814,11 @@ private TableExpressionBase Clone(string? alias, ExpressionVisitor cloningExpres
var limit = (SqlExpression?)cloningExpressionVisitor.Visit(Limit);
var newSelectExpression = new SelectExpression(
- alias, newTables, newGroupBy, newProjections, newOrderings, Annotations, _sqlAliasManager)
+ alias, newTables, predicate, newGroupBy, havingExpression, newProjections, IsDistinct, newOrderings, offset, limit,
+ Tags, Annotations, _sqlAliasManager, IsMutable)
{
- Predicate = predicate,
- Having = havingExpression,
- Offset = offset,
- Limit = limit,
- IsDistinct = IsDistinct,
- Tags = Tags,
_projectionMapping = newProjectionMappings,
_clientProjections = newClientProjections,
- IsMutable = IsMutable
};
foreach (var (column, comparer) in _identifier)
@@ -4055,17 +4110,11 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
if (changed)
{
var newSelectExpression = new SelectExpression(
- Alias, newTables, newGroupBy, newProjections, newOrderings, Annotations, _sqlAliasManager)
+ Alias, newTables, predicate, newGroupBy, havingExpression, newProjections, IsDistinct, newOrderings, offset,
+ limit, (IReadOnlySet)Tags, Annotations)
{
_clientProjections = _clientProjections,
- _projectionMapping = _projectionMapping,
- Predicate = predicate,
- Having = havingExpression,
- Offset = offset,
- Limit = limit,
- IsDistinct = IsDistinct,
- Tags = Tags,
- IsMutable = false
+ _projectionMapping = _projectionMapping
};
newSelectExpression._identifier.AddRange(identifier.Zip(_identifier).Select(e => (e.First, e.Second.Comparer)));
@@ -4124,25 +4173,25 @@ List VisitList(List list, bool inPlace, out bool changed)
/// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
/// return this expression.
///
- /// The property of the result.
/// The property of the result.
/// The property of the result.
/// The property of the result.
/// The property of the result.
+ /// The property of the result.
/// The property of the result.
- /// The property of the result.
/// The property of the result.
+ /// The property of the result.
/// This expression if no children changed, or an expression with the updated children.
// This does not take internal states since when using this method SelectExpression should be finalized
public SelectExpression Update(
- IReadOnlyList projections,
IReadOnlyList tables,
SqlExpression? predicate,
IReadOnlyList groupBy,
SqlExpression? having,
+ IReadOnlyList projections,
IReadOnlyList orderings,
- SqlExpression? limit,
- SqlExpression? offset)
+ SqlExpression? offset,
+ SqlExpression? limit)
{
if (IsMutable)
{
@@ -4155,8 +4204,8 @@ List VisitList(List list, bool inPlace, out bool changed)
&& groupBy == GroupBy
&& having == Having
&& orderings == Orderings
- && limit == Limit
- && offset == Offset)
+ && offset == Offset
+ && limit == Limit)
{
return this;
}
@@ -4168,17 +4217,11 @@ List VisitList(List list, bool inPlace, out bool changed)
}
var newSelectExpression = new SelectExpression(
- Alias, tables.ToList(), groupBy.ToList(), projections.ToList(), orderings.ToList(), Annotations, _sqlAliasManager)
+ Alias, tables, predicate, groupBy, having, projections, IsDistinct, orderings, offset, limit,
+ (IReadOnlySet)Tags, Annotations)
{
_projectionMapping = projectionMapping,
- _clientProjections = _clientProjections.ToList(),
- Predicate = predicate,
- Having = having,
- Offset = offset,
- Limit = limit,
- IsDistinct = IsDistinct,
- Tags = Tags,
- IsMutable = false
+ _clientProjections = _clientProjections.ToList()
};
// We don't copy identifiers because when we are doing reconstruction so projection is already applied.
@@ -4196,17 +4239,11 @@ public override SelectExpression WithAlias(string newAlias)
{
Check.DebugAssert(!IsMutable, "Can't change alias on mutable SelectExpression");
- return new SelectExpression(newAlias, _tables, _groupBy, _projection, _orderings, Annotations, _sqlAliasManager)
+ return new SelectExpression(
+ newAlias, _tables, Predicate, _groupBy, Having, _projection, IsDistinct, _orderings, Offset, Limit, Tags,
+ Annotations, _sqlAliasManager, isMutable: false)
{
- _projectionMapping = _projectionMapping,
- _clientProjections = _clientProjections.ToList(),
- Predicate = Predicate,
- Having = Having,
- Offset = Offset,
- Limit = Limit,
- IsDistinct = IsDistinct,
- Tags = Tags,
- IsMutable = false
+ _projectionMapping = _projectionMapping, _clientProjections = _clientProjections.ToList(),
};
}
@@ -4223,8 +4260,8 @@ public override Expression Quote()
typeof(IReadOnlyList), // projections
typeof(bool), // distinct
typeof(IReadOnlyList), // orderings
- typeof(SqlExpression), // limit
typeof(SqlExpression), // offset
+ typeof(SqlExpression), // limit
typeof(IReadOnlySet), // tags
typeof(IReadOnlyDictionary) // annotations
])!,
@@ -4238,8 +4275,8 @@ public override Expression Quote()
NewArrayInit(typeof(ProjectionExpression), initializers: Projection.Select(p => p.Quote())),
Constant(IsDistinct),
NewArrayInit(typeof(OrderingExpression), initializers: Orderings.Select(o => o.Quote())),
- RelationalExpressionQuotingUtilities.QuoteOrNull(Limit),
RelationalExpressionQuotingUtilities.QuoteOrNull(Offset),
+ RelationalExpressionQuotingUtilities.QuoteOrNull(Limit),
RelationalExpressionQuotingUtilities.QuoteTags(Tags),
RelationalExpressionQuotingUtilities.QuoteAnnotations(Annotations));
diff --git a/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs
index 02f8df8cd9a..ff3dcb48ce3 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs
@@ -183,7 +183,14 @@ public class SqlFunctionExpression : SqlExpression
{
}
- private SqlFunctionExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public SqlFunctionExpression(
SqlExpression? instance,
string? schema,
string name,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs
index 16db24710f0..70b83072cdb 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs
@@ -32,7 +32,14 @@ public class UnionExpression : SetOperationBase
{
}
- private UnionExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public UnionExpression(
string alias,
SelectExpression source1,
SelectExpression source2,
diff --git a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
index 598f8be6df9..3b0041d2ac5 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
@@ -31,7 +31,14 @@ public UpdateExpression(TableExpression table, SelectExpression selectExpression
{
}
- private UpdateExpression(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // For precompiled queries
+ public UpdateExpression(
TableExpression table,
SelectExpression selectExpression,
IReadOnlyList columnValueSetters,
diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
index 2f8e60bb01f..ad732e432e7 100644
--- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
+++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
@@ -367,7 +367,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression, bool
var limit = Visit(selectExpression.Limit, out _);
- return selectExpression.Update(projections, tables, predicate, groupBy, having, orderings, limit, offset);
+ return selectExpression.Update(tables, predicate, groupBy, having, projections, orderings, offset, limit);
}
///
@@ -670,9 +670,14 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
var projectionExpression = Visit(subqueryProjection, allowOptimizedExpansion, out var projectionNullable);
inExpression = inExpression.Update(
item, subquery.Update(
- [subquery.Projection[0].Update(projectionExpression)],
- subquery.Tables, subquery.Predicate, subquery.GroupBy, subquery.Having, subquery.Orderings, subquery.Limit,
- subquery.Offset));
+ subquery.Tables,
+ subquery.Predicate,
+ subquery.GroupBy,
+ subquery.Having,
+ projections: [subquery.Projection[0].Update(projectionExpression)],
+ subquery.Orderings,
+ subquery.Offset,
+ subquery.Limit));
if (UseRelationalNulls)
{
@@ -779,14 +784,14 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
// No need for a projection with EXISTS, clear it to get SELECT 1
subquery = subquery.Update(
- [],
subquery.Tables,
subquery.Predicate,
subquery.GroupBy,
subquery.Having,
+ [],
subquery.Orderings,
- subquery.Limit,
- subquery.Offset);
+ subquery.Offset,
+ subquery.Limit);
var predicate = VisitSqlBinary(
_sqlExpressionFactory.Equal(subqueryProjection, item), allowOptimizedExpansion: true, out _);
@@ -2077,14 +2082,15 @@ static bool TryNegate(ExpressionType expressionType, out ExpressionType result)
#pragma warning restore EF1001
rewrittenSelectExpression = rewrittenSelectExpression.Update(
- projection, // TODO: We should change the project column to be non-nullable, but it's too closed down for that.
new[] { rewrittenCollectionTable },
selectExpression.Predicate,
selectExpression.GroupBy,
selectExpression.Having,
+ // TODO: We should change the project column to be non-nullable, but it's too closed down for that.
+ projection,
selectExpression.Orderings,
- selectExpression.Limit,
- selectExpression.Offset);
+ selectExpression.Offset,
+ selectExpression.Limit);
return true;
}
diff --git a/src/EFCore.Relational/Query/SqlTreePruner.cs b/src/EFCore.Relational/Query/SqlTreePruner.cs
index ca03fc79a84..18746cfc999 100644
--- a/src/EFCore.Relational/Query/SqlTreePruner.cs
+++ b/src/EFCore.Relational/Query/SqlTreePruner.cs
@@ -259,6 +259,6 @@ protected virtual SelectExpression PruneSelect(SelectExpression select, bool pre
CurrentTableAlias = parentTableAlias;
return select.Update(
- projections ?? select.Projection, tables ?? select.Tables, predicate, groupBy, having, orderings, limit, offset);
+ tables ?? select.Tables, predicate, groupBy, having, projections ?? select.Projection, orderings, offset, limit);
}
}
diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
index 55e64112f64..b545b77b761 100644
--- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
@@ -299,7 +299,7 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
_isSearchCondition = parentSearchCondition;
- return selectExpression.Update(projections, tables, predicate, groupBy, havingExpression, orderings, limit, offset);
+ return selectExpression.Update(tables, predicate, groupBy, havingExpression, projections, orderings, offset, limit);
}
///
diff --git a/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs
index e80a86c1b6c..00ecaab9ee9 100644
--- a/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs
@@ -67,14 +67,14 @@ protected override Expression VisitExtension(Expression extensionExpression)
if (IsZero(selectExpression.Limit))
{
return selectExpression.Update(
- selectExpression.Projection,
selectExpression.Tables,
selectExpression.GroupBy.Count > 0 ? selectExpression.Predicate : _sqlExpressionFactory.Constant(false),
selectExpression.GroupBy,
selectExpression.GroupBy.Count > 0 ? _sqlExpressionFactory.Constant(false) : null,
+ selectExpression.Projection,
new List(0),
- limit: null,
- offset: null);
+ offset: null,
+ limit: null);
}
bool IsZero(SqlExpression? sqlExpression)
diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs
index 161875f7a8b..1ad0e35e8e3 100644
--- a/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs
@@ -138,14 +138,14 @@ public virtual Expression Process(Expression expression)
var newSelectExpression = newTables is not null
? selectExpression.Update(
- selectExpression.Projection,
newTables,
selectExpression.Predicate,
selectExpression.GroupBy,
selectExpression.Having,
+ selectExpression.Projection,
selectExpression.Orderings,
- selectExpression.Limit,
- selectExpression.Offset)
+ selectExpression.Offset,
+ selectExpression.Limit)
: selectExpression;
// when we mark columns for rewrite we don't yet have the updated SelectExpression, so we store the info in temporary dictionary
diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs
index 7f1063149e0..aca56014c1b 100644
--- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs
@@ -23,8 +23,9 @@ public class SqlServerQueryCompilationContext : RelationalQueryCompilationContex
QueryCompilationContextDependencies dependencies,
RelationalQueryCompilationContextDependencies relationalDependencies,
bool async,
+ bool precompiling,
bool multipleActiveResultSetsEnabled)
- : base(dependencies, relationalDependencies, async)
+ : base(dependencies, relationalDependencies, async, precompiling)
{
_multipleActiveResultSetsEnabled = multipleActiveResultSetsEnabled;
}
diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs
index 30b1d39932c..d7aa170c427 100644
--- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs
@@ -47,7 +47,16 @@ public class SqlServerQueryCompilationContextFactory : IQueryCompilationContextF
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public virtual QueryCompilationContext Create(bool async)
+ public virtual QueryCompilationContext Create(bool async, bool precompiling)
=> new SqlServerQueryCompilationContext(
- Dependencies, RelationalDependencies, async, _sqlServerConnection.IsMultipleActiveResultSetsEnabled);
+ Dependencies, RelationalDependencies, async, precompiling, _sqlServerConnection.IsMultipleActiveResultSetsEnabled);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual QueryCompilationContext Create(bool async)
+ => throw new UnreachableException("The overload with `precompiling` should be called");
}
diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs
index cf9f206c616..4f6b189126a 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs
@@ -20,8 +20,9 @@ public class SqliteQueryCompilationContext : RelationalQueryCompilationContext
public SqliteQueryCompilationContext(
QueryCompilationContextDependencies dependencies,
RelationalQueryCompilationContextDependencies relationalDependencies,
- bool async)
- : base(dependencies, relationalDependencies, async)
+ bool async,
+ bool precompiling)
+ : base(dependencies, relationalDependencies, async, precompiling)
{
}
diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs
index 0570b91fbc3..c088df764ff 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs
@@ -35,6 +35,15 @@ public class SqliteQueryCompilationContextFactory : IQueryCompilationContextFact
///
protected virtual RelationalQueryCompilationContextDependencies RelationalDependencies { get; }
+ ///
+ /// Creates a new .
+ ///
+ /// Specifies whether the query is async.
+ /// Indicates whether the query is being precompiled.
+ /// The created query compilation context.
+ public QueryCompilationContext Create(bool async, bool precompiling)
+ => new SqliteQueryCompilationContext(Dependencies, RelationalDependencies, async, precompiling);
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -42,5 +51,5 @@ public class SqliteQueryCompilationContextFactory : IQueryCompilationContextFact
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public virtual QueryCompilationContext Create(bool async)
- => new SqliteQueryCompilationContext(Dependencies, RelationalDependencies, async);
+ => throw new UnreachableException("The overload with `precompiling` should be called");
}
diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
index 5698691db62..f6ee51acd22 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
@@ -418,7 +418,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}
- private static string? ConstructLikePatternParameter(
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [EntityFrameworkInternal] // Can be called from precompiled shapers
+ public static string? ConstructLikePatternParameter(
QueryContext queryContext,
string baseParameterName,
bool startsWith)
diff --git a/src/EFCore/Design/ICSharpHelper.cs b/src/EFCore/Design/ICSharpHelper.cs
index 1dd92187d33..e477c22d06d 100644
--- a/src/EFCore/Design/ICSharpHelper.cs
+++ b/src/EFCore/Design/ICSharpHelper.cs
@@ -351,6 +351,7 @@ string Literal(T? value)
///
/// The node to be translated.
/// Any namespaces required by the translated code will be added to this set.
+ /// Any unsafe accessors needed to access private members will be added to this dictionary.
/// Collection of translations for statically known instances.
/// Collection of translations for non-public member accesses.
/// Source code that would produce .
@@ -363,6 +364,7 @@ string Literal(T? value)
[EntityFrameworkInternal]
string Statement(Expression node,
ISet collectedNamespaces,
+ ISet unsafeAccessors,
IReadOnlyDictionary? constantReplacements = null,
IReadOnlyDictionary? memberAccessReplacements = null);
@@ -371,6 +373,7 @@ string Literal(T? value)
///
/// The node to be translated.
/// Any namespaces required by the translated code will be added to this set.
+ /// Any unsafe accessors needed to access private members will be added to this dictionary.
/// Collection of translations for statically known instances.
/// Collection of translations for non-public member accesses.
/// Source code that would produce .
@@ -383,6 +386,7 @@ string Literal(T? value)
[EntityFrameworkInternal]
string Expression(Expression node,
ISet collectedNamespaces,
+ ISet unsafeAccessors,
IReadOnlyDictionary? constantReplacements = null,
IReadOnlyDictionary? memberAccessReplacements = null);
}
diff --git a/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs b/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs
index 9502298b15d..4961e6d98a3 100644
--- a/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs
+++ b/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs
@@ -364,6 +364,9 @@ public static void AddNamespace(Type type, ISet namespaces)
AddNamespace(converter.ModelClrType, parameters.Namespaces);
AddNamespace(converter.ProviderClrType, parameters.Namespaces);
+ // TODO
+ var unsafeAccessors = new HashSet();
+
mainBuilder
.Append("new ValueConverter<")
.Append(codeHelper.Reference(converter.ModelClrType))
@@ -371,10 +374,10 @@ public static void AddNamespace(Type type, ISet namespaces)
.Append(codeHelper.Reference(converter.ProviderClrType))
.AppendLine(">(")
.IncrementIndent()
- .AppendLines(codeHelper.Expression(converter.ConvertToProviderExpression, parameters.Namespaces, null, null),
+ .AppendLines(codeHelper.Expression(converter.ConvertToProviderExpression, parameters.Namespaces, unsafeAccessors),
skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(codeHelper.Expression(converter.ConvertFromProviderExpression, parameters.Namespaces, null, null),
+ .AppendLines(codeHelper.Expression(converter.ConvertFromProviderExpression, parameters.Namespaces, unsafeAccessors),
skipFinalNewline: true);
if (converter.ConvertsNulls)
@@ -425,18 +428,21 @@ public static void AddNamespace(Type type, ISet namespaces)
AddNamespace(typeof(ValueComparer<>), parameters.Namespaces);
AddNamespace(comparer.Type, parameters.Namespaces);
+ // TODO
+ var unsafeAccessors = new HashSet();
+
mainBuilder
.Append("new ValueComparer<")
.Append(codeHelper.Reference(comparer.Type))
.AppendLine(">(")
.IncrementIndent()
- .AppendLines(codeHelper.Expression(comparer.EqualsExpression, parameters.Namespaces, null, null),
+ .AppendLines(codeHelper.Expression(comparer.EqualsExpression, parameters.Namespaces, unsafeAccessors),
skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(codeHelper.Expression(comparer.HashCodeExpression, parameters.Namespaces, null, null),
+ .AppendLines(codeHelper.Expression(comparer.HashCodeExpression, parameters.Namespaces, unsafeAccessors),
skipFinalNewline: true)
.AppendLine(",")
- .AppendLines(codeHelper.Expression(comparer.SnapshotExpression, parameters.Namespaces, null, null),
+ .AppendLines(codeHelper.Expression(comparer.SnapshotExpression, parameters.Namespaces, unsafeAccessors),
skipFinalNewline: true)
.Append(")")
.DecrementIndent();
diff --git a/src/EFCore/Diagnostics/EventDefinitionBase.cs b/src/EFCore/Diagnostics/EventDefinitionBase.cs
index 7c9e94eb9e5..f3d9f50f17f 100644
--- a/src/EFCore/Diagnostics/EventDefinitionBase.cs
+++ b/src/EFCore/Diagnostics/EventDefinitionBase.cs
@@ -16,7 +16,7 @@ public abstract class EventDefinitionBase
/// Creates an event definition instance.
///
/// Logging options.
- /// The .
+ /// The .
/// The at which the event will be logged.
///
/// A string representing the code that should be passed to .
diff --git a/src/EFCore/Internal/InternalDbSet.cs b/src/EFCore/Internal/InternalDbSet.cs
index 440ad39159e..34eab90ecd6 100644
--- a/src/EFCore/Internal/InternalDbSet.cs
+++ b/src/EFCore/Internal/InternalDbSet.cs
@@ -19,6 +19,7 @@ public class InternalDbSet<[DynamicallyAccessedMembers(IEntityType.DynamicallyAc
DbSet,
IQueryable,
IAsyncEnumerable,
+ IInfrastructure,
IInfrastructure,
IResettableService
where TEntity : class
@@ -520,6 +521,15 @@ Expression IQueryable.Expression
IQueryProvider IQueryable.Provider
=> EntityQueryable.Provider;
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ DbContext IInfrastructure.Instance
+ => _context;
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs
index 6cde3b34aa4..3ce10a3ac67 100644
--- a/src/EFCore/Properties/CoreStrings.Designer.cs
+++ b/src/EFCore/Properties/CoreStrings.Designer.cs
@@ -994,6 +994,12 @@ public static string DuplicateTrigger(object? trigger, object? entityType, objec
public static string EFConstantInvoked
=> GetString("EFConstantInvoked");
+ ///
+ /// The EF.Constant<T> method is not supported when using precompiled queries.
+ ///
+ public static string EFConstantNotSupportedInPrecompiledQueries
+ => GetString("EFConstantNotSupportedInPrecompiledQueries");
+
///
/// The EF.Constant<T> method may only be used with an argument that can be evaluated client-side and does not contain any reference to database-side entities.
///
@@ -2233,6 +2239,14 @@ public static string NotCollection(object? entityType, object? property)
GetString("NotCollection", nameof(entityType), nameof(property)),
entityType, property);
+ ///
+ /// When precompiling queries, the '{parameter}' parameter of method '{method}' cannot be parameterized.
+ ///
+ public static string NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries(object? parameter, object? method)
+ => string.Format(
+ GetString("NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries", nameof(parameter), nameof(method)),
+ parameter, method);
+
///
/// The given 'IQueryable' does not support generation of query strings.
///
@@ -2339,6 +2353,12 @@ public static string PoolingContextCtorError(object? contextType)
public static string PoolingOptionsModified
=> GetString("PoolingOptionsModified");
+ ///
+ /// Precompiled queries aren't supported by your EF provider.
+ ///
+ public static string PrecompiledQueryNotSupported
+ => GetString("PrecompiledQueryNotSupported");
+
///
/// The derived type '{derivedType}' cannot have the [PrimaryKey] attribute since primary keys may only be declared on the root type. Move the attribute to '{rootType}', or remove '{rootType}' from the model by using [NotMapped] attribute or calling 'EntityTypeBuilder.Ignore' on the base type in 'OnModelCreating'.
///
diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx
index b1f71053ae5..5e3e94fa50f 100644
--- a/src/EFCore/Properties/CoreStrings.resx
+++ b/src/EFCore/Properties/CoreStrings.resx
@@ -486,6 +486,9 @@
The EF.Constant<T> method may only be used within Entity Framework LINQ queries.
+
+ The EF.Constant<T> method is not supported when using precompiled queries.
+
The EF.Constant<T> method may only be used with an argument that can be evaluated client-side and does not contain any reference to database-side entities.
@@ -1291,6 +1294,9 @@
The property '{entityType}.{property}' cannot be mapped as a collection since it does not implement 'IEnumerable<T>'.
+
+ When precompiling queries, the '{parameter}' parameter of method '{method}' cannot be parameterized.
+
The given 'IQueryable' does not support generation of query strings.
@@ -1333,6 +1339,9 @@
'OnConfiguring' cannot be used to modify DbContextOptions when DbContext pooling is enabled.
+
+ Precompiled queries aren't supported by your EF provider.
+
The derived type '{derivedType}' cannot have the [PrimaryKey] attribute since primary keys may only be declared on the root type. Move the attribute to '{rootType}', or remove '{rootType}' from the model by using [NotMapped] attribute or calling 'EntityTypeBuilder.Ignore' on the base type in 'OnModelCreating'.
diff --git a/src/EFCore/Query/ExpressionPrinter.cs b/src/EFCore/Query/ExpressionPrinter.cs
index ff1abca0724..4c39cf21c09 100644
--- a/src/EFCore/Query/ExpressionPrinter.cs
+++ b/src/EFCore/Query/ExpressionPrinter.cs
@@ -461,13 +461,19 @@ protected override Expression VisitConditional(ConditionalExpression conditional
///
protected override Expression VisitConstant(ConstantExpression constantExpression)
{
- if (constantExpression.Value is IPrintableExpression printable)
+ switch (constantExpression.Value)
{
- printable.Print(this);
- }
- else
- {
- PrintValue(constantExpression.Value);
+ case IPrintableExpression printable:
+ printable.Print(this);
+ break;
+
+ case IQueryable queryable:
+ _stringBuilder.Append(Print(queryable.Expression));
+ break;
+
+ default:
+ PrintValue(constantExpression.Value);
+ break;
}
return constantExpression;
diff --git a/src/EFCore/Query/IQueryCompilationContextFactory.cs b/src/EFCore/Query/IQueryCompilationContextFactory.cs
index cbfcf9c1780..6206a0dc39f 100644
--- a/src/EFCore/Query/IQueryCompilationContextFactory.cs
+++ b/src/EFCore/Query/IQueryCompilationContextFactory.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
+
namespace Microsoft.EntityFrameworkCore.Query;
///
@@ -26,4 +28,16 @@ public interface IQueryCompilationContextFactory
/// Specifies whether the query is async.
/// The created query compilation context.
QueryCompilationContext Create(bool async);
+
+ ///
+ /// Creates a new .
+ ///
+ /// Specifies whether the query is async.
+ /// Indicates whether the query is being precompiled.
+ /// The created query compilation context.
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ QueryCompilationContext Create(bool async, bool precompiling)
+ => precompiling
+ ? throw new InvalidOperationException(CoreStrings.PrecompiledQueryNotSupported)
+ : Create(async);
}
diff --git a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs
index 8e01d6b73c6..c905f8d45a5 100644
--- a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs
+++ b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs
@@ -39,6 +39,11 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor
///
private bool _calculatingPath;
+ ///
+ /// Indicates whether performing parameter extraction on a precompiled query.
+ ///
+ private bool _precompiledQuery;
+
///
/// Indicates whether we should parameterize. Is false in compiled query mode, as well as when we're handling query filters from
/// NavigationExpandingExpressionVisitor.
@@ -94,6 +99,10 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor
private static readonly MethodInfo ReadOnlyCollectionIndexerGetter = typeof(ReadOnlyCollection).GetProperties()
.Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!;
+ private static readonly MethodInfo ReadOnlyElementInitCollectionIndexerGetter = typeof(ReadOnlyCollection)
+ .GetProperties()
+ .Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!;
+
private static readonly MethodInfo ReadOnlyMemberBindingCollectionIndexerGetter = typeof(ReadOnlyCollection)
.GetProperties()
.Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!;
@@ -140,11 +149,13 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor
public virtual Expression ExtractParameters(
Expression expression,
IParameterValues parameterValues,
+ bool precompiledQuery,
bool parameterize,
bool clearParameterizedValues)
{
Reset(clearParameterizedValues);
_parameterValues = parameterValues;
+ _precompiledQuery = precompiledQuery;
_parameterize = parameterize;
_calculatingPath = false;
@@ -161,6 +172,23 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor
return root;
}
+ ///
+ /// Resets the funcletizer in preparation for multiple path calculations (i.e. for the same query). After this is called,
+ /// can be called multiple times, preserving state
+ /// between calls.
+ ///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ public virtual void ResetPathCalculation()
+ {
+ Reset();
+ _calculatingPath = true;
+ _parameterize = true;
+
+ // In precompilation mode we don't actually extract parameter values; but we do need to generate the parameter names, using the
+ // same logic (and via the same code) used in parameter extraction, and that logic requires _parameterValues.
+ _parameterValues = new DummyParameterValues();
+ }
+
///
/// Processes an expression tree, locates references to captured variables and returns information on how to extract them from
/// expression trees with the same shape. Used to generate C# code for query precompilation.
@@ -172,17 +200,74 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ public virtual PathNode? CalculatePathsToEvaluatableRoots(MethodCallExpression linqOperatorMethodCall, int argumentIndex)
+ {
+ var argument = linqOperatorMethodCall.Arguments[argumentIndex];
+ if (argument is UnaryExpression { NodeType: ExpressionType.Quote } quote)
+ {
+ argument = quote.Operand;
+ }
+
+ var root = Visit(argument, out var state);
+
+ // If the top-most node in the tree is evaluatable, that means we have a non-lambda parameter to the LINQ operator (e.g. Skip/Take).
+ // We make sure to return a path containing the argument; note that since we're not in a lambda, the argument will always be
+ // parameterized since we're not inside a lambda (e.g. Skip/Take), except for [NotParameterized].
+ if (state.IsEvaluatable
+ && IsParameterParameterizable(linqOperatorMethodCall.Method, linqOperatorMethodCall.Method.GetParameters()[argumentIndex]))
+ {
+ _ = Evaluate(root, out var parameterName, out _);
+
+ state = new()
+ {
+ StateType = StateType.ContainsEvaluatable,
+ Path = new()
+ {
+ ExpressionType = state.ExpressionType!,
+ ParameterName = parameterName,
+ Children = Array.Empty()
+ }
+ };
+ }
+
+ return state.Path;
+ }
+
+ ///
+ /// Processes an expression tree, locates references to captured variables and returns information on how to extract them from
+ /// expression trees with the same shape. Used to generate C# code for query precompilation.
+ ///
+ /// A tree representing the path to each evaluatable root node in the tree.
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
public virtual PathNode? CalculatePathsToEvaluatableRoots(Expression expression)
{
- Reset();
- _calculatingPath = true;
- _parameterize = true;
+ var root = Visit(expression, out var state);
- // In precompilation mode we don't actually extract parameter values; but we do need to generate the parameter names, using the
- // same logic (and via the same code) used in parameter extraction, and that logic requires _parameterValues.
- _parameterValues = new DummyParameterValues();
+ // If the top-most node in the tree is evaluatable, that means we have a non-lambda parameter to the LINQ operator (e.g. Skip/Take).
+ // We make sure to return a path containing the argument; note that since we're not in a lambda, the argument will always be
+ // parameterized since we're not inside a lambda (e.g. Skip/Take), except for [NotParameterized].
+ if (state.IsEvaluatable)
+ {
+ _ = Evaluate(root, out var parameterName, out _);
- _ = Visit(expression, out var state);
+ state = new()
+ {
+ StateType = StateType.ContainsEvaluatable,
+ Path = new()
+ {
+ ExpressionType = state.ExpressionType!,
+ ParameterName = parameterName,
+ Children = Array.Empty()
+ }
+ };
+ }
return state.Path;
}
@@ -798,7 +883,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall)
{
if (_calculatingPath)
{
- throw new InvalidOperationException("EF.Constant is not supported when using precompiled queries");
+ throw new InvalidOperationException(CoreStrings.EFConstantNotSupportedInPrecompiledQueries);
}
var argument = Visit(methodCall.Arguments[0], out var argumentState);
@@ -888,17 +973,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall)
// To support [NotParameterized] and indexer method arguments - which force evaluation as constant - go over the parameters
// and modify the states as needed
- ParameterInfo[]? parameterInfos = null;
+ ParameterInfo[]? parameters = null;
for (var i = 0; i < methodCall.Arguments.Count; i++)
{
var argumentState = argumentStates[i];
if (argumentState.IsEvaluatable)
{
- parameterInfos ??= methodCall.Method.GetParameters();
- if (parameterInfos[i].GetCustomAttribute() is not null
- || _model.IsIndexerMethod(methodCall.Method))
+ parameters ??= methodCall.Method.GetParameters();
+ if (!IsParameterParameterizable(methodCall.Method, parameters[i]))
{
+ if (argumentState.StateType is StateType.EvaluatableWithCapturedVariable && _precompiledQuery)
+ {
+ throw new InvalidOperationException(
+ CoreStrings.NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries(
+ parameters[i].Name,
+ method.Name));
+ }
+
argumentStates[i] = argumentState with
{
StateType = StateType.EvaluatableWithoutCapturedVariable, ForceConstantization = true
@@ -1140,7 +1232,7 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInit)
// Avoid allocating for the notEvaluatableAsRootHandler closure below unless we actually end up in the evaluatable case
var (memberInit2, new2, newState2, bindings2, bindingStates2) = (memberInit, @new, newState, bindings, bindingStates);
_state = State.CreateEvaluatable(
- typeof(InvocationExpression),
+ typeof(MemberInitExpression),
state is StateType.EvaluatableWithCapturedVariable,
notEvaluatableAsRootHandler: () => EvaluateChildren(memberInit2, new2, newState2, bindings2, bindingStates2));
break;
@@ -1299,7 +1391,7 @@ protected override Expression VisitListInit(ListInitExpression listInit)
{
children =
[
- newState.Path! with { PathFromParent = static e => Property(e, nameof(MethodCallExpression.Object)) }
+ newState.Path! with { PathFromParent = static e => Property(e, nameof(ListInitExpression.NewExpression)) }
];
}
@@ -1307,17 +1399,24 @@ protected override Expression VisitListInit(ListInitExpression listInit)
{
var initializer = initializers[i];
+ // listInit.Initializers[0].Arguments[1]
+ var initializerIndex = i;
var visitedArguments = EvaluateList(
visitedInitializersArguments is null
? initializer.Arguments
: visitedInitializersArguments[i],
initializerArgumentStates[i],
ref children,
- static i => e =>
+ j => e =>
Call(
- Property(e, nameof(MethodCallExpression.Arguments)),
+ Property(
+ Call(
+ Property(e, nameof(ListInitExpression.Initializers)),
+ ReadOnlyElementInitCollectionIndexerGetter,
+ arguments: [Constant(initializerIndex)]),
+ nameof(System.Linq.Expressions.ElementInit.Arguments)),
ReadOnlyCollectionIndexerGetter,
- arguments: [Constant(i)]));
+ arguments: [Constant(j)]));
if (visitedArguments is not null && visitedInitializersArguments is null)
{
@@ -1964,6 +2063,10 @@ private bool IsGenerallyEvaluatable(Expression expression)
// Don't evaluate QueryableMethods if in compiled query
|| !(expression is MethodCallExpression { Method: var method } && method.DeclaringType == typeof(Queryable)));
+ private bool IsParameterParameterizable(MethodInfo method, ParameterInfo parameter)
+ => parameter.GetCustomAttribute() is null
+ && !_model.IsIndexerMethod(method);
+
private enum StateType
{
///
diff --git a/src/EFCore/Query/Internal/IQueryCompiler.cs b/src/EFCore/Query/Internal/IQueryCompiler.cs
index 32749f8a345..277a351417a 100644
--- a/src/EFCore/Query/Internal/IQueryCompiler.cs
+++ b/src/EFCore/Query/Internal/IQueryCompiler.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
+
namespace Microsoft.EntityFrameworkCore.Query.Internal;
///
@@ -48,4 +50,13 @@ public interface IQueryCompiler
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
Func CreateCompiledAsyncQuery(Expression query);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ Expression> PrecompileQuery(Expression query, bool async);
}
diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
index 07b5982255b..f91f6f77a5f 100644
--- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
@@ -209,7 +209,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
&& entityQueryRootExpression.GetType() == typeof(EntityQueryRootExpression))
{
var processedDefiningQueryBody = _funcletizer.ExtractParameters(
- definingQuery.Body, _parameters, parameterize: false, clearParameterizedValues: false);
+ definingQuery.Body, _parameters, _queryCompilationContext.IsPrecompiling, parameterize: false,
+ clearParameterizedValues: false);
processedDefiningQueryBody = _queryTranslationPreprocessor.NormalizeQueryableMethod(processedDefiningQueryBody);
processedDefiningQueryBody = _nullCheckRemovingExpressionVisitor.Visit(processedDefiningQueryBody);
processedDefiningQueryBody =
@@ -1753,7 +1754,8 @@ private Expression ApplyQueryFilter(IEntityType entityType, NavigationExpansionE
{
filterPredicate = queryFilter;
filterPredicate = (LambdaExpression)_funcletizer.ExtractParameters(
- filterPredicate, _parameters, parameterize: false, clearParameterizedValues: false);
+ filterPredicate, _parameters, _queryCompilationContext.IsPrecompiling, parameterize: false,
+ clearParameterizedValues: false);
filterPredicate = (LambdaExpression)_queryTranslationPreprocessor.NormalizeQueryableMethod(filterPredicate);
// We need to do entity equality, but that requires a full method call on a query root to properly flow the
diff --git a/src/EFCore/Query/Internal/PrecompiledQueryContext.cs b/src/EFCore/Query/Internal/PrecompiledQueryContext.cs
new file mode 100644
index 00000000000..8587758c602
--- /dev/null
+++ b/src/EFCore/Query/Internal/PrecompiledQueryContext.cs
@@ -0,0 +1,128 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Diagnostics.CodeAnalysis;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal;
+
+///
+/// A context for a precompiled query that's being executed. This wraps the via which the query is being
+/// executed, as well as a regular EF . It is flown through all intercepted LINQ operators, until the
+/// terminating operator interceptor which actually executes the query. Note that it implements so that
+/// it can be flown from one intercepted LINQ operator to another.
+///
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+[Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+public class PrecompiledQueryContext : IOrderedQueryable
+{
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public PrecompiledQueryContext(DbContext dbContext)
+ : this(dbContext, dbContext.GetService().Create())
+ {
+ }
+
+ private PrecompiledQueryContext(DbContext dbContext, QueryContext queryContext)
+ {
+ DbContext = dbContext;
+ QueryContext = queryContext;
+ }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual DbContext DbContext { get; set; }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual QueryContext QueryContext { get; }
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual PrecompiledQueryContext ToType()
+ => new(DbContext, QueryContext);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual IncludablePrecompiledQueryContext ToIncludable()
+ => new(DbContext, QueryContext);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public IEnumerator GetEnumerator()
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public Type ElementType
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public Expression Expression
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public IQueryProvider Provider
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public class IncludablePrecompiledQueryContext(DbContext dbContext, QueryContext queryContext)
+ : PrecompiledQueryContext(dbContext, queryContext), IIncludableQueryable;
+}
diff --git a/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs b/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs
new file mode 100644
index 00000000000..e01894b27aa
--- /dev/null
+++ b/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs
@@ -0,0 +1,72 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Diagnostics.CodeAnalysis;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal;
+
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+[Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+public class PrecompiledQueryableAsyncEnumerableAdapter(IAsyncEnumerable asyncEnumerable)
+ : IQueryable, IAsyncEnumerable
+{
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default)
+ => asyncEnumerable.GetAsyncEnumerator(cancellationToken);
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public IEnumerator GetEnumerator()
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public Type ElementType
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public Expression Expression
+ => throw new NotSupportedException();
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public IQueryProvider Provider
+ => throw new NotSupportedException();
+}
diff --git a/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs b/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs
index 72d21b06c26..857ef233ec4 100644
--- a/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs
+++ b/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs
@@ -27,6 +27,15 @@ public QueryCompilationContextFactory(QueryCompilationContextDependencies depend
///
protected virtual QueryCompilationContextDependencies Dependencies { get; }
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual QueryCompilationContext Create(bool async, bool precompiling)
+ => new(Dependencies, async, precompiling);
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -34,5 +43,5 @@ public QueryCompilationContextFactory(QueryCompilationContextDependencies depend
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public virtual QueryCompilationContext Create(bool async)
- => new(Dependencies, async);
+ => throw new UnreachableException("The overload with `precompiling` should be called");
}
diff --git a/src/EFCore/Query/Internal/QueryCompiler.cs b/src/EFCore/Query/Internal/QueryCompiler.cs
index a0536aff1bb..3dfa3aea333 100644
--- a/src/EFCore/Query/Internal/QueryCompiler.cs
+++ b/src/EFCore/Query/Internal/QueryCompiler.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
namespace Microsoft.EntityFrameworkCore.Query.Internal;
@@ -125,6 +126,19 @@ var compiledQuery
bool async)
=> database.CompileQuery(query, async);
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ public virtual Expression> PrecompileQuery(Expression query, bool async)
+ {
+ query = ExtractParameters(query, _queryContextFactory.Create(), _logger, precompiledQuery: true);
+ return _database.CompileQueryExpression(query, async);
+ }
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -136,7 +150,8 @@ var compiledQuery
IParameterValues parameterValues,
IDiagnosticsLogger logger,
bool compiledQuery = false,
+ bool precompiledQuery = false,
bool generateContextAccessors = false)
=> new ExpressionTreeFuncletizer(_model, _evaluatableExpressionFilter, _contextType, generateContextAccessors: false, logger)
- .ExtractParameters(query, parameterValues, parameterize: !compiledQuery, clearParameterizedValues: true);
+ .ExtractParameters(query, parameterValues, precompiledQuery, parameterize: !compiledQuery, clearParameterizedValues: true);
}
diff --git a/src/EFCore/Query/LiftableConstantProcessor.cs b/src/EFCore/Query/LiftableConstantProcessor.cs
index bf1935bd227..acd2f796c31 100644
--- a/src/EFCore/Query/LiftableConstantProcessor.cs
+++ b/src/EFCore/Query/LiftableConstantProcessor.cs
@@ -219,6 +219,11 @@ protected virtual ParameterExpression LiftConstant(LiftableConstantExpression li
body = convertNode.Operand;
}
+ if (body.Type != liftableConstant.Type)
+ {
+ body = Expression.Convert(body, liftableConstant.Type);
+ }
+
// Register the lifted constant; note that the name will be uniquified later
var variableParameter = Expression.Parameter(liftableConstant.Type, liftableConstant.VariableName);
_liftedConstants.Add(new(variableParameter, body));
diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs
index 8a86ed0d5ec..a6809179cbd 100644
--- a/src/EFCore/Query/QueryCompilationContext.cs
+++ b/src/EFCore/Query/QueryCompilationContext.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
+
namespace Microsoft.EntityFrameworkCore.Query;
///
@@ -69,11 +71,27 @@ public class QueryCompilationContext
public QueryCompilationContext(
QueryCompilationContextDependencies dependencies,
bool async)
+ : this(dependencies, async, precompiling: false)
+ {
+ }
+
+ ///
+ /// Creates a new instance of the class.
+ ///
+ /// Parameter object containing dependencies for this class.
+ /// A bool value indicating whether it is for async query.
+ /// Indicates whether the query is being precompiled.
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ public QueryCompilationContext(
+ QueryCompilationContextDependencies dependencies,
+ bool async,
+ bool precompiling)
{
Dependencies = dependencies;
IsAsync = async;
QueryTrackingBehavior = dependencies.QueryTrackingBehavior;
IsBuffering = ExecutionStrategy.Current?.RetriesOnFailure ?? dependencies.IsRetryingExecutionStrategy;
+ IsPrecompiling = precompiling;
Model = dependencies.Model;
ContextOptions = dependencies.ContextOptions;
ContextType = dependencies.ContextType;
@@ -116,6 +134,11 @@ public class QueryCompilationContext
///
public virtual bool IsBuffering { get; }
+ ///
+ /// Indicates whether the query is being precompiled.
+ ///
+ public virtual bool IsPrecompiling { get; }
+
///
/// A value indicating whether query filters are ignored in this query.
///
@@ -167,7 +190,7 @@ public virtual void AddTag(string tag)
// across invocations of the query.
// In normal mode, these nodes should simply be evaluated, and a ConstantExpression to those instances embedded directly in the
// tree (for precompiled queries we generate C# code for resolving those instances instead).
- var queryExecutorAfterLiftingExpression =
+ var queryExecutorAfterLiftingExpression =
(Expression>)Dependencies.LiftableConstantProcessor.InlineConstants(queryExecutorExpression, SupportsPrecompiledQuery);
try
@@ -186,6 +209,7 @@ public virtual void AddTag(string tag)
/// The result type of this query.
/// The query to generate executor for.
/// Returns which can be invoked to get results of this query.
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
public virtual Expression> CreateQueryExecutorExpression(Expression query)
{
var queryAndEventData = Logger.QueryCompilationStarting(Dependencies.Context, _expressionPrinter, query);
@@ -204,11 +228,9 @@ public virtual void AddTag(string tag)
// wrap the query with code adding those parameters to the query context
query = InsertRuntimeParameters(query);
- var queryExecutorExpression = Expression.Lambda>(
+ return Expression.Lambda>(
query,
QueryContextParameter);
-
- return queryExecutorExpression;
}
///
@@ -223,7 +245,7 @@ public virtual ParameterExpression RegisterRuntimeParameter(string name, LambdaE
{
valueExtractorBody = _runtimeParameterConstantLifter.Visit(valueExtractorBody);
}
-
+
valueExtractor = Expression.Lambda(valueExtractorBody, valueExtractor.Parameters);
if (valueExtractor.Parameters.Count != 1
diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
index 315f457eccd..ef4dd17b185 100644
--- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
+++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
@@ -230,12 +230,12 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio
{ Value: IEntityType entityTypeValue } => liftableConstantFactory.CreateLiftableConstant(
constantExpression.Value,
LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(entityTypeValue),
- entityTypeValue.Name + "EntityType",
+ entityTypeValue.ShortName() + "EntityType",
constantExpression.Type),
{ Value: IComplexType complexTypeValue } => liftableConstantFactory.CreateLiftableConstant(
constantExpression.Value,
LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(complexTypeValue),
- complexTypeValue.Name + "ComplexType",
+ complexTypeValue.ShortName() + "ComplexType",
constantExpression.Type),
{ Value: IProperty propertyValue } => liftableConstantFactory.CreateLiftableConstant(
constantExpression.Value,
@@ -563,7 +563,7 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper)
LiftableConstantExpressionHelpers.BuildMemberAccessForEntityOrComplexType(typeBase, resolverPrm),
EntityTypeFindPrimaryKeyMethod),
resolverPrm),
- typeBase.Name + "Key",
+ /*typeBase.Name +*/ "key",
typeof(IKey))
: Constant(primaryKey),
NewArrayInit(
@@ -697,8 +697,8 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper)
Snapshot.Empty,
static _ => Snapshot.Empty,
"emptySnapshot",
- typeof(Snapshot))
- : Constant(Snapshot.Empty)));
+ typeof(ISnapshot))
+ : Constant(Snapshot.Empty, typeof(ISnapshot))));
var returnType = typeBase.ClrType;
var valueBufferExpression = Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod);
@@ -725,7 +725,7 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper)
? _liftableConstantFactory.CreateLiftableConstant(
concreteEntityTypes[i],
LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(concreteEntityType),
- concreteEntityType.Name + (typeBase is IEntityType ? "EntityType" : "ComplexType"),
+ concreteEntityType.ShortName() + (typeBase is IEntityType ? "EntityType" : "ComplexType"),
typeBase is IEntityType ? typeof(IEntityType) : typeof(IComplexType))
: Constant(concreteEntityTypes[i], typeBase is IEntityType ? typeof(IEntityType) : typeof(IComplexType)));
}
diff --git a/src/EFCore/Storage/Database.cs b/src/EFCore/Storage/Database.cs
index 66344581ce9..95606b3df64 100644
--- a/src/EFCore/Storage/Database.cs
+++ b/src/EFCore/Storage/Database.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
+
namespace Microsoft.EntityFrameworkCore.Storage;
///
@@ -64,6 +66,13 @@ protected Database(DatabaseDependencies dependencies)
///
public virtual Func CompileQuery(Expression query, bool async)
=> Dependencies.QueryCompilationContextFactory
- .Create(async)
+ .Create(async, precompiling: false)
.CreateQueryExecutor(query);
+
+ ///
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ public virtual Expression> CompileQueryExpression(Expression query, bool async)
+ => Dependencies.QueryCompilationContextFactory
+ .Create(async, precompiling: true)
+ .CreateQueryExecutorExpression(query);
}
diff --git a/src/EFCore/Storage/IDatabase.cs b/src/EFCore/Storage/IDatabase.cs
index eeeb574a812..69ee1952b1d 100644
--- a/src/EFCore/Storage/IDatabase.cs
+++ b/src/EFCore/Storage/IDatabase.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics.CodeAnalysis;
+
namespace Microsoft.EntityFrameworkCore.Storage;
///
@@ -55,4 +57,14 @@ public interface IDatabase
/// A value indicating whether this is an async query.
/// A which can be invoked to get results of the query.
Func CompileQuery(Expression query, bool async);
+
+ ///
+ /// Compiles the given query to generate an expression tree which can be used to execute the query.
+ ///
+ /// The type of query result.
+ /// The query to compile.
+ /// A value indicating whether this is an async query.
+ /// An expression tree which can be used to execute the query.
+ [Experimental(EFDiagnostics.PrecompiledQueryExperimental)]
+ Expression> CompileQueryExpression(Expression query, bool async);
}
diff --git a/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj b/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj
index 1ce0a1d4856..9562fd1d55e 100644
--- a/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj
+++ b/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj
@@ -56,7 +56,7 @@
-
+
diff --git a/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs b/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs
new file mode 100644
index 00000000000..3fa2f2c2f6f
--- /dev/null
+++ b/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs
@@ -0,0 +1,537 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.CompilerServices;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+using static System.Linq.Expressions.Expression;
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+// ReSharper disable InconsistentNaming
+// ReSharper disable RedundantCast
+
+#nullable enable
+
+public class CSharpToLinqTranslatorTest
+{
+ [Fact]
+ public void ArrayCreation()
+ => AssertExpression(
+ () => new int[3],
+ "new int[3]");
+
+ // ReSharper disable RedundantExplicitArrayCreation
+ [Fact]
+ public void ArrayCreation_with_initializer()
+ => AssertExpression(
+ () => new int[] { 1, 2 },
+ "new int[] { 1, 2 }");
+ // ReSharper restore RedundantExplicitArrayCreation
+
+ // ReSharper disable BuiltInTypeReferenceStyle
+ [Fact]
+ public void As()
+ => AssertExpression(
+ () => "foo" as String,
+ """ "foo" as String""");
+ // ReSharper restore BuiltInTypeReferenceStyle
+
+ [Fact]
+ public void As_with_predefined_type()
+ => AssertExpression(
+ () => "foo" as string,
+ """ "foo" as string""");
+
+ [Theory]
+ [InlineData("1 + 2", ExpressionType.Add)]
+ [InlineData("1 - 2", ExpressionType.Subtract)]
+ [InlineData("1 * 2", ExpressionType.Multiply)]
+ [InlineData("1 / 2", ExpressionType.Divide)]
+ [InlineData("1 % 2", ExpressionType.Modulo)]
+ [InlineData("1 & 2", ExpressionType.And)]
+ [InlineData("1 | 2", ExpressionType.Or)]
+ [InlineData("1 ^ 2", ExpressionType.ExclusiveOr)]
+ [InlineData("1 >> 2", ExpressionType.RightShift)]
+ [InlineData("1 << 2", ExpressionType.LeftShift)]
+ [InlineData("1 < 2", ExpressionType.LessThan)]
+ [InlineData("1 <= 2", ExpressionType.LessThanOrEqual)]
+ [InlineData("1 > 2", ExpressionType.GreaterThan)]
+ [InlineData("1 >= 2", ExpressionType.GreaterThanOrEqual)]
+ [InlineData("1 == 2", ExpressionType.Equal)]
+ [InlineData("1 != 2", ExpressionType.NotEqual)]
+ public void Binary_int(string code, ExpressionType binaryType)
+ => AssertExpression(
+ MakeBinary(binaryType, Constant(1), Constant(2)),
+ code);
+
+ [Theory]
+ [InlineData("true && false", ExpressionType.AndAlso)]
+ [InlineData("true || false", ExpressionType.OrElse)]
+ [InlineData("true ^ false", ExpressionType.ExclusiveOr)]
+ public void Binary_bool(string code, ExpressionType binaryType)
+ => AssertExpression(
+ Lambda>(
+ MakeBinary(binaryType, Constant(true), Constant(false))),
+ code);
+
+ [Fact]
+ public void Binary_add_string()
+ => AssertExpression(
+ () => new[] { "foo", "bar" }.Select(s => s + "foo"),
+ """new[] { "foo", "bar" }.Select(s => s + "foo")""");
+
+ [Fact]
+ public void Cast()
+ => AssertExpression(
+ () => (object)1,
+ "(object)1");
+
+ [Fact]
+ public void Coalesce()
+ => AssertExpression(
+ () => (object?)"foo" ?? (object)"bar",
+ """(object?)"foo" ?? (object)"bar" """);
+
+ [Fact]
+ public void ElementAccess_over_array()
+ => AssertExpression(
+ () => new[] { 1, 2, 3 }[1],
+ "new[] { 1, 2, 3 } [1]");
+
+ [Fact]
+ public void ElementAccess_over_list()
+ => AssertExpression(
+ () => new List { 1, 2, 3 }[1],
+ "new List { 1, 2, 3 }[1]");
+
+ [Fact]
+ public void IdentifierName_for_lambda_parameter()
+ => AssertExpression(
+ () => new[] { 1, 2, 3 }.Where(i => i == 2),
+ "new[] { 1, 2, 3 }.Where(i => i == 2);");
+
+ [Fact]
+ public void ImplicitArrayCreation()
+ => AssertExpression(
+ () => new[] { 1, 2 },
+ "new[] { 1, 2 }");
+
+ [Fact]
+ public void Interpolated_string()
+ => AssertExpression(
+ () => string.Format("Foo: {0}", new[] { (object)8 }),
+ """$"Foo: {8}" """);
+
+ [Fact]
+ public void Interpolated_string_formattable()
+ => AssertExpression(
+ () => FormattableStringMethod(FormattableStringFactory.Create("Foo: {0}, {1}", (object)8, (object) 9)),
+ """CSharpToLinqTranslatorTest.FormattableStringMethod($"Foo: {8}, {9}")""");
+
+ [Fact]
+ public void Index_over_array()
+ => AssertExpression(
+ () => new[] { 1, 2 }[0],
+ "new[] { 1, 2 }[0]");
+
+ [Fact]
+ public void Index_over_List()
+ => AssertExpression(
+ () => new List { 1, 2 }[0],
+ "new List { 1, 2 }[0]");
+
+ [Fact]
+ public void Invocation_instance_method()
+ => AssertExpression(
+ () => "foo".Substring(2),
+ """ "foo".Substring(2)""");
+
+ [Fact]
+ public void Invocation_method_with_optional_parameter()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(4),
+ NewArrayInit(typeof(int))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4)");
+
+ [Fact]
+ public void Invocation_method_with_optional_parameter_missing_argument()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(3),
+ NewArrayInit(typeof(int))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1)");
+
+ [Fact]
+ public void Invocation_method_with_params_parameter_no_arguments()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(4),
+ NewArrayInit(typeof(int))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4)");
+
+ [Fact]
+ public void Invocation_method_with_params_parameter_one_argument()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(4),
+ NewArrayInit(typeof(int), Constant(5))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4, 5)");
+
+ [Fact]
+ public void Invocation_method_with_params_parameter_multiple_arguments()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(4),
+ NewArrayInit(typeof(int), Constant(5), Constant(6))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4, 5, 6)");
+
+ [Fact]
+ public void Invocation_method_with_params_parameter_missing_argument()
+ => AssertExpression(
+ Call(
+ typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!,
+ Constant(1),
+ Constant(3),
+ NewArrayInit(typeof(int))),
+ "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1)");
+
+ [Fact]
+ public void Invocation_static_method()
+ => AssertExpression(
+ () => DateTime.Parse("2020-01-01"),
+ """DateTime.Parse("2020-01-01")""");
+
+ [Fact]
+ public void Invocation_extension_method()
+ => AssertExpression(
+ () => typeof(string).GetTypeInfo(),
+ "typeof(string).GetTypeInfo()");
+
+ // ReSharper disable InvokeAsExtensionMethod
+ [Fact]
+ public void Invocation_extension_method_with_non_extension_syntax()
+ => AssertExpression(
+ () => IntrospectionExtensions.GetTypeInfo(typeof(string)),
+ "typeof(string).GetTypeInfo()");
+ // ReSharper restore InvokeAsExtensionMethod
+
+ [Fact]
+ public void Invocation_generic_method()
+ => AssertExpression(
+ () => Enumerable.Repeat("foo", 5),
+ """Enumerable.Repeat("foo", 5)""");
+
+ [Fact]
+ public void Invocation_generic_extension_method()
+ => AssertExpression(
+ () => new[] { 1, 2, 3 }.Where(i => i > 1),
+ "new[] { 1, 2, 3 }.Where(i => i > 1)");
+
+ [Fact]
+ public void Invocation_generic_queryable_extension_method()
+ => AssertExpression(
+ () => new[] { 1, 2, 3 }.AsQueryable().Where(i => i > 1),
+ "new[] { 1, 2, 3 }.AsQueryable().Where(i => i > 1)");
+
+ [Fact]
+ public void Invocation_non_generic_method_on_generic_type()
+ => AssertExpression(
+ () => SomeGenericType.SomeFunction(1),
+ "CSharpToLinqTranslatorTest.SomeGenericType.SomeFunction(1)");
+
+ [Fact]
+ public void Invocation_generic_method_on_generic_type()
+ => AssertExpression(
+ () => SomeGenericType.SomeGenericFunction(1, "foo"),
+ """CSharpToLinqTranslatorTest.SomeGenericType.SomeGenericFunction(1, "foo")""");
+
+ [Theory]
+ [InlineData("""
+ "hello"
+ """, "hello")]
+ [InlineData("1", 1)]
+ [InlineData("1L", 1L)]
+ [InlineData("1U", 1U)]
+ [InlineData("1UL", 1UL)]
+ [InlineData("1.5D", 1.5)]
+ [InlineData("1.5F", 1.5F)]
+ [InlineData("true", true)]
+ public void Literal(string csharpLiteral, object expectedValue)
+ => AssertExpression(
+ Constant(expectedValue),
+ csharpLiteral);
+
+ [Fact]
+ public void Literal_decimal()
+ => AssertExpression(
+ () => 1.5m,
+ "1.5m");
+
+ [Fact]
+ public void Literal_null()
+ => AssertExpression(
+ Equal(Constant("foo"), Constant(null, typeof(string))),
+ """ "foo" == null""");
+
+ [Fact]
+ public void Literal_enum()
+ => AssertExpression(
+ () => SomeEnum.Two,
+ "CSharpToLinqTranslatorTest.SomeEnum.Two");
+
+ [Fact]
+ public void Literal_enum_with_multiple_values()
+ => AssertExpression(
+ Convert(
+ Or(
+ Convert(Constant(SomeEnum.One), typeof(int)),
+ Convert(Constant(SomeEnum.Two), typeof(int))),
+ typeof(SomeEnum)),
+ "CSharpToLinqTranslatorTest.SomeEnum.One | CSharpToLinqTranslatorTest.SomeEnum.Two");
+
+ [Fact]
+ public void MemberAccess_array_length()
+ => AssertExpression(
+ () => new[] { 1, 2, 3 }.Length,
+ "new[] { 1, 2, 3 }.Length");
+
+ [Fact]
+ public void MemberAccess_instance_property()
+ => AssertExpression(
+ () => "foo".Length,
+ """ "foo".Length""");
+
+ [Fact]
+ public void MemberAccess_static_property()
+ => AssertExpression(
+ () => DateTime.Now,
+ "DateTime.Now");
+
+ // TODO: MemberAccess on fields
+
+ [Fact]
+ public void Nested_type()
+ => AssertExpression(
+ () => (object)new Blog(),
+ "(object)new CSharpToLinqTranslatorTest.Blog()");
+
+ [Fact]
+ public void Not_boolean()
+ => AssertExpression(
+ Not(Constant(true)),
+ "!true");
+
+ [Fact]
+ public void ObjectCreation()
+ => AssertExpression(
+ () => new List(),
+ "new List()");
+
+ [Fact]
+ public void ObjectCreation_with_arguments()
+ => AssertExpression(
+ () => new List(10),
+ "new List(10)");
+
+ [Fact]
+ public void ObjectCreation_with_initializers()
+ => AssertExpression(
+ () => new Blog(8) { Name = "foo" },
+ """new CSharpToLinqTranslatorTest.Blog(8) { Name = "foo" }""");
+
+ [Fact]
+ public void ObjectCreation_with_parameterless_struct_constructor()
+ => AssertExpression(
+ () => new DateTime(),
+ "new DateTime()");
+
+ [Fact]
+ public void Parenthesized()
+ => AssertExpression(
+ () => 1,
+ "(1)");
+
+ [Theory]
+ [InlineData("+8", 8, ExpressionType.UnaryPlus)]
+ [InlineData("-8", 8, ExpressionType.Negate)]
+ [InlineData("~8", 8, ExpressionType.Not)]
+ public void PrefixUnary(string code, object operandValue, ExpressionType expectedNodeType)
+ => AssertExpression(
+ MakeUnary(expectedNodeType, Constant(8), typeof(int)),
+ code);
+
+ // ReSharper disable RedundantSuppressNullableWarningExpression
+ [Fact]
+ public void SuppressNullableWarningExpression()
+ => AssertExpression(
+ () => "foo"!,
+ """ "foo"! """);
+ // ReSharper restore RedundantSuppressNullableWarningExpression
+
+ [ConditionalFact]
+ public void Typeof()
+ => AssertExpression(
+ () => typeof(string),
+ "typeof(string)");
+
+ [ConditionalFact]
+ public void Array_type()
+ => AssertExpression(
+ () => typeof(ParameterExpression[]),
+ "typeof(ParameterExpression[])");
+
+ protected virtual void AssertExpression(Expression> expected, string code)
+ => AssertExpression(
+ expected.Body,
+ code);
+
+ protected virtual void AssertExpression(Expression expected, string code)
+ {
+ code = $"""
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using Microsoft.EntityFrameworkCore.Query;
+
+_ = {code};
+""";
+
+ var compilation = Compile(code);
+
+ var syntaxTree = compilation.SyntaxTrees.Single();
+
+ if (syntaxTree.GetRoot() is CompilationUnitSyntax { Members: [GlobalStatementSyntax globalStatement, ..] })
+ {
+ var expression = globalStatement switch
+ {
+ { Statement: ExpressionStatementSyntax { Expression: AssignmentExpressionSyntax { Right: var e } } } => e,
+ { Statement: LocalDeclarationStatementSyntax e } => e.Declaration.Variables[0].Initializer!.Value,
+ { Statement: ExpressionStatementSyntax { Expression: var e } } => e,
+
+ _ => throw new InvalidOperationException("Could not find expression to assert on")
+ };
+
+ var actual = Translate(expression, compilation);
+
+ Assert.Equal(expected, actual, ExpressionEqualityComparer.Instance);
+ }
+ else
+ {
+ Assert.Fail("Could not find expression to assert on");
+ }
+ }
+
+ private Compilation Compile(string code)
+ {
+ var syntaxTree = CSharpSyntaxTree.ParseText(code);
+
+ var compilation = CSharpCompilation.Create(
+ "TestCompilation",
+ syntaxTrees: new[] { syntaxTree },
+ references: MetadataReferences);
+
+ var diagnostics = compilation.GetDiagnostics()
+ .Where(d => d.Severity is DiagnosticSeverity.Error)
+ .ToArray();
+
+ if (diagnostics.Any())
+ {
+ var stringBuilder = new StringBuilder()
+ .AppendLine("Compilation errors:");
+
+ foreach (var diagnostic in diagnostics)
+ {
+ stringBuilder.AppendLine(diagnostic.ToString());
+ }
+
+ Assert.Fail(stringBuilder.ToString());
+ }
+
+ return compilation;
+ }
+
+ private Expression Translate(SyntaxNode node, Compilation compilation)
+ {
+ var blogContext = new BlogContext();
+ var translator = new CSharpToLinqTranslator();
+ translator.Load(compilation, blogContext);
+ return translator.Translate(node, compilation.GetSemanticModel(node.SyntaxTree));
+ }
+
+ private static readonly MetadataReference[] MetadataReferences;
+
+ static CSharpToLinqTranslatorTest()
+ {
+ var metadataReferences = new List
+ {
+ MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Queryable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IQueryable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(DbContext).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(BlogContext).Assembly.Location)
+ };
+
+ var netAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
+
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "mscorlib.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Core.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Runtime.dll")));
+
+ MetadataReferences = metadataReferences.ToArray();
+ }
+
+ [Flags]
+ public enum SomeEnum
+ {
+ One = 1,
+ Two = 2
+ }
+
+ private class BlogContext : DbContext;
+
+ public class Blog
+ {
+ public Blog()
+ {
+ }
+
+ public Blog(int id)
+ => Id = id;
+
+ public int Id { get; set; }
+ public string? Name { get; set; }
+ }
+
+ public class SomeGenericType
+ {
+ public static int SomeFunction(T1 t1)
+ => 0;
+
+ public static int SomeGenericFunction(T1 t1, T2 t2)
+ => 0;
+ }
+
+ public static int ParamsAndOptionalMethod(int a, int b = 3, params int[] c)
+ => throw new NotSupportedException();
+
+ public static int FormattableStringMethod(FormattableString formattableString)
+ => throw new NotSupportedException();
+}
diff --git a/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs b/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs
index 15e7f815257..5f7821429b7 100644
--- a/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs
+++ b/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs
@@ -3,6 +3,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
@@ -16,8 +17,6 @@ namespace Microsoft.EntityFrameworkCore.Query;
public class LinqToCSharpSyntaxTranslatorTest(ITestOutputHelper testOutputHelper)
{
- private readonly ITestOutputHelper _testOutputHelper = testOutputHelper;
-
[Theory]
[InlineData("hello", "\"hello\"")]
[InlineData(1, "1")]
@@ -33,9 +32,7 @@ public class LinqToCSharpSyntaxTranslatorTest(ITestOutputHelper testOutputHelper
[InlineData(true, "true")]
[InlineData(typeof(string), "typeof(string)")]
public void Constant_values(object constantValue, string literalRepresentation)
- => AssertExpression(
- Constant(constantValue),
- literalRepresentation);
+ => AssertExpression(Constant(constantValue), literalRepresentation);
[Fact]
public void Constant_DateTime_default()
@@ -105,14 +102,6 @@ public void Binary_PowerAssign()
PowerAssign(Parameter(typeof(double), "d"), Constant(3.0)),
"d = Math.Pow(d, 3D)");
- [Fact]
- public void Private_instance_field_SimpleAssign()
- => AssertExpression(
- Assign(
- Field(Parameter(typeof(Blog), "blog"), "_privateField"),
- Constant(3)),
- """typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 3)""");
-
[Theory]
[InlineData(ExpressionType.AddAssign, "+=")]
[InlineData(ExpressionType.MultiplyAssign, "*=")]
@@ -130,7 +119,14 @@ public void Private_instance_field_AssignOperators(ExpressionType expressionType
expressionType,
Field(Parameter(typeof(Blog), "blog"), "_privateField"),
Constant(3)),
- $"""typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 3)""");
+ $"UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(blog) {op} 3",
+ unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")]
+private static extern ref int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(LinqToCSharpSyntaxTranslatorTest.Blog instance);
+""",
+ Assert.Single(unsafeAccessors),
+ ignoreLineEndingDifferences: true));
[Theory]
[InlineData(ExpressionType.AddAssign, "+=")]
@@ -149,11 +145,9 @@ public void Private_instance_field_AssignOperators_with_replacements(ExpressionT
expressionType,
Field(Parameter(typeof(Blog), "blog"), "_privateField"),
Constant(3)),
- $"""AccessPrivateField(blog) {op} Three""",
- new Dictionary() { { 3, "Three" } },
- new Dictionary() {
+ $"""AccessPrivateField(blog) {op} Three""", new Dictionary() { { 3, "Three" } }, new Dictionary() {
{ BlogPrivateField, new QualifiedName("AccessPrivateField", "") }
- });
+ });
[Theory]
[InlineData(ExpressionType.Negate, "-i")]
@@ -192,7 +186,7 @@ public void Unary_statement(ExpressionType expressionType, string expected)
MakeUnary(expressionType, i, typeof(int))),
$$"""
{
- int i;
+ int i = default;
{{expected}};
}
""");
@@ -252,25 +246,64 @@ public void Static_property()
typeof(DateTime).GetProperty(nameof(DateTime.Now))!),
"DateTime.Now");
+ [Fact]
+ public void Indexer_property()
+ => AssertExpression(
+ Call(
+ New(typeof(List)),
+ typeof(List).GetProperties().Single(
+ p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int))
+ .GetMethod!,
+ Constant(1)), "new List()[1]");
+
[Fact]
public void Private_instance_field_read()
=> AssertExpression(
- Field(Parameter(typeof(Blog), "blog"), "_privateField"),
- """(int)typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).GetValue(blog)""");
+ Field(
+ Parameter(typeof(Blog), "blog"),
+ "_privateField"),
+ "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Get(blog)", unsafeAccessorsAsserter: accessors =>
+ Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")]
+private static extern int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Get(LinqToCSharpSyntaxTranslatorTest.Blog instance);
+""",
+ Assert.Single(accessors),
+ ignoreLineEndingDifferences: true));
[Fact]
public void Private_instance_field_write()
=> AssertStatement(
Assign(
- Field(Parameter(typeof(Blog), "blog"), "_privateField"),
+ Field(
+ Parameter(typeof(Blog), "blog"),
+ "_privateField"),
Constant(8)),
- """typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 8)""");
+ "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(blog) = 8", unsafeAccessorsAsserter: accessors =>
+ Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")]
+private static extern ref int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(LinqToCSharpSyntaxTranslatorTest.Blog instance);
+""",
+ Assert.Single(accessors),
+ ignoreLineEndingDifferences: true));
+
+ // TODO: Also test accessing private static fields
+ // TODO: Also test accessing private properties, instance and static
[Fact]
public void Internal_instance_field_read()
=> AssertExpression(
- Field(Parameter(typeof(Blog), "blog"), "InternalField"),
- """(int)typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("InternalField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).GetValue(blog)""");
+ Field(
+ Parameter(typeof(Blog), "blog"),
+ "InternalField"),
+ "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog_InternalField_Get(blog)", unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "InternalField")]
+private static extern int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog_InternalField_Get(LinqToCSharpSyntaxTranslatorTest.Blog instance);
+""",
+ Assert.Single(unsafeAccessors),
+ ignoreLineEndingDifferences: true));
[Fact]
public void Not()
@@ -429,7 +462,7 @@ public void Method_call_namespace_is_collected()
{
var (translator, _) = CreateTranslator();
var namespaces = new HashSet();
- _ = translator.TranslateExpression(Call(FooMethod), null, namespaces);
+ _ = translator.TranslateExpression(Call(FooMethod), null, namespaces, new HashSet());
Assert.Collection(
namespaces,
ns => Assert.Equal(typeof(LinqToCSharpSyntaxTranslatorTest).Namespace, ns));
@@ -448,9 +481,9 @@ public void Method_call_with_in_out_ref_parameters()
Call(WithInOutRefParameterMethod, [inParam, outParam, refParam])),
"""
{
- int inParam;
- int outParam;
- int refParam;
+ int inParam = default;
+ int outParam = default;
+ int refParam = default;
LinqToCSharpSyntaxTranslatorTest.WithInOutRefParameter(in inParam, out outParam, ref refParam);
}
""");
@@ -467,19 +500,36 @@ public void Instantiation()
[Fact]
public void Instantiation_with_required_properties_and_parameterless_constructor()
=> AssertExpression(
- New(
- typeof(BlogWithRequiredProperties).GetConstructor([])!),
- """
-Activator.CreateInstance()
-""");
+ New(typeof(BlogWithRequiredProperties).GetConstructor([])!),
+ "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor()",
+ unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Constructor)]
+private static extern LinqToCSharpSyntaxTranslatorTest.BlogWithRequiredProperties UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor();
+""",
+ Assert.Single(unsafeAccessors),
+ ignoreLineEndingDifferences: true));
+
+// => AssertExpression(
+// New(typeof(BlogWithRequiredProperties).GetConstructor([])!),
+// """
+// Activator.CreateInstance()
+// """);
[Fact]
public void Instantiation_with_required_properties_and_non_parameterless_constructor()
- => Assert.Throws(
- () => AssertExpression(
- New(
- typeof(BlogWithRequiredProperties).GetConstructor([typeof(string)])!,
- Constant("foo")), ""));
+ => AssertExpression(
+ New(
+ typeof(BlogWithRequiredProperties).GetConstructor([typeof(string)])!,
+ Constant("foo")),
+ """UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor("foo")""",
+ unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal(
+ """
+[UnsafeAccessor(UnsafeAccessorKind.Constructor)]
+private static extern LinqToCSharpSyntaxTranslatorTest.BlogWithRequiredProperties UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor(string name);
+""",
+ Assert.Single(unsafeAccessors),
+ ignoreLineEndingDifferences: true));
[Fact]
public void Instantiation_with_required_properties_with_SetsRequiredMembers()
@@ -571,6 +621,17 @@ public void Invocation_with_argument_that_has_side_effects()
""");
}
+ [Fact]
+ public void Invocation_with_property_argument()
+ => AssertExpression(
+ Invoke(
+ Property(
+ expression: null,
+ typeof(LinqToCSharpSyntaxTranslatorTest).GetProperty(
+ nameof(LambdaExpressionProperty), BindingFlags.Public | BindingFlags.Static)!),
+ Constant(8)),
+ "LinqToCSharpSyntaxTranslatorTest.LambdaExpressionProperty(8)");
+
[Fact]
public void Conditional_expression()
=> AssertExpression(
@@ -669,7 +730,7 @@ public void IfThenElse_nested()
Block(Assign(variable, Constant(3)))))),
"""
{
- int i;
+ int i = default;
if (true)
{
i = 1;
@@ -788,7 +849,7 @@ public void Switch_expression_nested()
Constant(200))))),
"""
{
- int k;
+ int k = default;
var j = 8;
var i = j switch
{
@@ -852,7 +913,7 @@ public void Switch_statement_without_default()
SwitchCase(Block(typeof(void), Assign(parameter, Constant(10))), Constant(-10)))),
"""
{
- int i;
+ int i = default;
switch (7)
{
case -9:
@@ -886,7 +947,7 @@ public void Switch_statement_with_default()
SwitchCase(Assign(parameter, Constant(10)), Constant(-10)))),
"""
{
- int i;
+ int i = default;
switch (7)
{
case -9:
@@ -918,7 +979,7 @@ public void Switch_statement_with_multiple_labels()
SwitchCase(Assign(parameter, Constant(10)), Constant(-10)))),
"""
{
- int i;
+ int i = default;
switch (7)
{
case -9:
@@ -1081,8 +1142,7 @@ public void Same_parameter_instance_is_used_twice_in_nested_lambdas()
[Fact]
public void Block_with_non_standalone_expression_as_statement()
- => AssertStatement(
- Block(Add(Constant(1), Constant(2))),
+ => AssertStatement(Block(Add(Constant(1), Constant(2))),
"""
{
_ = 1 + 2;
@@ -1337,7 +1397,7 @@ public void Lift_variable_in_expression_block()
Constant(9))))),
"""
{
- int j;
+ int j = default;
LinqToCSharpSyntaxTranslatorTest.Foo();
j = 8;
var i = 9;
@@ -1420,7 +1480,7 @@ public void Lift_switch_expression()
SwitchCase(Constant(2), Constant(9))))),
"""
{
- int i;
+ int i = default;
var j = 8;
switch (j)
{
@@ -1474,8 +1534,8 @@ public void Lift_nested_switch_expression()
Constant(200))))),
"""
{
- int i;
- int k;
+ int i = default;
+ int k = default;
var j = 8;
switch (j)
{
@@ -1538,7 +1598,7 @@ public void Lift_non_literal_switch_expression()
SwitchCase(Constant(3), Parameter(typeof(Blog), "blog4"))))),
"""
{
- int i;
+ int i = default;
if (blog1 == blog2)
{
LinqToCSharpSyntaxTranslatorTest.ReturnsIntWithParam(8);
@@ -1813,7 +1873,7 @@ public void Try_catch_finally_statement()
{
LinqToCSharpSyntaxTranslatorTest.Bar();
}
-catch (InvalidOperationException e)when (e.Message == "foo")
+catch (InvalidOperationException e)when (((Exception)e).Message == "foo")
{
LinqToCSharpSyntaxTranslatorTest.Baz();
}
@@ -1844,7 +1904,7 @@ public void Try_catch_statement_with_filter()
{
LinqToCSharpSyntaxTranslatorTest.Foo();
}
-catch (InvalidOperationException e)when (e.Message == "foo")
+catch (InvalidOperationException e)when (((Exception)e).Message == "foo")
{
LinqToCSharpSyntaxTranslatorTest.Bar();
}
@@ -1889,19 +1949,29 @@ public void Try_fault_statement()
// TODO: try/catch expressions
- private void AssertStatement(Expression expression, string expected,
+ private void AssertStatement(
+ Expression expression,
+ string expected,
Dictionary? constantReplacements = null,
- Dictionary? memberAccessReplacements = null)
- => AssertCore(expression, isStatement: true, expected, constantReplacements, memberAccessReplacements);
+ Dictionary? memberAccessReplacements = null,
+ Action>? unsafeAccessorsAsserter = null)
+ => AssertCore(expected, isStatement: true, expression, constantReplacements, memberAccessReplacements, unsafeAccessorsAsserter);
- private void AssertExpression(Expression expression, string expected,
+ private void AssertExpression(
+ Expression expression,
+ string expected,
Dictionary? constantReplacements = null,
- Dictionary? memberAccessReplacements = null)
- => AssertCore(expression, isStatement: false, expected, constantReplacements, memberAccessReplacements);
-
- private void AssertCore(Expression expression, bool isStatement, string expected,
+ Dictionary? memberAccessReplacements = null,
+ Action>? unsafeAccessorsAsserter = null)
+ => AssertCore(expected, isStatement: false, expression, constantReplacements, memberAccessReplacements, unsafeAccessorsAsserter);
+
+ private void AssertCore(
+ string expected,
+ bool isStatement,
+ Expression expression,
Dictionary? constantReplacements,
- Dictionary? memberAccessReplacements)
+ Dictionary? memberAccessReplacements,
+ Action>? unsafeAccessorsAsserter)
{
var typeMappingSource = new SqlServerTypeMappingSource(
TestServiceFactory.Instance.Create(),
@@ -1909,14 +1979,15 @@ public void Try_fault_statement()
var translator = new CSharpHelper(typeMappingSource);
var namespaces = new HashSet();
+ var unsafeAccessors = new HashSet();
var actual = isStatement
- ? translator.Statement(expression, namespaces, constantReplacements, memberAccessReplacements)
- : translator.Expression(expression, namespaces, constantReplacements, memberAccessReplacements);
+ ? translator.Statement(expression, namespaces, unsafeAccessors, constantReplacements, memberAccessReplacements)
+ : translator.Expression(expression, namespaces, unsafeAccessors, constantReplacements, memberAccessReplacements);
if (_outputExpressionTrees)
{
- _testOutputHelper.WriteLine("---- Input LINQ expression tree:");
- _testOutputHelper.WriteLine(_expressionPrinter.PrintExpression(expression));
+ testOutputHelper.WriteLine("---- Input LINQ expression tree:");
+ testOutputHelper.WriteLine(_expressionPrinter.PrintExpression(expression));
}
// TODO: Actually compile the output C# code to make sure it's valid.
@@ -1928,17 +1999,26 @@ public void Try_fault_statement()
if (_outputExpressionTrees)
{
- _testOutputHelper.WriteLine("---- Output Roslyn syntax tree:");
- _testOutputHelper.WriteLine(actual);
+ testOutputHelper.WriteLine("---- Output Roslyn syntax tree:");
+ testOutputHelper.WriteLine(actual);
}
}
catch (EqualException)
{
- _testOutputHelper.WriteLine("---- Output Roslyn syntax tree:");
- _testOutputHelper.WriteLine(actual);
+ testOutputHelper.WriteLine("---- Output Roslyn syntax tree:");
+ testOutputHelper.WriteLine(actual);
throw;
}
+
+ if (unsafeAccessorsAsserter is null)
+ {
+ Assert.Empty(unsafeAccessors);
+ }
+ else
+ {
+ unsafeAccessorsAsserter(unsafeAccessors);
+ }
}
private (LinqToCSharpSyntaxTranslator, AdhocWorkspace) CreateTranslator()
@@ -1993,6 +2073,7 @@ public static int Baz()
public static int MethodWithSixParams(int a, int b, int c, int d, int e, int f)
=> a + b + c + d + e + f;
+ public static Expression> LambdaExpressionProperty => f => f > 5;
private static readonly FieldInfo BlogPrivateField
= typeof(Blog).GetField("_privateField", BindingFlags.NonPublic | BindingFlags.Instance)!;
diff --git a/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj b/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj
index c4b531fd2ad..b07088761fb 100644
--- a/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj
+++ b/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj
@@ -48,7 +48,11 @@
-
+
+
+
+
+
diff --git a/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs
new file mode 100644
index 00000000000..dd352965ca3
--- /dev/null
+++ b/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs
@@ -0,0 +1,247 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.CompilerServices;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers;
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public abstract class AdHocPrecompiledQueryRelationalTestBase(ITestOutputHelper testOutputHelper) : NonSharedModelTestBase
+{
+ [ConditionalFact]
+ public virtual async Task Index_no_evaluatability()
+ {
+ var contextFactory = await InitializeAsync();
+ var options = contextFactory.GetOptions();
+
+ await Test(
+ """
+await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions);
+await context.Database.BeginTransactionAsync();
+
+var blogs = context.JsonEntities.Where(b => b.IntList[b.Id] == 2).ToList();
+""",
+ typeof(JsonContext),
+ options);
+ }
+
+ [ConditionalFact]
+ public virtual async Task Index_with_captured_variable()
+ {
+ var contextFactory = await InitializeAsync();
+ var options = contextFactory.GetOptions();
+
+ await Test(
+ """
+await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions);
+await context.Database.BeginTransactionAsync();
+
+var id = 1;
+var blogs = context.JsonEntities.Where(b => b.IntList[id] == 2).ToList();
+""",
+ typeof(JsonContext),
+ options);
+ }
+
+ [ConditionalFact]
+ public virtual async Task JsonScalar()
+ {
+ var contextFactory = await InitializeAsync();
+ var options = contextFactory.GetOptions();
+
+ await Test(
+ """
+await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions);
+await context.Database.BeginTransactionAsync();
+
+_ = context.JsonEntities.Where(b => b.JsonThing.StringProperty == "foo").ToList();
+""",
+ typeof(JsonContext),
+ options);
+ }
+
+ public class JsonContext(DbContextOptions options) : DbContext(options)
+ {
+ public DbSet JsonEntities { get; set; } = null!;
+
+ protected override void OnModelCreating(ModelBuilder modelBuilder)
+ => modelBuilder.Entity().OwnsOne(j => j.JsonThing, n => n.ToJson());
+ }
+
+ public class JsonEntity
+ {
+ public int Id { get; set; }
+ public List IntList { get; set; } = null!;
+ public JsonThing JsonThing { get; set; } = null!;
+ }
+
+ public class JsonThing
+ {
+ public string StringProperty { get; set; } = null!;
+ }
+
+ [ConditionalFact]
+ public virtual async Task Materialize_non_public()
+ {
+ var contextFactory = await InitializeAsync();
+ var options = contextFactory.GetOptions();
+
+ await Test(
+ """
+await using var context = new AdHocPrecompiledQueryRelationalTestBase.NonPublicContext(dbContextOptions);
+
+var nonPublicEntity = (AdHocPrecompiledQueryRelationalTestBase.NonPublicEntity)Activator.CreateInstance(typeof(AdHocPrecompiledQueryRelationalTestBase.NonPublicEntity), nonPublic: true);
+nonPublicEntity.PrivateFieldExposer = 8;
+nonPublicEntity.PrivatePropertyExposer = 9;
+nonPublicEntity.PrivateAutoPropertyExposer = 10;
+context.NonPublicEntities.Add(nonPublicEntity);
+await context.SaveChangesAsync();
+
+context.ChangeTracker.Clear();
+
+var e = await context.NonPublicEntities.SingleAsync();
+Assert.Equal(8, e.PrivateFieldExposer);
+Assert.Equal(9, e.PrivatePropertyExposer);
+Assert.Equal(10, e.PrivateAutoPropertyExposer);
+""",
+ typeof(NonPublicContext),
+ options,
+ interceptorCodeAsserter: code =>
+ {
+ Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")]""", code);
+ Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]""", code);
+ Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]""", code);
+ Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "set_PrivateProperty")]""", code);
+
+ Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity__privateField(instance) =", code);
+ Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity_PrivateAutoProperty(instance) =", code);
+ Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity_set_PrivateProperty(instance,", code);
+ });
+ }
+
+ public class NonPublicContext(DbContextOptions options) : DbContext(options)
+ {
+ public DbSet NonPublicEntities { get; set; } = null!;
+
+ protected override void OnModelCreating(ModelBuilder modelBuilder)
+ => modelBuilder.Entity(
+ b =>
+ {
+ b.Property("_privateField");
+ b.Property("PrivateProperty");
+ b.Property("PrivateAutoProperty");
+ b.Ignore(b => b.PrivateFieldExposer);
+ b.Ignore(b => b.PrivatePropertyExposer);
+ b.Ignore(b => b.PrivateAutoPropertyExposer);
+ });
+ }
+
+#pragma warning disable CS0169
+#pragma warning disable CS0649
+ public class NonPublicEntity
+ {
+ private NonPublicEntity()
+ {
+ }
+
+ public int Id { get; set; }
+
+ private int? _privateField;
+
+ // ReSharper disable once ConvertToAutoProperty
+ private int? PrivateProperty
+ {
+ get => _privatePropertyBackingField;
+ set => _privatePropertyBackingField = value;
+ }
+ private int? _privatePropertyBackingField;
+
+ private int? PrivateAutoProperty { get; set; }
+
+ // ReSharper disable once ConvertToAutoProperty
+ public int? PrivateFieldExposer
+ {
+ get => _privateField;
+ set => _privateField = value;
+ }
+
+ public int? PrivatePropertyExposer
+ {
+ get => PrivateProperty;
+ set => PrivateProperty = value;
+ }
+
+ public int? PrivateAutoPropertyExposer
+ {
+ get => PrivateAutoProperty;
+ set => PrivateAutoProperty = value;
+ }
+ }
+#pragma warning restore CS0649
+#pragma warning restore CS0169
+
+// [ConditionalFact]
+// public virtual Task JsonScalar()
+// => Test(
+// // TODO: Remove Select() to Id after JSON is supported in materialization
+// """_ = context.Blogs.Where(b => b.JsonThing.SomeProperty == "foo").Select(b => b.Id).ToList();""",
+// modelSourceCode: providerOptions => $$"""
+// public class BlogContext : DbContext
+// {
+// public DbSet Blogs { get; set; }
+//
+// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
+// => optionsBuilder
+// {{providerOptions}}
+// .ReplaceService();
+//
+// protected override void OnModelCreating(ModelBuilder modelBuilder)
+// => modelBuilder.Entity().OwnsOne(b => b.JsonThing, n => n.ToJson());
+// }
+//
+// public class Blog
+// {
+// public int Id { get; set; }
+// public JsonThing JsonThing { get; set; }
+// }
+//
+// public class JsonThing
+// {
+// public string SomeProperty { get; set; }
+// }
+// """);
+
+ protected TestSqlLoggerFactory TestSqlLoggerFactory
+ => (TestSqlLoggerFactory)ListLoggerFactory;
+
+ protected void ClearLog()
+ => TestSqlLoggerFactory.Clear();
+
+ protected void AssertSql(params string[] expected)
+ => TestSqlLoggerFactory.AssertBaseline(expected);
+
+ protected virtual Task Test(
+ string sourceCode,
+ Type dbContextType,
+ DbContextOptions dbContextOptions,
+ Action? interceptorCodeAsserter = null,
+ Action>? precompilationErrorAsserter = null,
+ [CallerMemberName] string callerName = "")
+ => PrecompiledQueryTestHelpers.Test(
+ sourceCode, dbContextOptions, dbContextType, interceptorCodeAsserter, precompilationErrorAsserter, testOutputHelper,
+ AlwaysPrintGeneratedSources,
+ callerName);
+
+ protected virtual bool AlwaysPrintGeneratedSources
+ => false;
+
+ protected abstract PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers { get; }
+
+ protected override IServiceCollection AddServices(IServiceCollection serviceCollection)
+ => base.AddServices(serviceCollection)
+ .AddScoped();
+
+ protected override string StoreName
+ => "AdHocPrecompiledQueryTest";
+}
diff --git a/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs
new file mode 100644
index 00000000000..d0b2cd519dc
--- /dev/null
+++ b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.EntityFrameworkCore.Query.Internal;
+using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers;
+using Blog = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Blog;
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public abstract class PrecompiledQueryRelationalFixture
+ : SharedStoreFixtureBase, ITestSqlLoggerFactory
+{
+ protected override string StoreName
+ => "PrecompiledQueryTest";
+
+ public TestSqlLoggerFactory TestSqlLoggerFactory
+ => (TestSqlLoggerFactory)ListLoggerFactory;
+
+ protected override IServiceCollection AddServices(IServiceCollection serviceCollection)
+ => base.AddServices(serviceCollection)
+ .AddScoped();
+
+ protected override async Task SeedAsync(PrecompiledQueryRelationalTestBase.PrecompiledQueryContext context)
+ {
+ context.Blogs.AddRange(
+ new Blog { Id = 8, Name = "Blog1" },
+ new Blog { Id = 9, Name = "Blog2" });
+ await context.SaveChangesAsync();
+ }
+
+ public abstract PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers { get; }
+}
diff --git a/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs
new file mode 100644
index 00000000000..77acb29a6ee
--- /dev/null
+++ b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs
@@ -0,0 +1,1134 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.ComponentModel.DataAnnotations.Schema;
+using System.Runtime.CompilerServices;
+using Microsoft.CodeAnalysis;
+using Microsoft.EntityFrameworkCore.Internal;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+using Xunit.Sdk;
+using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers;
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+// ReSharper disable InconsistentNaming
+
+public class PrecompiledQueryRelationalTestBase
+{
+ public PrecompiledQueryRelationalTestBase(PrecompiledQueryRelationalFixture fixture, ITestOutputHelper testOutputHelper)
+ {
+ Fixture = fixture;
+ TestOutputHelper = testOutputHelper;
+
+ Fixture.TestSqlLoggerFactory.Clear();
+ Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
+ }
+
+ #region Expression types
+
+ [ConditionalFact]
+ public virtual Task BinaryExpression()
+ => Test("""
+var id = 3;
+var blogs = await context.Blogs.Where(b => b.Id > id).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Conditional_no_evaluatable()
+ => Test("""
+var id = 3;
+var blogs = await context.Blogs.Select(b => b.Id == 2 ? "yes" : "no").ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Conditional_contains_captured_variable()
+ => Test("""
+var yes = "yes";
+var blogs = await context.Blogs.Select(b => b.Id == 2 ? yes : "no").ToListAsync();
+""");
+
+ // We do not support embedding Expression builder API calls into the query; this would require CSharpToLinqTranslator to actually
+ // evaluate those APIs and embed the results into the tree. It's (at least potentially) a form of dynamic query, unsupported for now.
+ [ConditionalFact]
+ public virtual Task Invoke_no_evaluatability_is_not_supported()
+ => Test(
+ """
+Expression> lambda = b => b.Name == "foo";
+var parameter = Expression.Parameter(typeof(Blog), "b");
+
+var blogs = await context.Blogs
+ .Where(Expression.Lambda>(Expression.Invoke(lambda, parameter), parameter))
+ .ToListAsync();
+""",
+ errorAsserter: errors => Assert.IsType(errors.Single().Exception));
+
+ [ConditionalFact]
+ public virtual Task ListInit_no_evaluatability()
+ => Test("_ = await context.Blogs.Select(b => new List { b.Id, b.Id + 1 }).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task ListInit_with_evaluatable_with_captured_variable()
+ => Test(
+ """
+var i = 1;
+_ = await context.Blogs.Select(b => new List { b.Id, i }).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task ListInit_with_evaluatable_without_captured_variable()
+ => Test(
+ """
+var i = 1;
+_ = await context.Blogs.Select(b => new List { b.Id, 8 }).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task ListInit_fully_evaluatable()
+ => Test("""
+var blog = await context.Blogs.Where(b => new List { 7, 8 }.Contains(b.Id)).SingleAsync();
+Assert.Equal("Blog1", blog.Name);
+""");
+
+ [ConditionalFact]
+ public virtual Task MethodCallExpression_no_evaluatability()
+ => Test("_ = await context.Blogs.Where(b => b.Name.StartsWith(b.Name)).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task MethodCallExpression_with_evaluatable_with_captured_variable()
+ => Test("""
+var pattern = "foo";
+_ = await context.Blogs.Where(b => b.Name.StartsWith(pattern)).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task MethodCallExpression_with_evaluatable_without_captured_variable()
+ => Test("""_ = await context.Blogs.Where(b => b.Name.StartsWith("foo")).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task MethodCallExpression_fully_evaluatable()
+ => Test("""_ = await context.Blogs.Where(b => "foobar".StartsWith("foo")).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task New_with_no_arguments()
+ => Test(
+ """
+var i = 8;
+_ = await context.Blogs.Where(b => b == new Blog()).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Where_New_with_captured_variable()
+ => Test(
+ """
+var i = 8;
+_ = await context.Blogs.Where(b => b == new Blog(i, b.Name)).ToListAsync();
+""",
+ errorAsserter: errors => Assert.StartsWith("Translation of", errors.Single().Exception.Message));
+
+ [ConditionalFact]
+ public virtual Task Select_New_with_captured_variable()
+ => Test(
+ """
+var i = 8;
+_ = await context.Blogs.Select(b => new Blog(i, b.Name)).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task MemberInit_no_evaluatable()
+ => Test("_ = await context.Blogs.Select(b => new Blog { Id = b.Id, Name = b.Name }).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task MemberInit_contains_captured_variable()
+ => Test(
+ """
+var id = 8;
+_ = await context.Blogs.Select(b => new Blog { Id = id, Name = b.Name }).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task MemberInit_evaluatable_as_constant()
+ => Test("""_ = await context.Blogs.Select(b => new Blog { Id = 1, Name = "foo" }).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task MemberInit_evaluatable_as_parameter()
+ => Test(
+ """
+var id = 8;
+var foo = "foo";
+_ = await context.Blogs.Select(b => new Blog { Id = id, Name = foo }).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task NewArray()
+ => Test(
+ """
+var i = 8;
+_ = await context.Blogs.Select(b => new[] { b.Id, b.Id + i }).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Unary()
+ => Test("_ = await context.Blogs.Where(b => (short)b.Id == (short)8).ToListAsync();");
+
+ #endregion Expression types
+
+ #region Terminating operators
+
+ [ConditionalFact]
+ public virtual Task Terminating_AsEnumerable()
+ => Test("""
+var blogs = context.Blogs.AsEnumerable().ToList();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_AsAsyncEnumerable_on_DbSet()
+ => Test("""
+var sum = 0;
+await foreach (var blog in context.Blogs.AsAsyncEnumerable())
+{
+ sum += blog.Id;
+}
+Assert.Equal(17, sum);
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_AsAsyncEnumerable_on_IQueryable()
+ => Test("""
+var sum = 0;
+await foreach (var blog in context.Blogs.Where(b => b.Id > 8).AsAsyncEnumerable())
+{
+ sum += blog.Id;
+}
+Assert.Equal(9, sum);
+""");
+
+ [ConditionalFact]
+ public virtual Task Foreach_sync_over_operator()
+ => Test(
+ """
+foreach (var blog in context.Blogs.Where(b => b.Id > 8))
+{
+}
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToArray()
+ => Test(
+ """
+var blogs = context.Blogs.ToArray();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToArrayAsync()
+ => Test(
+ """
+var blogs = await context.Blogs.ToArrayAsync();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToDictionary()
+ => Test(
+ """
+var blogs = context.Blogs.ToDictionary(b => b.Id, b => b.Name);
+Assert.Equal(2, blogs.Count);
+Assert.Equal("Blog1", blogs[8]);
+Assert.Equal("Blog2", blogs[9]);
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToDictionaryAsync()
+ => Test(
+ """
+var blogs = await context.Blogs.ToDictionaryAsync(b => b.Id, b => b.Name);
+Assert.Equal(2, blogs.Count);
+Assert.Equal("Blog1", blogs[8]);
+Assert.Equal("Blog2", blogs[9]);
+""");
+
+ [ConditionalFact]
+ public virtual Task ToDictionary_over_anonymous_type()
+ => Test("_ = context.Blogs.Select(b => new { b.Id, b.Name }).ToDictionary(x => x.Id, x => x.Name);");
+
+ [ConditionalFact]
+ public virtual Task ToDictionaryAsync_over_anonymous_type()
+ => Test("_ = await context.Blogs.Select(b => new { b.Id, b.Name }).ToDictionaryAsync(x => x.Id, x => x.Name);");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToHashSet()
+ => Test(
+ """
+var blogs = context.Blogs.ToHashSet();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToHashSetAsync()
+ => Test(
+ """
+var blogs = await context.Blogs.ToHashSetAsync();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToLookup()
+ => Test("_ = context.Blogs.ToLookup(b => b.Name);");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToList()
+ => Test(
+ """
+var blogs = context.Blogs.ToList();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ToListAsync()
+ => Test(
+ """
+var blogs = await context.Blogs.ToListAsync();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ // foreach/await foreach directly over DbSet properties doesn't isn't supported, since we can't intercept property accesses.
+ [ConditionalFact]
+ public virtual async Task Foreach_sync_over_DbSet_property_is_not_supported()
+ {
+ // TODO: Assert diagnostics about non-intercepted query
+ var exception = await Assert.ThrowsAsync(
+ () => Test(
+ """
+foreach (var blog in context.Blogs)
+{
+}
+"""));
+ Assert.Equal(NonCompilingQueryCompiler.ErrorMessage, exception.Message);
+ }
+
+ // foreach/await foreach directly over DbSet properties doesn't isn't supported, since we can't intercept property accesses.
+ [ConditionalFact]
+ public virtual async Task Foreach_async_is_not_supported()
+ {
+ // TODO: Assert diagnostics about non-intercepted query
+ var exception = await Assert.ThrowsAsync(
+ () => Test(
+ """
+await foreach (var blog in context.Blogs)
+{
+}
+"""));
+ Assert.Equal(NonCompilingQueryCompiler.ErrorMessage, exception.Message);
+ }
+
+ #endregion Terminating operators
+
+ #region Reducing terminating operators
+
+ [ConditionalFact]
+ public virtual Task Terminating_All()
+ => Test(
+ """
+Assert.True(context.Blogs.All(b => b.Id > 7));
+Assert.False(context.Blogs.All(b => b.Id > 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_AllAsync()
+ => Test(
+ """
+Assert.True(await context.Blogs.AllAsync(b => b.Id > 7));
+Assert.False(await context.Blogs.AllAsync(b => b.Id > 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Any()
+ => Test(
+ """
+Assert.True(context.Blogs.Where(b => b.Id > 7).Any());
+Assert.False(context.Blogs.Where(b => b.Id < 7).Any());
+
+Assert.True(context.Blogs.Any(b => b.Id > 7));
+Assert.False(context.Blogs.Any(b => b.Id < 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_AnyAsync()
+ => Test(
+ """
+Assert.True(await context.Blogs.Where(b => b.Id > 7).AnyAsync());
+Assert.False(await context.Blogs.Where(b => b.Id < 7).AnyAsync());
+
+Assert.True(await context.Blogs.AnyAsync(b => b.Id > 7));
+Assert.False(await context.Blogs.AnyAsync(b => b.Id < 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Average()
+ => Test(
+ """
+Assert.Equal(8.5, context.Blogs.Select(b => b.Id).Average());
+Assert.Equal(8.5, context.Blogs.Average(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_AverageAsync()
+ => Test(
+ """
+Assert.Equal(8.5, await context.Blogs.Select(b => b.Id).AverageAsync());
+Assert.Equal(8.5, await context.Blogs.AverageAsync(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Contains()
+ => Test(
+ """
+Assert.True(context.Blogs.Select(b => b.Id).Contains(8));
+Assert.False(context.Blogs.Select(b => b.Id).Contains(7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ContainsAsync()
+ => Test(
+ """
+Assert.True(await context.Blogs.Select(b => b.Id).ContainsAsync(8));
+Assert.False(await context.Blogs.Select(b => b.Id).ContainsAsync(7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Count()
+ => Test(
+ """
+Assert.Equal(2, context.Blogs.Count());
+Assert.Equal(1, context.Blogs.Count(b => b.Id > 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_CountAsync()
+ => Test(
+ """
+Assert.Equal(2, await context.Blogs.CountAsync());
+Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id > 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ElementAt()
+ => Test(
+ """
+Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).ElementAt(1).Name);
+Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).ElementAt(3));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ElementAtAsync()
+ => Test(
+ """
+Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).ElementAtAsync(1)).Name);
+await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).ElementAtAsync(3));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ElementAtOrDefault()
+ => Test(
+ """
+Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).ElementAtOrDefault(1).Name);
+Assert.Null(context.Blogs.OrderBy(b => b.Id).ElementAtOrDefault(3));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ElementAtOrDefaultAsync()
+ => Test(
+ """
+Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).ElementAtOrDefaultAsync(1)).Name);
+Assert.Null(await context.Blogs.OrderBy(b => b.Id).ElementAtOrDefaultAsync(3));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_First()
+ => Test(
+ """
+Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).First().Name);
+Assert.Throws(() => context.Blogs.Where(b => b.Id == 7).First());
+
+Assert.Equal("Blog1", context.Blogs.First(b => b.Id == 8).Name);
+Assert.Throws(() => context.Blogs.First(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_FirstAsync()
+ => Test(
+ """
+Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).FirstAsync()).Name);
+await Assert.ThrowsAsync(() => context.Blogs.Where(b => b.Id == 7).FirstAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.FirstAsync(b => b.Id == 8)).Name);
+await Assert.ThrowsAsync(() => context.Blogs.FirstAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_FirstOrDefault()
+ => Test(
+ """
+Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).FirstOrDefault().Name);
+Assert.Null(context.Blogs.Where(b => b.Id == 7).FirstOrDefault());
+
+Assert.Equal("Blog1", context.Blogs.FirstOrDefault(b => b.Id == 8).Name);
+Assert.Null(context.Blogs.FirstOrDefault(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_FirstOrDefaultAsync()
+ => Test(
+ """
+Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).FirstOrDefaultAsync()).Name);
+Assert.Null(await context.Blogs.Where(b => b.Id == 7).FirstOrDefaultAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.FirstOrDefaultAsync(b => b.Id == 8)).Name);
+Assert.Null(await context.Blogs.FirstOrDefaultAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_GetEnumerator()
+ => Test(
+ """
+using var enumerator = context.Blogs.Where(b => b.Id == 8).GetEnumerator();
+Assert.True(enumerator.MoveNext());
+Assert.Equal("Blog1", enumerator.Current.Name);
+Assert.False(enumerator.MoveNext());
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Last()
+ => Test(
+ """
+Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).Last().Name);
+Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).Last());
+
+Assert.Equal("Blog1", context.Blogs.OrderBy(b => b.Id).Last(b => b.Id == 8).Name);
+Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).Last(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_LastAsync()
+ => Test(
+ """
+Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).LastAsync()).Name);
+await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.OrderBy(b => b.Id).LastAsync(b => b.Id == 8)).Name);
+await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).LastAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_LastOrDefault()
+ => Test(
+ """
+Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).LastOrDefault().Name);
+Assert.Null(context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastOrDefault());
+
+Assert.Equal("Blog1", context.Blogs.OrderBy(b => b.Id).LastOrDefault(b => b.Id == 8).Name);
+Assert.Null(context.Blogs.OrderBy(b => b.Id).LastOrDefault(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_LastOrDefaultAsync()
+ => Test(
+ """
+Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync()).Name);
+Assert.Null(await context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastOrDefaultAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync(b => b.Id == 8)).Name);
+Assert.Null(await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_LongCount()
+ => Test(
+ """
+Assert.Equal(2, context.Blogs.LongCount());
+Assert.Equal(1, context.Blogs.LongCount(b => b.Id == 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_LongCountAsync()
+ => Test(
+ """
+Assert.Equal(2, await context.Blogs.LongCountAsync());
+Assert.Equal(1, await context.Blogs.LongCountAsync(b => b.Id == 8));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Max()
+ => Test(
+ """
+Assert.Equal(9, context.Blogs.Select(b => b.Id).Max());
+Assert.Equal(9, context.Blogs.Max(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_MaxAsync()
+ => Test(
+ """
+Assert.Equal(9, await context.Blogs.Select(b => b.Id).MaxAsync());
+Assert.Equal(9, await context.Blogs.MaxAsync(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Min()
+ => Test(
+ """
+Assert.Equal(8, context.Blogs.Select(b => b.Id).Min());
+Assert.Equal(8, context.Blogs.Min(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_MinAsync()
+ => Test(
+ """
+Assert.Equal(8, await context.Blogs.Select(b => b.Id).MinAsync());
+Assert.Equal(8, await context.Blogs.MinAsync(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Single()
+ => Test(
+ """
+Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).Single().Name);
+Assert.Throws(() => context.Blogs.Where(b => b.Id == 7).Single());
+
+Assert.Equal("Blog1", context.Blogs.Single(b => b.Id == 8).Name);
+Assert.Throws(() => context.Blogs.Single(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_SingleAsync()
+ => Test(
+ """
+Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).SingleAsync()).Name);
+await Assert.ThrowsAsync(() => context.Blogs.Where(b => b.Id == 7).SingleAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.SingleAsync(b => b.Id == 8)).Name);
+await Assert.ThrowsAsync(() => context.Blogs.SingleAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_SingleOrDefault()
+ => Test(
+ """
+Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).SingleOrDefault().Name);
+Assert.Null(context.Blogs.Where(b => b.Id == 7).SingleOrDefault());
+
+Assert.Equal("Blog1", context.Blogs.SingleOrDefault(b => b.Id == 8).Name);
+Assert.Null(context.Blogs.SingleOrDefault(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_SingleOrDefaultAsync()
+ => Test(
+ """
+Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).SingleOrDefaultAsync()).Name);
+Assert.Null(await context.Blogs.Where(b => b.Id == 7).SingleOrDefaultAsync());
+
+Assert.Equal("Blog1", (await context.Blogs.SingleOrDefaultAsync(b => b.Id == 8)).Name);
+Assert.Null(await context.Blogs.SingleOrDefaultAsync(b => b.Id == 7));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_Sum()
+ => Test(
+ """
+Assert.Equal(17, context.Blogs.Select(b => b.Id).Sum());
+Assert.Equal(17, context.Blogs.Sum(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_SumAsync()
+ => Test(
+ """
+Assert.Equal(17, await context.Blogs.Select(b => b.Id).SumAsync());
+Assert.Equal(17, await context.Blogs.SumAsync(b => b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ExecuteDelete()
+ => Test(
+ """
+await context.Database.BeginTransactionAsync();
+
+var rowsAffected = context.Blogs.Where(b => b.Id > 8).ExecuteDelete();
+Assert.Equal(1, rowsAffected);
+Assert.Equal(1, await context.Blogs.CountAsync());
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ExecuteDeleteAsync()
+ => Test(
+ """
+await context.Database.BeginTransactionAsync();
+
+var rowsAffected = await context.Blogs.Where(b => b.Id > 8).ExecuteDeleteAsync();
+Assert.Equal(1, rowsAffected);
+Assert.Equal(1, await context.Blogs.CountAsync());
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ExecuteUpdate()
+ => Test(
+ """
+await context.Database.BeginTransactionAsync();
+
+var suffix = "Suffix";
+var rowsAffected = context.Blogs.Where(b => b.Id > 8).ExecuteUpdate(setters => setters.SetProperty(b => b.Name, b => b.Name + suffix));
+Assert.Equal(1, rowsAffected);
+Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id == 9 && b.Name == "Blog2Suffix"));
+""");
+
+ [ConditionalFact]
+ public virtual Task Terminating_ExecuteUpdateAsync()
+ => Test(
+ """
+await context.Database.BeginTransactionAsync();
+
+var suffix = "Suffix";
+var rowsAffected = await context.Blogs.Where(b => b.Id > 8).ExecuteUpdateAsync(setters => setters.SetProperty(b => b.Name, b => b.Name + suffix));
+Assert.Equal(1, rowsAffected);
+Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id == 9 && b.Name == "Blog2Suffix"));
+""");
+
+ #endregion Reducing terminating operators
+
+ #region SQL expression quotability
+
+ [ConditionalFact]
+ public virtual Task Union()
+ => Test(
+ """
+var blogs = await context.Blogs.Where(b => b.Id > 7)
+ .Union(context.Blogs.Where(b => b.Id < 10))
+ .OrderBy(b => b.Id)
+ .ToListAsync();
+
+Assert.Collection(blogs,
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Concat()
+ => Test(
+ """
+var blogs = await context.Blogs.Where(b => b.Id > 7)
+ .Concat(context.Blogs.Where(b => b.Id < 10))
+ .OrderBy(b => b.Id)
+ .ToListAsync();
+
+Assert.Collection(blogs,
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id),
+ b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Intersect()
+ => Test(
+ """
+var blogs = await context.Blogs.Where(b => b.Id > 7)
+ .Intersect(context.Blogs.Where(b => b.Id > 8))
+ .OrderBy(b => b.Id)
+ .ToListAsync();
+
+Assert.Collection(blogs, b => Assert.Equal(9, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task Except()
+ => Test(
+ """
+var blogs = await context.Blogs.Where(b => b.Id > 7)
+ .Except(context.Blogs.Where(b => b.Id > 8))
+ .OrderBy(b => b.Id)
+ .ToListAsync();
+
+Assert.Collection(blogs, b => Assert.Equal(8, b.Id));
+""");
+
+ [ConditionalFact]
+ public virtual Task ValuesExpression()
+ => Test("_ = await context.Blogs.Where(b => new[] { 7, b.Id }.Count(i => i > 8) == 2).ToListAsync();");
+
+ // Tests e.g. OPENJSON on SQL Server
+ [ConditionalFact]
+ public virtual Task Contains_with_parameterized_collection()
+ => Test(
+ """
+int[] ids = [1, 2, 3];
+_ = await context.Blogs.Where(b => ids.Contains(b.Id)).ToListAsync();
+""");
+
+ // TODO: SQL Server-specific
+ [ConditionalFact]
+ public virtual Task FromSqlRaw()
+ => Test("""_ = await context.Blogs.FromSqlRaw("SELECT * FROM Blogs").OrderBy(b => b.Id).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task FromSql_with_FormattableString_parameters()
+ => Test("""_ = await context.Blogs.FromSql($"SELECT * FROM Blogs WHERE Id > {8} AND Id < {9}").OrderBy(b => b.Id).ToListAsync();""");
+
+ #endregion SQL expression quotability
+
+ #region Different DbContext expressions
+
+ [ConditionalFact]
+ public virtual Task DbContext_as_local_variable()
+ => Test(
+ """
+var context2 = context;
+
+_ = await context2.Blogs.ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task DbContext_as_field()
+ => FullSourceTest(
+ """
+public static class TestContainer
+{
+ private static PrecompiledQueryContext _context;
+
+ public static async Task Test(DbContextOptions dbContextOptions)
+ {
+ using (_context = new PrecompiledQueryContext(dbContextOptions))
+ {
+ var blogs = await _context.Blogs.ToListAsync();
+ Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+ }
+ }
+}
+""");
+
+ [ConditionalFact]
+ public virtual Task DbContext_as_property()
+ => FullSourceTest(
+ """
+public static class TestContainer
+{
+ private static PrecompiledQueryContext Context { get; set; }
+
+ public static async Task Test(DbContextOptions dbContextOptions)
+ {
+ using (Context = new PrecompiledQueryContext(dbContextOptions))
+ {
+ var blogs = await Context.Blogs.ToListAsync();
+ Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+ }
+ }
+}
+""");
+
+ [ConditionalFact]
+ public virtual Task DbContext_as_captured_variable()
+ => Test(
+ """
+Func> foo = () => context.Blogs.ToList();
+_ = foo();
+""");
+
+ [ConditionalFact]
+ public virtual Task DbContext_as_method_invocation_result()
+ => FullSourceTest(
+ """
+public static class TestContainer
+{
+ private static PrecompiledQueryContext _context;
+
+ public static async Task Test(DbContextOptions dbContextOptions)
+ {
+ using (_context = new PrecompiledQueryContext(dbContextOptions))
+ {
+ var blogs = await GetContext().Blogs.ToListAsync();
+ Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+ }
+ }
+
+ private static PrecompiledQueryContext GetContext()
+ => _context;
+}
+""");
+
+ #endregion Different DbContext expressions
+
+ #region Negative cases
+
+ [ConditionalFact]
+ public virtual Task Dynamic_query_does_not_get_precompiled()
+ => Test(
+ """
+var query = context.Blogs;
+var blogs = await query.ToListAsync();
+""",
+ errorAsserter: errors =>
+ {
+ var dynamicQueryError = errors.Single();
+ Assert.IsType(dynamicQueryError.Exception);
+ Assert.Equal(DesignStrings.DynamicQueryNotSupported, dynamicQueryError.Exception.Message);
+ Assert.Equal("query.ToListAsync()", dynamicQueryError.SyntaxNode.NormalizeWhitespace().ToFullString());
+ });
+
+ [ConditionalFact]
+ public virtual Task ToList_over_objects_does_not_get_precompiled()
+ => Test(
+ """
+int[] numbers = [1, 2, 3];
+var lessNumbers = numbers.Where(i => i > 1).ToList();
+""");
+
+ [ConditionalFact]
+ public virtual async Task Query_compilation_failure()
+ => await Test(
+ "_ = await context.Blogs.Where(b => PrecompiledQueryRelationalTestBase.Untranslatable(b.Id) == 999).ToListAsync();",
+ errorAsserter: errors
+ => Assert.Contains(
+ CoreStrings.TranslationFailedWithDetails(
+ "DbSet()\n .Where(b => PrecompiledQueryRelationalTestBase.Untranslatable(b.Id) == 999)",
+ "Translation of method 'Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Untranslatable' failed. If this method can be mapped to your custom function, see https://go.microsoft.com/fwlink/?linkid=2132413 for more information."),
+ errors.Single().Exception.Message));
+
+ public static int Untranslatable(int foo)
+ => throw new InvalidOperationException();
+
+ [ConditionalFact]
+ public virtual Task EF_Constant_is_not_supported()
+ => Test(
+ "_ = await context.Blogs.Where(b => b.Id > EF.Constant(8)).ToListAsync();",
+ errorAsserter: errors
+ => Assert.Equal(CoreStrings.EFConstantNotSupportedInPrecompiledQueries, errors.Single().Exception.Message));
+
+ [ConditionalFact]
+ public virtual Task NotParameterizedAttribute_with_constant()
+ => Test(
+ """
+var blog = await context.Blogs.Where(b => EF.Property(b, "Name") == "Blog2").SingleAsync();
+Assert.Equal(9, blog.Id);
+""");
+
+ [ConditionalFact]
+ public virtual Task NotParameterizedAttribute_is_not_supported_with_non_constant_argument()
+ => Test(
+ """
+var propertyName = "Name";
+var blog = await context.Blogs.Where(b => EF.Property(b, propertyName) == "Blog2").SingleAsync();
+""",
+ errorAsserter: errors
+ => Assert.Equal(
+ CoreStrings.NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries("propertyName", "Property"),
+ errors.Single().Exception.Message));
+
+ [ConditionalFact]
+ public virtual Task Query_syntax_is_not_supported()
+ => Test(
+ """
+var id = 3;
+var blogs = await (
+ from b in context.Blogs
+ where b.Id > 8
+ select b).ToListAsync();
+""",
+ errorAsserter: errors
+ => Assert.Equal(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries, errors.Single().Exception.Message));
+
+ #endregion Negative cases
+
+ [ConditionalFact]
+ public virtual Task Select_changes_type()
+ => Test("_ = await context.Blogs.Select(b => b.Name).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task OrderBy()
+ => Test("_ = await context.Blogs.OrderBy(b => b.Name).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task Skip()
+ => Test("_ = await context.Blogs.OrderBy(b => b.Name).Skip(1).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task Take()
+ => Test("_ = await context.Blogs.OrderBy(b => b.Name).Take(1).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task Project_anonymous_object()
+ => Test("""_ = await context.Blogs.Select(b => new { Foo = b.Name + "Foo" }).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task Two_captured_variables_in_same_lambda()
+ => Test("""
+var yes = "yes";
+var no = "no";
+var blogs = await context.Blogs.Select(b => b.Id == 3 ? yes : no).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Two_captured_variables_in_different_lambdas()
+ => Test("""
+var starts = "blog";
+var ends = "2";
+var blog = await context.Blogs.Where(b => b.Name.StartsWith(starts)).Where(b => b.Name.EndsWith(ends)).SingleAsync();
+Assert.Equal(9, blog.Id);
+""");
+
+ [ConditionalFact]
+ public virtual Task Same_captured_variable_twice_in_same_lambda()
+ => Test("""
+var foo = "X";
+var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo) && b.Name.EndsWith(foo)).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Same_captured_variable_twice_in_different_lambdas()
+ => Test("""
+var foo = "X";
+var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo)).Where(b => b.Name.EndsWith(foo)).ToListAsync();
+""");
+
+ [ConditionalFact]
+ public virtual Task Include_single()
+ => Test("var blogs = await context.Blogs.Include(b => b.Posts).Where(b => b.Id > 8).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task Include_split()
+ => Test("var blogs = await context.Blogs.AsSplitQuery().Include(b => b.Posts).ToListAsync();");
+
+ [ConditionalFact]
+ public virtual Task Final_GroupBy()
+ => Test("""var blogs = await context.Blogs.GroupBy(b => b.Name).ToListAsync();""");
+
+ [ConditionalFact]
+ public virtual Task Multiple_queries_with_captured_variables()
+ => Test("""
+var id1 = 8;
+var id2 = 9;
+var blogs = await context.Blogs.Where(b => b.Id == id1 || b.Id == id2).ToListAsync();
+var blog1 = await context.Blogs.Where(b => b.Id == id1).SingleAsync();
+Assert.Collection(
+ blogs.OrderBy(b => b.Id),
+ b => Assert.Equal(8, b.Id),
+ b => Assert.Equal(9, b.Id));
+Assert.Equal("Blog1", blog1.Name);
+""");
+
+ [ConditionalFact]
+ public virtual Task Unsafe_accessor_gets_generated_once_for_multiple_queries()
+ => Test("""
+var blogs1 = await context.Blogs.ToListAsync();
+var blogs2 = await context.Blogs.ToListAsync();
+""",
+ interceptorCodeAsserter: code => Assert.Equal(2, code.Split("GetSet_Microsoft_EntityFrameworkCore_Query_Blog_Id").Length));
+
+ public class PrecompiledQueryContext(DbContextOptions options) : DbContext(options)
+ {
+ public DbSet Blogs { get; set; } = null!;
+ public DbSet Posts { get; set; } = null!;
+ }
+
+ protected PrecompiledQueryRelationalFixture Fixture { get; }
+ protected ITestOutputHelper TestOutputHelper { get; }
+
+ protected void AssertSql(params string[] expected)
+ => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
+
+ protected virtual Task Test(
+ string sourceCode,
+ Action? interceptorCodeAsserter = null,
+ Action>? errorAsserter = null,
+ [CallerMemberName] string callerName = "")
+ => Fixture.PrecompiledQueryTestHelpers.Test(
+ """
+await using var context = new PrecompiledQueryContext(dbContextOptions);
+
+""" + sourceCode,
+ Fixture.ServiceProvider.GetRequiredService(),
+ typeof(PrecompiledQueryContext),
+ interceptorCodeAsserter,
+ errorAsserter,
+ TestOutputHelper,
+ AlwaysPrintGeneratedSources,
+ callerName);
+
+ protected virtual Task FullSourceTest(
+ string sourceCode,
+ Action? interceptorCodeAsserter = null,
+ Action>? errorAsserter = null,
+ [CallerMemberName] string callerName = "")
+ => Fixture.PrecompiledQueryTestHelpers.FullSourceTest(
+ sourceCode,
+ Fixture.ServiceProvider.GetRequiredService(),
+ typeof(PrecompiledQueryContext),
+ interceptorCodeAsserter,
+ errorAsserter,
+ TestOutputHelper,
+ AlwaysPrintGeneratedSources,
+ callerName);
+
+ protected virtual bool AlwaysPrintGeneratedSources
+ => false;
+
+ public class Blog
+ {
+ public Blog()
+ {
+ }
+
+ public Blog(int id, string name)
+ {
+ Id = id;
+ Name = name;
+ }
+
+ [DatabaseGenerated(DatabaseGeneratedOption.None)]
+ public int Id { get; set; }
+ public string? Name { get; set; }
+
+ public List Posts { get; set; } = new();
+ }
+
+ public class Post
+ {
+ public int Id { get; set; }
+ public string? Title { get; set; }
+
+ public Blog? Blog { get; set; }
+ }
+
+ public static IEnumerable IsAsyncData = new object[][] { [false], [true] };
+}
diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs
new file mode 100644
index 00000000000..960b3bab4f2
--- /dev/null
+++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs
@@ -0,0 +1,297 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.ComponentModel;
+using System.ComponentModel.DataAnnotations.Schema;
+using System.Runtime.Loader;
+using System.Text.Encodings.Web;
+using System.Text.Json;
+using System.Text.RegularExpressions;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.Editing;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+using Microsoft.Extensions.Caching.Memory;
+
+namespace Microsoft.EntityFrameworkCore.TestUtilities;
+
+public abstract class PrecompiledQueryTestHelpers
+{
+ private readonly MetadataReference[] _metadataReferences;
+
+ protected PrecompiledQueryTestHelpers()
+ => _metadataReferences = BuildMetadataReferences().ToArray();
+
+ public Task Test(
+ string sourceCode,
+ DbContextOptions dbContextOptions,
+ Type dbContextType,
+ Action? interceptorCodeAsserter,
+ Action>? errorAsserter,
+ ITestOutputHelper testOutputHelper,
+ bool alwaysPrintGeneratedSources,
+ string callerName)
+ {
+ var source = $$"""
+public static class TestContainer
+{
+ public static async Task Test(DbContextOptions dbContextOptions)
+ {
+{{sourceCode}}
+ }
+}
+""";
+ return FullSourceTest(
+ source, dbContextOptions, dbContextType, interceptorCodeAsserter, errorAsserter, testOutputHelper, alwaysPrintGeneratedSources,
+ callerName);
+ }
+
+ public async Task FullSourceTest(
+ string sourceCode,
+ DbContextOptions dbContextOptions,
+ Type dbContextType,
+ Action? interceptorCodeAsserter,
+ Action>? errorAsserter,
+ ITestOutputHelper testOutputHelper,
+ bool alwaysPrintGeneratedSources,
+ string callerName)
+ {
+ // The overall end-to-end testing for precompiled queries is as follows:
+ // 1. Compile the user code, produce an assembly from it and load it. We need to do this since precompiled query generation requires
+ // an actual DbContext instance, from which we get the model, services, ec.
+ // 2. Do precompiled query generation. This outputs additional source files (syntax trees) containing interceptors for the located
+ // EF LINQ queries.
+ // 3. Integrate the additional syntax trees into the compilation, and again, produce an assembly from it and load it.
+ // 4. Use reflection to find the EntryPoint (Main method) on this assembly, and invoke it.
+ var source = $"""
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using System.Threading.Tasks;
+using System.Text.RegularExpressions;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.EntityFrameworkCore.Query;
+using Xunit;
+using static Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase;
+//using Microsoft.EntityFrameworkCore.PrecompiledQueryTest;
+
+{sourceCode}
+""";
+
+ // This turns on the interceptors feature for the designated namespace(s).
+ var parseOptions = new CSharpParseOptions().WithFeatures(
+ new[]
+ {
+ new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.EntityFrameworkCore.GeneratedInterceptors")
+ });
+
+ var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions, path: "Test.cs");
+
+ var compilation = CSharpCompilation.Create(
+ "TestCompilation",
+ syntaxTrees: [syntaxTree],
+ _metadataReferences,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+
+ IReadOnlyList? generatedFiles = null;
+
+ try
+ {
+ // The test code compiled - emit and assembly and load it.
+ var (assemblyLoadContext, assembly) = EmitAndLoadAssembly(compilation, callerName + "_Original");
+ try
+ {
+ var workspace = new AdhocWorkspace();
+ var syntaxGenerator = SyntaxGenerator.GetGenerator(workspace, LanguageNames.CSharp);
+
+ // TODO: Look up as regular dependencies
+ var precompiledQueryCodeGenerator = new PrecompiledQueryCodeGenerator();
+
+ await using var dbContext = (DbContext)Activator.CreateInstance(dbContextType, args: [dbContextOptions])!;
+
+ // Perform precompilation
+ var precompilationErrors = new List();
+ generatedFiles = precompiledQueryCodeGenerator.GeneratePrecompiledQueries(
+ compilation, syntaxGenerator, dbContext, precompilationErrors, additionalAssembly: assembly);
+
+ if (errorAsserter is null)
+ {
+ if (precompilationErrors.Count > 0)
+ {
+ Assert.Fail("Precompilation error: " + precompilationErrors[0].Exception);
+ }
+ }
+ else
+ {
+ errorAsserter(precompilationErrors);
+ return;
+ }
+ }
+ finally
+ {
+ assemblyLoadContext.Unload();
+ }
+
+ // We now have the code-generated interceptors; add them to the compilation and re-emit.
+ compilation = compilation.AddSyntaxTrees(
+ generatedFiles.Select(f => CSharpSyntaxTree.ParseText(f.Code, parseOptions, f.Path)));
+
+ // We have the final compilation, including the interceptors. Emit and load it, and then invoke its entry point, which contains
+ // the original test code with the EF LINQ query, etc.
+ (assemblyLoadContext, assembly) = EmitAndLoadAssembly(compilation, callerName + "_WithInterceptors");
+ try
+ {
+ await using var dbContext = (DbContext)Activator.CreateInstance(dbContextType, dbContextOptions)!;
+
+ var testContainer = assembly.ExportedTypes.Single(t => t.Name == "TestContainer");
+ var testMethod = testContainer.GetMethod("Test")!;
+ await (Task)testMethod.Invoke(obj: null, parameters: [dbContextOptions])!;
+ }
+ finally
+ {
+ assemblyLoadContext.Unload();
+ }
+ }
+ catch
+ {
+ PrintGeneratedSources();
+
+ throw;
+ }
+
+ if (alwaysPrintGeneratedSources)
+ {
+ PrintGeneratedSources();
+ }
+
+ void PrintGeneratedSources()
+ {
+ if (generatedFiles is not null)
+ {
+ foreach (var generatedFile in generatedFiles)
+ {
+ testOutputHelper.WriteLine($"Generated file {generatedFile.Path}: ");
+ testOutputHelper.WriteLine("");
+ testOutputHelper.WriteLine(generatedFile.Code);
+ }
+ }
+ }
+
+ static (AssemblyLoadContext, Assembly) EmitAndLoadAssembly(Compilation compilation, string assemblyLoadContextName)
+ {
+ var errorDiagnostics = compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).ToList();
+ if (errorDiagnostics.Count > 0)
+ {
+ var stringBuilder = new StringBuilder();
+ stringBuilder.AppendLine("Compilation failed:").AppendLine();
+
+ foreach (var errorDiagnostic in errorDiagnostics)
+ {
+ stringBuilder.AppendLine(errorDiagnostic.ToString());
+
+ var textLines = errorDiagnostic.Location.SourceTree!.GetText().Lines;
+ var startLine = errorDiagnostic.Location.GetLineSpan().StartLinePosition.Line;
+ var endLine = errorDiagnostic.Location.GetLineSpan().EndLinePosition.Line;
+
+ if (startLine == endLine)
+ {
+ stringBuilder.Append("Line: ").AppendLine(textLines[startLine].ToString().TrimStart());
+ }
+ else
+ {
+ stringBuilder.AppendLine("Lines:");
+ for (var i = startLine; i <= endLine; i++)
+ {
+ stringBuilder.AppendLine(textLines[i].ToString());
+ }
+ }
+ }
+
+ throw new InvalidOperationException("Compilation failed:" + stringBuilder);
+ }
+
+ using var memoryStream = new MemoryStream();
+ var emitResult = compilation.Emit(memoryStream);
+ memoryStream.Position = 0;
+
+ errorDiagnostics = emitResult.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToList();
+ if (errorDiagnostics.Count > 0)
+ {
+ throw new InvalidOperationException(
+ "Compilation emit failed:" + Environment.NewLine + string.Join(Environment.NewLine, errorDiagnostics));
+ }
+
+ var assemblyLoadContext = new AssemblyLoadContext(assemblyLoadContextName, isCollectible: true);
+ var assembly = assemblyLoadContext.LoadFromStream(memoryStream);
+ return (assemblyLoadContext, assembly);
+ }
+ }
+
+ protected virtual IEnumerable BuildMetadataReferences()
+ {
+ var netAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
+
+ return new[]
+ {
+ MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Queryable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IQueryable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(List<>).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Regex).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(JsonSerializer).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(JavaScriptEncoder).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(DatabaseGeneratedAttribute).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(DbContext).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(RelationalOptionsExtension).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(DbConnection).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IListSource).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IServiceProvider).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IMemoryCache).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Assert).Assembly.Location),
+ // This is to allow referencing types from this file, e.g. NonCompilingQueryCompiler
+ MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location),
+ MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "mscorlib.dll")),
+ MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.dll")),
+ MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Core.dll")),
+ MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Runtime.dll")),
+ MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Collections.dll"))
+ }
+ .Concat(BuildProviderMetadataReferences());
+ }
+
+ protected abstract IEnumerable BuildProviderMetadataReferences();
+
+ // Used from inside the tested code to ensure that we never end up compiling queries at runtime.
+ // TODO: Probably remove this later, once we have a regular mechanism for failing non-intercepted queries at runtime.
+ // ReSharper disable once UnusedMember.Global
+ public class NonCompilingQueryCompiler(
+ IQueryContextFactory queryContextFactory,
+ ICompiledQueryCache compiledQueryCache,
+ ICompiledQueryCacheKeyGenerator compiledQueryCacheKeyGenerator,
+ IDatabase database,
+ IDiagnosticsLogger logger,
+ ICurrentDbContext currentContext,
+ IEvaluatableExpressionFilter evaluatableExpressionFilter,
+ IModel model)
+ : QueryCompiler(queryContextFactory, compiledQueryCache, compiledQueryCacheKeyGenerator, database, logger,
+ currentContext, evaluatableExpressionFilter, model)
+ {
+ public const string ErrorMessage =
+ "A query reached the query compilation pipeline, indicating that it was not intercepted as a precompiled query.";
+
+ public override TResult Execute(Expression query)
+ {
+ Assert.Fail(ErrorMessage);
+ throw new UnreachableException();
+ }
+
+ public override TResult ExecuteAsync(Expression query, CancellationToken cancellationToken = default)
+ {
+ Assert.Fail(ErrorMessage);
+ throw new UnreachableException();
+ }
+ }
+}
diff --git a/test/EFCore.Specification.Tests/JsonTypesTestBase.cs b/test/EFCore.Specification.Tests/JsonTypesTestBase.cs
index 5a79e1a355e..4b6a3638399 100644
--- a/test/EFCore.Specification.Tests/JsonTypesTestBase.cs
+++ b/test/EFCore.Specification.Tests/JsonTypesTestBase.cs
@@ -3924,7 +3924,6 @@ protected class BinaryListArrayArrayListType
{
var contextFactory = await CreateContextFactory(
buildModel,
- addServices: AddServices,
configureConventions: configureConventions);
using var context = contextFactory.CreateContext();
@@ -4002,9 +4001,6 @@ protected class BinaryListArrayArrayListType
protected override string StoreName
=> "JsonTypesTest";
- protected virtual IServiceCollection AddServices(IServiceCollection serviceCollection)
- => serviceCollection;
-
protected virtual void AssertElementFacets(IElementType element, Dictionary? facets)
{
Assert.Equal(FacetValue(CoreAnnotationNames.Precision), element.GetPrecision());
diff --git a/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs b/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs
index ff738874eb2..7f8f924020e 100644
--- a/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs
+++ b/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs
@@ -72,10 +72,11 @@ public virtual Task InitializeAsync()
: CreateTestStore();
shouldLogCategory ??= _ => false;
- var services = (useServiceProvider
+ var services = AddServices(
+ (useServiceProvider
? TestStoreFactory.AddProviderServices(new ServiceCollection())
: new ServiceCollection())
- .AddSingleton(TestStoreFactory.CreateListLoggerFactory(shouldLogCategory));
+ .AddSingleton(TestStoreFactory.CreateListLoggerFactory(shouldLogCategory)));
if (onModelCreating != null)
{
@@ -117,6 +118,9 @@ public virtual Task InitializeAsync()
return optionsBuilder;
}
+ protected virtual IServiceCollection AddServices(IServiceCollection serviceCollection)
+ => serviceCollection;
+
protected virtual DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder)
=> builder
.EnableSensitiveDataLogging()
@@ -178,5 +182,8 @@ public virtual TContext CreateContext()
=> UsePooling
? PooledContextFactory!.CreateDbContext()
: (TContext)ServiceProvider.GetRequiredService(typeof(TContext));
+
+ public virtual DbContextOptions GetOptions()
+ => ServiceProvider.GetRequiredService();
}
}
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs
new file mode 100644
index 00000000000..99148b0d900
--- /dev/null
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs
@@ -0,0 +1,89 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public class AdHocPrecompiledQuerySqlServerTest(ITestOutputHelper testOutputHelper)
+ : AdHocPrecompiledQueryRelationalTestBase(testOutputHelper)
+{
+ protected override bool AlwaysPrintGeneratedSources
+ => false;
+
+ public override async Task Index_no_evaluatability()
+ {
+ await base.Index_no_evaluatability();
+
+ AssertSql("""
+SELECT [j].[Id], [j].[IntList], [j].[JsonThing]
+FROM [JsonEntities] AS [j]
+WHERE CAST(JSON_VALUE([j].[IntList], '$[' + CAST([j].[Id] AS nvarchar(max)) + ']') AS int) = 2
+""");
+ }
+
+ public override async Task Index_with_captured_variable()
+ {
+ await base.Index_with_captured_variable();
+
+ AssertSql("""
+@__id_0='1'
+
+SELECT [j].[Id], [j].[IntList], [j].[JsonThing]
+FROM [JsonEntities] AS [j]
+WHERE CAST(JSON_VALUE([j].[IntList], '$[' + CAST(@__id_0 AS nvarchar(max)) + ']') AS int) = 2
+""");
+ }
+
+ public override async Task JsonScalar()
+ {
+ await base.JsonScalar();
+
+ AssertSql("""
+SELECT [j].[Id], [j].[IntList], [j].[JsonThing]
+FROM [JsonEntities] AS [j]
+WHERE JSON_VALUE([j].[JsonThing], '$.StringProperty') = N'foo'
+""");
+ }
+
+ public override async Task Materialize_non_public()
+ {
+ await base.Materialize_non_public();
+
+ AssertSql(
+ """
+@p0='10' (Nullable = true)
+@p1='9' (Nullable = true)
+@p2='8' (Nullable = true)
+
+SET IMPLICIT_TRANSACTIONS OFF;
+SET NOCOUNT ON;
+INSERT INTO [NonPublicEntities] ([PrivateAutoProperty], [PrivateProperty], [_privateField])
+OUTPUT INSERTED.[Id]
+VALUES (@p0, @p1, @p2);
+""",
+ //
+ """
+SELECT TOP(2) [n].[Id], [n].[PrivateAutoProperty], [n].[PrivateProperty], [n].[_privateField]
+FROM [NonPublicEntities] AS [n]
+""");
+ }
+
+ [ConditionalFact]
+ public virtual void Check_all_tests_overridden()
+ => TestHelpers.AssertAllMethodsOverridden(GetType());
+
+ protected override ITestStoreFactory TestStoreFactory
+ => SqlServerTestStoreFactory.Instance;
+
+ protected override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers
+ => SqlServerPrecompiledQueryTestHelpers.Instance;
+
+ protected override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder)
+ {
+ builder = base.AddOptions(builder);
+
+ // TODO: Figure out if there's a nice way to continue using the retrying strategy
+ var sqlServerOptionsBuilder = new SqlServerDbContextOptionsBuilder(builder);
+ sqlServerOptionsBuilder.ExecutionStrategy(d => new NonRetryingExecutionStrategy(d));
+ return builder;
+ }
+}
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs
new file mode 100644
index 00000000000..342808d8613
--- /dev/null
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs
@@ -0,0 +1,2018 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+// ReSharper disable InconsistentNaming
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public class PrecompiledQuerySqlServerTest(
+ PrecompiledQuerySqlServerTest.PrecompiledQuerySqlServerFixture fixture,
+ ITestOutputHelper testOutputHelper)
+ : PrecompiledQueryRelationalTestBase(fixture, testOutputHelper),
+ IClassFixture
+{
+ protected override bool AlwaysPrintGeneratedSources
+ => false;
+
+ #region Expression types
+
+ public override async Task BinaryExpression()
+ {
+ await base.BinaryExpression();
+
+ AssertSql(
+ """
+@__id_0='3'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > @__id_0
+""");
+ }
+
+ public override async Task Conditional_no_evaluatable()
+ {
+ await base.Conditional_no_evaluatable();
+
+ AssertSql(
+ """
+SELECT CASE
+ WHEN [b].[Id] = 2 THEN N'yes'
+ ELSE N'no'
+END
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Conditional_contains_captured_variable()
+ {
+ await base.Conditional_contains_captured_variable();
+
+ AssertSql(
+ """
+@__yes_0='yes' (Size = 4000)
+
+SELECT CASE
+ WHEN [b].[Id] = 2 THEN @__yes_0
+ ELSE N'no'
+END
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Invoke_no_evaluatability_is_not_supported()
+ {
+ await base.Invoke_no_evaluatability_is_not_supported();
+
+ AssertSql();
+ }
+
+ public override async Task ListInit_no_evaluatability()
+ {
+ await base.ListInit_no_evaluatability();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Id] + 1
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task ListInit_with_evaluatable_with_captured_variable()
+ {
+ await base.ListInit_with_evaluatable_with_captured_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Id]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task ListInit_with_evaluatable_without_captured_variable()
+ {
+ await base.ListInit_with_evaluatable_without_captured_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Id]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task ListInit_fully_evaluatable()
+ {
+ await base.ListInit_fully_evaluatable();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] IN (7, 8)
+""");
+ }
+
+ public override async Task MethodCallExpression_no_evaluatability()
+ {
+ await base.MethodCallExpression_no_evaluatability();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] IS NOT NULL AND LEFT([b].[Name], LEN([b].[Name])) = [b].[Name]
+""");
+ }
+
+ public override async Task MethodCallExpression_with_evaluatable_with_captured_variable()
+ {
+ await base.MethodCallExpression_with_evaluatable_with_captured_variable();
+
+ AssertSql(
+ """
+@__pattern_0_startswith='foo%' (Size = 4000)
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] LIKE @__pattern_0_startswith ESCAPE N'\'
+""");
+ }
+
+ public override async Task MethodCallExpression_with_evaluatable_without_captured_variable()
+ {
+ await base.MethodCallExpression_with_evaluatable_without_captured_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] LIKE N'foo%'
+""");
+ }
+
+ public override async Task MethodCallExpression_fully_evaluatable()
+ {
+ await base.MethodCallExpression_fully_evaluatable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task New_with_no_arguments()
+ {
+ await base.New_with_no_arguments();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 0
+""");
+ }
+
+ public override async Task Where_New_with_captured_variable()
+ {
+ await base.Where_New_with_captured_variable();
+
+ AssertSql();
+ }
+
+ public override async Task Select_New_with_captured_variable()
+ {
+ await base.Select_New_with_captured_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task MemberInit_no_evaluatable()
+ {
+ await base.MemberInit_no_evaluatable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task MemberInit_contains_captured_variable()
+ {
+ await base.MemberInit_contains_captured_variable();
+
+ AssertSql(
+ """
+@__id_0='8'
+
+SELECT @__id_0 AS [Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task MemberInit_evaluatable_as_constant()
+ {
+ await base.MemberInit_evaluatable_as_constant();
+
+ AssertSql(
+ """
+SELECT 1 AS [Id], N'foo' AS [Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task MemberInit_evaluatable_as_parameter()
+ {
+ await base.MemberInit_evaluatable_as_parameter();
+
+ AssertSql(
+ """
+SELECT 1
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task NewArray()
+ {
+ await base.NewArray();
+
+ AssertSql(
+ """
+@__i_0='8'
+
+SELECT [b].[Id], [b].[Id] + @__i_0
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Unary()
+ {
+ await base.Unary();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE CAST([b].[Id] AS smallint) = CAST(8 AS smallint)
+""");
+ }
+
+ public virtual async Task Collate()
+ {
+ await Test("""_ = context.Blogs.Where(b => EF.Functions.Collate(b.Name, "German_PhoneBook_CI_AS") == "foo").ToList();""");
+
+ AssertSql();
+ }
+
+ #endregion Expression types
+
+ #region Terminating operators
+
+ public override async Task Terminating_AsEnumerable()
+ {
+ await base.Terminating_AsEnumerable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_AsAsyncEnumerable_on_DbSet()
+ {
+ await base.Terminating_AsAsyncEnumerable_on_DbSet();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_AsAsyncEnumerable_on_IQueryable()
+ {
+ await base.Terminating_AsAsyncEnumerable_on_IQueryable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""");
+ }
+
+ public override async Task Foreach_sync_over_operator()
+ {
+ await base.Foreach_sync_over_operator();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""");
+ }
+
+ public override async Task Terminating_ToArray()
+ {
+ await base.Terminating_ToArray();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToArrayAsync()
+ {
+ await base.Terminating_ToArrayAsync();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToDictionary()
+ {
+ await base.Terminating_ToDictionary();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToDictionaryAsync()
+ {
+ await base.Terminating_ToDictionaryAsync();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task ToDictionary_over_anonymous_type()
+ {
+ await base.ToDictionary_over_anonymous_type();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task ToDictionaryAsync_over_anonymous_type()
+ {
+ await base.ToDictionaryAsync_over_anonymous_type();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToHashSet()
+ {
+ await base.Terminating_ToHashSet();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToHashSetAsync()
+ {
+ await base.Terminating_ToHashSetAsync();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToLookup()
+ {
+ await base.Terminating_ToLookup();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToList()
+ {
+ await base.Terminating_ToList();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ToListAsync()
+ {
+ await base.Terminating_ToListAsync();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Foreach_sync_over_DbSet_property_is_not_supported()
+ {
+ await base.Foreach_sync_over_DbSet_property_is_not_supported();
+
+ AssertSql();
+ }
+
+ public override async Task Foreach_async_is_not_supported()
+ {
+ await base.Foreach_async_is_not_supported();
+
+ AssertSql();
+ }
+
+ #endregion Terminating operators
+
+ #region Reducing terminating operators
+
+ public override async Task Terminating_All()
+ {
+ await base.Terminating_All();
+
+ AssertSql(
+ """
+SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] <= 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] <= 8) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_AllAsync()
+ {
+ await base.Terminating_AllAsync();
+
+ AssertSql(
+ """
+SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] <= 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] <= 8) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_Any()
+ {
+ await base.Terminating_Any();
+
+ AssertSql(
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] < 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] < 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_AnyAsync()
+ {
+ await base.Terminating_AnyAsync();
+
+ AssertSql(
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] < 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] < 7) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_Average()
+ {
+ await base.Terminating_Average();
+
+ AssertSql(
+ """
+SELECT AVG(CAST([b].[Id] AS float))
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT AVG(CAST([b].[Id] AS float))
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_AverageAsync()
+ {
+ await base.Terminating_AverageAsync();
+
+ AssertSql(
+ """
+SELECT AVG(CAST([b].[Id] AS float))
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT AVG(CAST([b].[Id] AS float))
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_Contains()
+ {
+ await base.Terminating_Contains();
+
+ AssertSql(
+ """
+@__p_0='8'
+
+SELECT CASE
+ WHEN @__p_0 IN (
+ SELECT [b].[Id]
+ FROM [Blogs] AS [b]
+ ) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+@__p_0='7'
+
+SELECT CASE
+ WHEN @__p_0 IN (
+ SELECT [b].[Id]
+ FROM [Blogs] AS [b]
+ ) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_ContainsAsync()
+ {
+ await base.Terminating_ContainsAsync();
+
+ AssertSql(
+ """
+@__p_0='8'
+
+SELECT CASE
+ WHEN @__p_0 IN (
+ SELECT [b].[Id]
+ FROM [Blogs] AS [b]
+ ) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""",
+ //
+ """
+@__p_0='7'
+
+SELECT CASE
+ WHEN @__p_0 IN (
+ SELECT [b].[Id]
+ FROM [Blogs] AS [b]
+ ) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END
+""");
+ }
+
+ public override async Task Terminating_Count()
+ {
+ await base.Terminating_Count();
+
+ AssertSql(
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""");
+ }
+
+ public override async Task Terminating_CountAsync()
+ {
+ await base.Terminating_CountAsync();
+
+ AssertSql(
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""");
+ }
+
+ public override async Task Terminating_ElementAt()
+ {
+ await base.Terminating_ElementAt();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""",
+ //
+ """
+@__p_0='3'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""");
+ }
+
+ public override async Task Terminating_ElementAtAsync()
+ {
+ await base.Terminating_ElementAtAsync();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""",
+ //
+ """
+@__p_0='3'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""");
+ }
+
+ public override async Task Terminating_ElementAtOrDefault()
+ {
+ await base.Terminating_ElementAtOrDefault();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""",
+ //
+ """
+@__p_0='3'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""");
+ }
+
+ public override async Task Terminating_ElementAtOrDefaultAsync()
+ {
+ await base.Terminating_ElementAtOrDefaultAsync();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""",
+ //
+ """
+@__p_0='3'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY
+""");
+ }
+
+ public override async Task Terminating_First()
+ {
+ await base.Terminating_First();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_FirstAsync()
+ {
+ await base.Terminating_FirstAsync();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_FirstOrDefault()
+ {
+ await base.Terminating_FirstOrDefault();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_FirstOrDefaultAsync()
+ {
+ await base.Terminating_FirstOrDefaultAsync();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_GetEnumerator()
+ {
+ await base.Terminating_GetEnumerator();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""");
+ }
+
+ public override async Task Terminating_Last()
+ {
+ await base.Terminating_Last();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""");
+ }
+
+ public override async Task Terminating_LastAsync()
+ {
+ await base.Terminating_LastAsync();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""");
+ }
+
+ public override async Task Terminating_LastOrDefault()
+ {
+ await base.Terminating_LastOrDefault();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""");
+ }
+
+ public override async Task Terminating_LastOrDefaultAsync()
+ {
+ await base.Terminating_LastOrDefaultAsync();
+
+ AssertSql(
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+ORDER BY [b].[Id] DESC
+""",
+ //
+ """
+SELECT TOP(1) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+ORDER BY [b].[Id] DESC
+""");
+ }
+
+ public override async Task Terminating_LongCount()
+ {
+ await base.Terminating_LongCount();
+
+ AssertSql(
+ """
+SELECT COUNT_BIG(*)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COUNT_BIG(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""");
+ }
+
+ public override async Task Terminating_LongCountAsync()
+ {
+ await base.Terminating_LongCountAsync();
+
+ AssertSql(
+ """
+SELECT COUNT_BIG(*)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COUNT_BIG(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""");
+ }
+
+ public override async Task Terminating_Max()
+ {
+ await base.Terminating_Max();
+
+ AssertSql(
+ """
+SELECT MAX([b].[Id])
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT MAX([b].[Id])
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_MaxAsync()
+ {
+ await base.Terminating_MaxAsync();
+
+ AssertSql(
+ """
+SELECT MAX([b].[Id])
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT MAX([b].[Id])
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_Min()
+ {
+ await base.Terminating_Min();
+
+ AssertSql(
+ """
+SELECT MIN([b].[Id])
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT MIN([b].[Id])
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_MinAsync()
+ {
+ await base.Terminating_MinAsync();
+
+ AssertSql(
+ """
+SELECT MIN([b].[Id])
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT MIN([b].[Id])
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_Single()
+ {
+ await base.Terminating_Single();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_SingleAsync()
+ {
+ await base.Terminating_SingleAsync();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_SingleOrDefault()
+ {
+ await base.Terminating_SingleOrDefault();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_SingleOrDefaultAsync()
+ {
+ await base.Terminating_SingleOrDefaultAsync();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 8
+""",
+ //
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 7
+""");
+ }
+
+ public override async Task Terminating_Sum()
+ {
+ await base.Terminating_Sum();
+
+ AssertSql(
+ """
+SELECT COALESCE(SUM([b].[Id]), 0)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COALESCE(SUM([b].[Id]), 0)
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_SumAsync()
+ {
+ await base.Terminating_SumAsync();
+
+ AssertSql(
+ """
+SELECT COALESCE(SUM([b].[Id]), 0)
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT COALESCE(SUM([b].[Id]), 0)
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ExecuteDelete()
+ {
+ await base.Terminating_ExecuteDelete();
+
+ AssertSql(
+ """
+DELETE FROM [b]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ExecuteDeleteAsync()
+ {
+ await base.Terminating_ExecuteDeleteAsync();
+
+ AssertSql(
+ """
+DELETE FROM [b]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Terminating_ExecuteUpdate()
+ {
+ await base.Terminating_ExecuteUpdate();
+
+ AssertSql(
+ """
+@__suffix_0='Suffix' (Size = 4000)
+
+UPDATE [b]
+SET [b].[Name] = COALESCE([b].[Name], N'') + @__suffix_0
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 9 AND [b].[Name] = N'Blog2Suffix'
+""");
+ }
+
+ public override async Task Terminating_ExecuteUpdateAsync()
+ {
+ await base.Terminating_ExecuteUpdateAsync();
+
+ AssertSql(
+ """
+@__suffix_0='Suffix' (Size = 4000)
+
+UPDATE [b]
+SET [b].[Name] = COALESCE([b].[Name], N'') + @__suffix_0
+FROM [Blogs] AS [b]
+WHERE [b].[Id] > 8
+""",
+ //
+ """
+SELECT COUNT(*)
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = 9 AND [b].[Name] = N'Blog2Suffix'
+""");
+ }
+
+ #endregion Reducing terminating operators
+
+ #region SQL expression quotability
+
+ public override async Task Union()
+ {
+ await base.Union();
+
+ AssertSql(
+ """
+SELECT [u].[Id], [u].[Name]
+FROM (
+ SELECT [b].[Id], [b].[Name]
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7
+ UNION
+ SELECT [b0].[Id], [b0].[Name]
+ FROM [Blogs] AS [b0]
+ WHERE [b0].[Id] < 10
+) AS [u]
+ORDER BY [u].[Id]
+""");
+ }
+
+ public override async Task Concat()
+ {
+ await base.Concat();
+
+ AssertSql(
+ """
+SELECT [u].[Id], [u].[Name]
+FROM (
+ SELECT [b].[Id], [b].[Name]
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7
+ UNION ALL
+ SELECT [b0].[Id], [b0].[Name]
+ FROM [Blogs] AS [b0]
+ WHERE [b0].[Id] < 10
+) AS [u]
+ORDER BY [u].[Id]
+""");
+ }
+
+ public override async Task Intersect()
+ {
+ await base.Intersect();
+
+ AssertSql(
+ """
+SELECT [i].[Id], [i].[Name]
+FROM (
+ SELECT [b].[Id], [b].[Name]
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7
+ INTERSECT
+ SELECT [b0].[Id], [b0].[Name]
+ FROM [Blogs] AS [b0]
+ WHERE [b0].[Id] > 8
+) AS [i]
+ORDER BY [i].[Id]
+""");
+ }
+
+ public override async Task Except()
+ {
+ await base.Except();
+
+ AssertSql(
+ """
+SELECT [e].[Id], [e].[Name]
+FROM (
+ SELECT [b].[Id], [b].[Name]
+ FROM [Blogs] AS [b]
+ WHERE [b].[Id] > 7
+ EXCEPT
+ SELECT [b0].[Id], [b0].[Name]
+ FROM [Blogs] AS [b0]
+ WHERE [b0].[Id] > 8
+) AS [e]
+ORDER BY [e].[Id]
+""");
+ }
+
+ public override async Task ValuesExpression()
+ {
+ await base.ValuesExpression();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE (
+ SELECT COUNT(*)
+ FROM (VALUES (CAST(7 AS int)), ([b].[Id])) AS [v]([Value])
+ WHERE [v].[Value] > 8) = 2
+""");
+ }
+
+ public override async Task Contains_with_parameterized_collection()
+ {
+ await base.Contains_with_parameterized_collection();
+
+ AssertSql(
+ """
+@__ids_0='[1,2,3]' (Size = 4000)
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] IN (
+ SELECT [i].[value]
+ FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i]
+)
+""");
+ }
+
+ public override async Task FromSqlRaw()
+ {
+ await base.FromSqlRaw();
+
+ AssertSql(
+ """
+SELECT [m].[Id], [m].[Name]
+FROM (
+ SELECT * FROM Blogs
+) AS [m]
+ORDER BY [m].[Id]
+""");
+ }
+
+ public override async Task FromSql_with_FormattableString_parameters()
+ {
+ await base.FromSql_with_FormattableString_parameters();
+
+ AssertSql(
+ """
+p0='8'
+p1='9'
+
+SELECT [m].[Id], [m].[Name]
+FROM (
+ SELECT * FROM Blogs WHERE Id > @p0 AND Id < @p1
+) AS [m]
+ORDER BY [m].[Id]
+""");
+ }
+
+ [ConditionalFact]
+ public virtual async Task SqlServerAggregateFunctionExpression()
+ {
+ await Test(
+ """
+_ = context.Blogs
+ .GroupBy(b => b.Id)
+ .Select(g => string.Join(", ", g.OrderBy(b => b.Name).Select(b => b.Name)))
+ .ToList();
+""");
+
+ AssertSql(
+ """
+SELECT COALESCE(STRING_AGG(COALESCE([b].[Name], N''), N', ') WITHIN GROUP (ORDER BY [b].[Name]), N'')
+FROM [Blogs] AS [b]
+GROUP BY [b].[Id]
+""");
+ }
+
+ // SqlServerOpenJsonExpression is covered by PrecompiledQueryRelationalTestBase.Contains_with_parameterized_collection
+
+// [ConditionalFact]
+// public virtual Task TableValuedFunctionExpression_toplevel()
+// => Test(
+// "_ = context.GetBlogsWithAtLeast(9).ToList();",
+// modelSourceCode: providerOptions => $$"""
+// public class BlogContext : DbContext
+// {
+// public DbSet Blogs { get; set; }
+//
+// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
+// => optionsBuilder
+// {{providerOptions}}
+// .ReplaceService();
+//
+// protected override void OnModelCreating(ModelBuilder modelBuilder)
+// {
+// modelBuilder.HasDbFunction(typeof(BlogContext).GetMethod(nameof(GetBlogsWithAtLeast)));
+// }
+//
+// public IQueryable GetBlogsWithAtLeast(int minBlogId) => FromExpression(() => GetBlogsWithAtLeast(minBlogId));
+// }
+//
+// public class Blog
+// {
+// [DatabaseGenerated(DatabaseGeneratedOption.None)]
+// public int Id { get; set; }
+// public string StringProperty { get; set; }
+// }
+// """,
+// setupSql: """
+// CREATE FUNCTION dbo.GetBlogsWithAtLeast(@minBlogId int)
+// RETURNS TABLE AS RETURN
+// (
+// SELECT [b].[Id], [b].[Name] FROM [Blogs] AS [b] WHERE [b].[Id] >= @minBlogId
+// )
+// """,
+// cleanupSql: "DROP FUNCTION dbo.GetBlogsWithAtLeast;");
+//
+// [ConditionalFact]
+// public virtual Task TableValuedFunctionExpression_non_toplevel()
+// => Test(
+// "_ = context.Blogs.Where(b => context.GetPosts(b.Id).Count() == 2).ToList();",
+// modelSourceCode: providerOptions => $$"""
+// public class BlogContext : DbContext
+// {
+// public DbSet Blogs { get; set; }
+//
+// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
+// => optionsBuilder
+// {{providerOptions}}
+// .ReplaceService();
+//
+// protected override void OnModelCreating(ModelBuilder modelBuilder)
+// {
+// modelBuilder.HasDbFunction(typeof(BlogContext).GetMethod(nameof(GetPosts)));
+// }
+//
+// public IQueryable GetPosts(int blogId) => FromExpression(() => GetPosts(blogId));
+// }
+//
+// public class Blog
+// {
+// public int Id { get; set; }
+// public string StringProperty { get; set; }
+// public List Post { get; set; }
+// }
+//
+// public class Post
+// {
+// public int Id { get; set; }
+// public string Title { get; set; }
+//
+// public Blog Blog { get; set; }
+// }
+// """,
+// setupSql: """
+// CREATE FUNCTION dbo.GetPosts(@blogId int)
+// RETURNS TABLE AS RETURN
+// (
+// SELECT [p].[Id], [p].[Title], [p].[BlogId] FROM [Posts] AS [p] WHERE [p].[BlogId] = @blogId
+// )
+// """,
+// cleanupSql: "DROP FUNCTION dbo.GetPosts;");
+
+ #endregion SQL expression quotability
+
+ #region Different query roots
+
+ public override async Task DbContext_as_local_variable()
+ {
+ await base.DbContext_as_local_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task DbContext_as_field()
+ {
+ await base.DbContext_as_field();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task DbContext_as_property()
+ {
+ await base.DbContext_as_property();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task DbContext_as_captured_variable()
+ {
+ await base.DbContext_as_captured_variable();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task DbContext_as_method_invocation_result()
+ {
+ await base.DbContext_as_method_invocation_result();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ #endregion Different query roots
+
+ #region Negative cases
+
+ public override async Task Dynamic_query_does_not_get_precompiled()
+ {
+ await base.Dynamic_query_does_not_get_precompiled();
+
+ AssertSql();
+ }
+
+ public override async Task ToList_over_objects_does_not_get_precompiled()
+ {
+ await base.ToList_over_objects_does_not_get_precompiled();
+
+ AssertSql();
+ }
+
+ public override async Task Query_compilation_failure()
+ {
+ await base.Query_compilation_failure();
+
+ AssertSql();
+ }
+
+ public override async Task EF_Constant_is_not_supported()
+ {
+ await base.EF_Constant_is_not_supported();
+
+ AssertSql();
+ }
+
+ public override async Task NotParameterizedAttribute_with_constant()
+ {
+ await base.NotParameterizedAttribute_with_constant();
+
+ AssertSql(
+ """
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] = N'Blog2'
+""");
+ }
+
+ public override async Task NotParameterizedAttribute_is_not_supported_with_non_constant_argument()
+ {
+ await base.NotParameterizedAttribute_is_not_supported_with_non_constant_argument();
+
+ AssertSql();
+ }
+
+ public override async Task Query_syntax_is_not_supported()
+ {
+ await base.Query_syntax_is_not_supported();
+
+ AssertSql();
+ }
+
+ #endregion Negative cases
+
+ public override async Task Select_changes_type()
+ {
+ await base.Select_changes_type();
+
+ AssertSql(
+ """
+SELECT [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task OrderBy()
+ {
+ await base.OrderBy();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Name]
+""");
+ }
+
+ public override async Task Skip()
+ {
+ await base.Skip();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Name]
+OFFSET @__p_0 ROWS
+""");
+ }
+
+ public override async Task Take()
+ {
+ await base.Take();
+
+ AssertSql(
+ """
+@__p_0='1'
+
+SELECT TOP(@__p_0) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Name]
+""");
+ }
+
+ public override async Task Project_anonymous_object()
+ {
+ await base.Project_anonymous_object();
+
+ AssertSql(
+ """
+SELECT COALESCE([b].[Name], N'') + N'Foo' AS [Foo]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Two_captured_variables_in_same_lambda()
+ {
+ await base.Two_captured_variables_in_same_lambda();
+
+ AssertSql(
+ """
+@__yes_0='yes' (Size = 4000)
+@__no_1='no' (Size = 4000)
+
+SELECT CASE
+ WHEN [b].[Id] = 3 THEN @__yes_0
+ ELSE @__no_1
+END
+FROM [Blogs] AS [b]
+""");
+ }
+
+ public override async Task Two_captured_variables_in_different_lambdas()
+ {
+ await base.Two_captured_variables_in_different_lambdas();
+
+ AssertSql(
+ """
+@__starts_0_startswith='blog%' (Size = 4000)
+@__ends_1_endswith='%2' (Size = 4000)
+
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] LIKE @__starts_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__ends_1_endswith ESCAPE N'\'
+""");
+ }
+
+ public override async Task Same_captured_variable_twice_in_same_lambda()
+ {
+ await base.Same_captured_variable_twice_in_same_lambda();
+
+ AssertSql(
+ """
+@__foo_0_startswith='X%' (Size = 4000)
+@__foo_0_endswith='%X' (Size = 4000)
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] LIKE @__foo_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__foo_0_endswith ESCAPE N'\'
+""");
+ }
+
+ public override async Task Same_captured_variable_twice_in_different_lambdas()
+ {
+ await base.Same_captured_variable_twice_in_different_lambdas();
+
+ AssertSql(
+ """
+@__foo_0_startswith='X%' (Size = 4000)
+@__foo_0_endswith='%X' (Size = 4000)
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Name] LIKE @__foo_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__foo_0_endswith ESCAPE N'\'
+""");
+ }
+
+ public override async Task Include_single()
+ {
+ await base.Include_single();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name], [p].[Id], [p].[BlogId], [p].[Title]
+FROM [Blogs] AS [b]
+LEFT JOIN [Posts] AS [p] ON [b].[Id] = [p].[BlogId]
+WHERE [b].[Id] > 8
+ORDER BY [b].[Id]
+""");
+ }
+
+ public override async Task Include_split()
+ {
+ await base.Include_split();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Id]
+""",
+ //
+ """
+SELECT [p].[Id], [p].[BlogId], [p].[Title], [b].[Id]
+FROM [Blogs] AS [b]
+INNER JOIN [Posts] AS [p] ON [b].[Id] = [p].[BlogId]
+ORDER BY [b].[Id]
+""");
+ }
+
+ public override async Task Final_GroupBy()
+ {
+ await base.Final_GroupBy();
+
+ AssertSql(
+ """
+SELECT [b].[Name], [b].[Id]
+FROM [Blogs] AS [b]
+ORDER BY [b].[Name]
+""");
+ }
+
+ public override async Task Multiple_queries_with_captured_variables()
+ {
+ await base.Multiple_queries_with_captured_variables();
+
+ AssertSql(
+ """
+@__id1_0='8'
+@__id2_1='9'
+
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = @__id1_0 OR [b].[Id] = @__id2_1
+""",
+ //
+ """
+@__id1_0='8'
+
+SELECT TOP(2) [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+WHERE [b].[Id] = @__id1_0
+""");
+ }
+
+ public override async Task Unsafe_accessor_gets_generated_once_for_multiple_queries()
+ {
+ await base.Unsafe_accessor_gets_generated_once_for_multiple_queries();
+
+ AssertSql(
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""",
+ //
+ """
+SELECT [b].[Id], [b].[Name]
+FROM [Blogs] AS [b]
+""");
+ }
+
+ [ConditionalFact]
+ public virtual void Check_all_tests_overridden()
+ => TestHelpers.AssertAllMethodsOverridden(GetType());
+
+ public class PrecompiledQuerySqlServerFixture : PrecompiledQueryRelationalFixture
+ {
+ protected override ITestStoreFactory TestStoreFactory
+ => SqlServerTestStoreFactory.Instance;
+
+ public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder)
+ {
+ builder = base.AddOptions(builder);
+
+ // TODO: Figure out if there's a nice way to continue using the retrying strategy
+ var sqlServerOptionsBuilder = new SqlServerDbContextOptionsBuilder(builder);
+ sqlServerOptionsBuilder.ExecutionStrategy(d => new NonRetryingExecutionStrategy(d));
+ return builder;
+ }
+
+ public override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers => SqlServerPrecompiledQueryTestHelpers.Instance;
+ }
+}
diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs
index 47969a88022..943b0f500f6 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs
@@ -50,15 +50,15 @@ private FunctionParameterTypeMappingContextModel()
param.TypeMapping = StringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "varchar",
@@ -67,15 +67,15 @@ private FunctionParameterTypeMappingContextModel()
getSqlFragmentStatic.TypeMapping = SqlServerStringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "nvarchar(max)",
diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs
index b3068bed402..d9d601aa499 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs
@@ -50,15 +50,15 @@ private FunctionTypeMappingContextModel()
param.TypeMapping = SqlServerStringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "nvarchar(max)",
@@ -69,15 +69,15 @@ private FunctionTypeMappingContextModel()
getSqlFragmentStatic.TypeMapping = StringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "varchar",
diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs
index e2d187cd8ce..4a979a96d2b 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs
@@ -60,7 +60,7 @@ public static RuntimeEntityType Create(RuntimeModel model, RuntimeEntityType bas
blob.TypeMapping = SqlServerByteArrayTypeMapping.Default.Clone(
comparer: new ValueComparer(
(byte[] v1, byte[] v2) => StructuralComparisons.StructuralEqualityComparer.Equals((object)v1, (object)v2),
- (byte[] v) => v.GetHashCode(),
+ (byte[] v) => ((object)v).GetHashCode(),
(byte[] v) => v),
keyComparer: new ValueComparer(
(byte[] v1, byte[] v2) => StructuralComparisons.StructuralEqualityComparer.Equals((object)v1, (object)v2),
@@ -86,7 +86,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
(InternalEntityEntry source) =>
{
var entity = (CompiledModelTestBase.Data)source.Entity;
- return (ISnapshot)new Snapshot(source.GetCurrentValue(blob) == null ? null : ((ValueComparer)blob.GetValueComparer()).Snapshot(source.GetCurrentValue(blob)));
+ return (ISnapshot)new Snapshot(source.GetCurrentValue(blob) == null ? null : ((ValueComparer)((IProperty)blob).GetValueComparer()).Snapshot(source.GetCurrentValue(blob)));
});
runtimeEntityType.SetStoreGeneratedValuesFactory(
() => Snapshot.Empty);
diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs
index b5f009363bc..47595a59c8e 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs
@@ -70,15 +70,15 @@ private DbFunctionContextModel()
id.TypeMapping = GuidTypeMapping.Default.Clone(
comparer: new ValueComparer(
(Guid v1, Guid v2) => v1 == v2,
- (Guid v) => v.GetHashCode(),
+ (Guid v) => ((object)v).GetHashCode(),
(Guid v) => v),
keyComparer: new ValueComparer(
(Guid v1, Guid v2) => v1 == v2,
- (Guid v) => v.GetHashCode(),
+ (Guid v) => ((object)v).GetHashCode(),
(Guid v) => v),
providerValueComparer: new ValueComparer(
(Guid v1, Guid v2) => v1 == v2,
- (Guid v) => v.GetHashCode(),
+ (Guid v) => ((object)v).GetHashCode(),
(Guid v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "uniqueidentifier"));
@@ -92,15 +92,15 @@ private DbFunctionContextModel()
condition.TypeMapping = SqlServerStringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "nchar(256)",
@@ -197,15 +197,15 @@ private DbFunctionContextModel()
date.TypeMapping = SqlServerStringTypeMapping.Default.Clone(
comparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
keyComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
providerValueComparer: new ValueComparer(
(string v1, string v2) => v1 == v2,
- (string v) => v.GetHashCode(),
+ (string v) => ((object)v).GetHashCode(),
(string v) => v),
mappingInfo: new RelationalTypeMappingInfo(
storeTypeName: "nchar(256)",
@@ -217,15 +217,15 @@ private DbFunctionContextModel()
isDateStatic.TypeMapping = SqlServerBoolTypeMapping.Default.Clone(
comparer: new ValueComparer(
(bool v1, bool v2) => v1 == v2,
- (bool v) => v.GetHashCode(),
+ (bool v) => ((object)v).GetHashCode(),
(bool v) => v),
keyComparer: new ValueComparer(
(bool v1, bool v2) => v1 == v2,
- (bool v) => v.GetHashCode(),
+ (bool v) => ((object)v).GetHashCode(),
(bool v) => v),
providerValueComparer: new ValueComparer(
(bool v1, bool v2) => v1 == v2,
- (bool v) => v.GetHashCode(),
+ (bool v) => ((object)v).GetHashCode(),
(bool v) => v));
isDateStatic.AddAnnotation("MyGuid", new Guid("00000000-0000-0000-0000-000000000000"));
functions["Microsoft.EntityFrameworkCore.Scaffolding.CompiledModelRelationalTestBase+DbFunctionContext.IsDateStatic(string)"] = isDateStatic;
diff --git a/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs
new file mode 100644
index 00000000000..e25d0a1df12
--- /dev/null
+++ b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.CodeAnalysis;
+using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;
+
+namespace Microsoft.EntityFrameworkCore.TestUtilities;
+
+public class SqlServerPrecompiledQueryTestHelpers : PrecompiledQueryTestHelpers
+{
+ public static SqlServerPrecompiledQueryTestHelpers Instance = new();
+
+ protected override IEnumerable BuildProviderMetadataReferences()
+ {
+ yield return MetadataReference.CreateFromFile(typeof(SqlServerOptionsExtension).Assembly.Location);
+ yield return MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location);
+ }
+}
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs
new file mode 100644
index 00000000000..076e532fb33
--- /dev/null
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs
@@ -0,0 +1,17 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public class AdHocPrecompiledQuerySqliteTest(ITestOutputHelper testOutputHelper)
+ : AdHocPrecompiledQueryRelationalTestBase(testOutputHelper)
+{
+ protected override bool AlwaysPrintGeneratedSources
+ => false;
+
+ protected override ITestStoreFactory TestStoreFactory
+ => SqliteTestStoreFactory.Instance;
+
+ protected override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers
+ => SqlitePrecompiledQueryTestHelpers.Instance;
+}
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs
new file mode 100644
index 00000000000..f124247ed7f
--- /dev/null
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+// ReSharper disable InconsistentNaming
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+public class PrecompiledQuerySqliteTest(
+ PrecompiledQuerySqliteTest.PrecompiledQuerySqliteFixture fixture,
+ ITestOutputHelper testOutputHelper)
+ : PrecompiledQueryRelationalTestBase(fixture, testOutputHelper),
+ IClassFixture
+{
+ protected override bool AlwaysPrintGeneratedSources
+ => false;
+
+ [ConditionalFact]
+ public virtual Task Glob()
+ => Test("""_ = context.Blogs.Where(b => EF.Functions.Glob(b.Name, "*foo*")).ToList();""");
+
+ [ConditionalFact]
+ public virtual Task Regexp()
+ => Test("""_ = context.Blogs.Where(b => Regex.IsMatch(b.Name, "^foo")).ToList();""");
+
+ public class PrecompiledQuerySqliteFixture : PrecompiledQueryRelationalFixture
+ {
+ protected override ITestStoreFactory TestStoreFactory
+ => SqliteTestStoreFactory.Instance;
+
+ public override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers => SqlitePrecompiledQueryTestHelpers.Instance;
+ }
+}
diff --git a/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs
new file mode 100644
index 00000000000..4e9984a139c
--- /dev/null
+++ b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.CodeAnalysis;
+using Microsoft.EntityFrameworkCore.Sqlite.Infrastructure.Internal;
+
+namespace Microsoft.EntityFrameworkCore.TestUtilities;
+
+public class SqlitePrecompiledQueryTestHelpers : PrecompiledQueryTestHelpers
+{
+ public static SqlitePrecompiledQueryTestHelpers Instance = new();
+
+ protected override IEnumerable BuildProviderMetadataReferences()
+ {
+ yield return MetadataReference.CreateFromFile(typeof(SqliteOptionsExtension).Assembly.Location);
+ yield return MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location);
+ }
+}
diff --git a/test/EFCore.Tests/EFCore.Tests.csproj b/test/EFCore.Tests/EFCore.Tests.csproj
index f81c6d6cc98..c38c3eb46b9 100644
--- a/test/EFCore.Tests/EFCore.Tests.csproj
+++ b/test/EFCore.Tests/EFCore.Tests.csproj
@@ -6,6 +6,7 @@
Microsoft.EntityFrameworkCore
disable
true
+ $(NoWarn);EF9100
diff --git a/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs b/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs
index aa8517389b9..6144b826051 100644
--- a/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs
+++ b/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs
@@ -22,19 +22,19 @@ public TestNavigationExpandingExpressionVisitor()
null,
new QueryCompilationContext(
new QueryCompilationContextDependencies(
- null,
- null,
- null,
- null,
- null,
- null,
- null,
+ model: null,
+ queryTranslationPreprocessorFactory: null,
+ queryableMethodTranslatingExpressionVisitorFactory: null,
+ queryTranslationPostprocessorFactory: null,
+ shapedQueryCompilingExpressionVisitorFactory: null,
+ liftableConstantFactory: null,
+ liftableConstantProcessor: null,
new ExecutionStrategyTest.TestExecutionStrategy(new MyDemoContext()),
new CurrentDbContext(new MyDemoContext()),
- null,
- null,
+ contextOptions: null,
+ logger: null,
new TestInterceptors()
- ), false),
+ ), async: false, precompiling: false),
null,
null)
{
@@ -77,7 +77,7 @@ private class A
}
[ConditionalFact]
- public void Visits_extention_childrens()
+ public void Visits_extension_children()
{
var model = new Model();
var e = model.AddEntityType(typeof(A), false, ConfigurationSource.Explicit);