diff --git a/EFCore.sln.DotSettings b/EFCore.sln.DotSettings index 6881961da98..d9724cb9987 100644 --- a/EFCore.sln.DotSettings +++ b/EFCore.sln.DotSettings @@ -301,9 +301,9 @@ The .NET Foundation licenses this file to you under the MIT license. True True True - True True True + True True True True @@ -325,7 +325,10 @@ The .NET Foundation licenses this file to you under the MIT license. True True True + True True + True + True True True True diff --git a/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs b/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs index 60d4ef6301a..ce83a43ec36 100644 --- a/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs +++ b/src/EFCore.Design/Design/DesignTimeServiceCollectionExtensions.cs @@ -3,6 +3,7 @@ using Microsoft.EntityFrameworkCore.Design.Internal; using Microsoft.EntityFrameworkCore.Migrations.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Scaffolding.Internal; namespace Microsoft.EntityFrameworkCore.Design; diff --git a/src/EFCore.Design/Design/Internal/CSharpHelper.cs b/src/EFCore.Design/Design/Internal/CSharpHelper.cs index 01d2f2c1acd..8b5cb3c4bac 100644 --- a/src/EFCore.Design/Design/Internal/CSharpHelper.cs +++ b/src/EFCore.Design/Design/Internal/CSharpHelper.cs @@ -1587,13 +1587,25 @@ private string ToSourceCode(SyntaxNode node) public virtual string Statement( Expression node, ISet collectedNamespaces, + ISet unsafeAccessors, IReadOnlyDictionary? constantReplacements, IReadOnlyDictionary? memberAccessReplacements) - => ToSourceCode(_translator.TranslateStatement( - node, - constantReplacements, - memberAccessReplacements, - collectedNamespaces)); + { + var unsafeAccessorDeclarations = new HashSet(); + + var code = ToSourceCode( + _translator.TranslateStatement( + node, + constantReplacements, + memberAccessReplacements, + collectedNamespaces, + unsafeAccessorDeclarations)); + + // TODO: Possibly improve this (e.g. expose a single string that contains all the accessors concatenated?) + unsafeAccessors.UnionWith(unsafeAccessorDeclarations.Select(ToSourceCode)); + + return code; + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -1604,13 +1616,25 @@ private string ToSourceCode(SyntaxNode node) public virtual string Expression( Expression node, ISet collectedNamespaces, + ISet unsafeAccessors, IReadOnlyDictionary? constantReplacements, IReadOnlyDictionary? memberAccessReplacements) - => ToSourceCode(_translator.TranslateExpression( - node, - constantReplacements, - memberAccessReplacements, - collectedNamespaces)); + { + var unsafeAccessorDeclarations = new HashSet(); + + var code = ToSourceCode( + _translator.TranslateExpression( + node, + constantReplacements, + memberAccessReplacements, + collectedNamespaces, + unsafeAccessorDeclarations)); + + // TODO: Possibly improve this (e.g. expose a single string that contains all the accessors concatenated?) + unsafeAccessors.UnionWith(unsafeAccessorDeclarations.Select(ToSourceCode)); + + return code; + } private static bool IsIdentifierStartCharacter(char ch) { diff --git a/src/EFCore.Design/EFCore.Design.csproj b/src/EFCore.Design/EFCore.Design.csproj index 7db02fcc958..8916afeb532 100644 --- a/src/EFCore.Design/EFCore.Design.csproj +++ b/src/EFCore.Design/EFCore.Design.csproj @@ -58,6 +58,8 @@ + + diff --git a/src/EFCore.Design/Properties/DesignStrings.Designer.cs b/src/EFCore.Design/Properties/DesignStrings.Designer.cs index 4eff1157449..7c8b90a1f7e 100644 --- a/src/EFCore.Design/Properties/DesignStrings.Designer.cs +++ b/src/EFCore.Design/Properties/DesignStrings.Designer.cs @@ -207,6 +207,12 @@ public static string DuplicateMigrationName(object? migrationName) GetString("DuplicateMigrationName", nameof(migrationName)), migrationName); + /// + /// Dynamic LINQ queries are not supported when precompiling queries. + /// + public static string DynamicQueryNotSupported + => GetString("DynamicQueryNotSupported"); + /// /// The encoding '{encoding}' specified in the output directive will be ignored. EF Core always scaffolds files using the encoding 'utf-8'. /// @@ -609,6 +615,12 @@ public static string ProviderReturnedNullModel(object? providerTypeName) GetString("ProviderReturnedNullModel", nameof(providerTypeName)), providerTypeName); + /// + /// LINQ query comprehension syntax is currently unsupported in precompiled queries. + /// + public static string QueryComprehensionSyntaxNotSupportedInPrecompiledQueries + => GetString("QueryComprehensionSyntaxNotSupportedInPrecompiledQueries"); + /// /// No files were generated in directory '{outputDirectoryName}'. The following file(s) already exist(s) and must be made writeable to continue: {readOnlyFiles}. /// diff --git a/src/EFCore.Design/Properties/DesignStrings.resx b/src/EFCore.Design/Properties/DesignStrings.resx index 842e2eb9170..c8ed2a6006b 100644 --- a/src/EFCore.Design/Properties/DesignStrings.resx +++ b/src/EFCore.Design/Properties/DesignStrings.resx @@ -192,6 +192,9 @@ The name '{migrationName}' is used by an existing migration. + + Dynamic LINQ queries are not supported when precompiling queries. + The encoding '{encoding}' specified in the output directive will be ignored. EF Core always scaffolds files using the encoding 'utf-8'. @@ -359,6 +362,9 @@ Change your target project to the migrations project by using the Package Manage Metadata model returned should not be null. Provider: {providerTypeName}. + + LINQ query comprehension syntax is currently unsupported in precompiled queries. + No files were generated in directory '{outputDirectoryName}'. The following file(s) already exist(s) and must be made writeable to continue: {readOnlyFiles}. diff --git a/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs b/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs new file mode 100644 index 00000000000..7df792f374d --- /dev/null +++ b/src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs @@ -0,0 +1,1309 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.EntityFrameworkCore.Internal; +using static System.Linq.Expressions.Expression; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// Translates a Roslyn syntax tree into a LINQ expression tree. +/// +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class CSharpToLinqTranslator : CSharpSyntaxVisitor +{ + private static readonly SymbolDisplayFormat QualifiedTypeNameSymbolDisplayFormat = new( + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); + + private Compilation? _compilation; + +#pragma warning disable CS8618 // Uninitialized non-nullable fields. We check _compilation to make sure LoadCompilation was invoked. + private DbContext _userDbContext; + private Assembly? _additionalAssembly; + private INamedTypeSymbol _userDbContextSymbol; + private INamedTypeSymbol _formattableStringSymbol; +#pragma warning restore CS8618 + + private SemanticModel _semanticModel = null!; + + private static MethodInfo? _stringConcatMethod; + private static MethodInfo? _stringFormatMethod; + private static MethodInfo? _formattableStringFactoryCreateMethod; + + /// + /// Loads the given and prepares to translate queries using the given . + /// + /// A containing the syntax nodes to be translated. + /// An instance of the user's . + /// An optional additional assemblies to resolve CLR types from. + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void Load(Compilation compilation, DbContext userDbContext, Assembly? additionalAssembly = null) + { + _compilation = compilation; + _userDbContext = userDbContext; + _additionalAssembly = additionalAssembly; + _userDbContextSymbol = GetTypeSymbolOrThrow(userDbContext.GetType().FullName!); + _formattableStringSymbol = GetTypeSymbolOrThrow("System.FormattableString"); + + INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName) + => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName) + ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName); + } + + private readonly Stack> _parameterStack + = new(new[] { ImmutableDictionary.Empty }); + + private readonly Dictionary _capturedVariables = new(SymbolEqualityComparer.Default); + + /// + /// Translates a Roslyn syntax tree into a LINQ expression tree. + /// + /// The Roslyn syntax node to be translated. + /// + /// The for the Roslyn of which is a part. + /// + /// A LINQ expression tree translated from the provided . + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Translate(SyntaxNode node, SemanticModel semanticModel) + { + if (_compilation is null) + { + throw new InvalidOperationException("A compilation must be loaded."); + } + + Check.DebugAssert( + ReferenceEquals(semanticModel.SyntaxTree, node.SyntaxTree), + "Provided semantic model doesn't match the provided syntax node"); + + _semanticModel = semanticModel; + + // Perform data flow analysis to detect all captured data (closure parameters) + _capturedVariables.Clear(); + foreach (var captured in _semanticModel.AnalyzeDataFlow(node).Captured) + { + _capturedVariables[captured] = null; + } + + var result = Visit(node); + + // TODO: Sanity check: make sure all captured variables in _capturedVariables have non-null values + // (i.e. have been encountered and referenced) + + Debug.Assert(_parameterStack.Count == 1); + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [return: NotNullIfNotNull("node")] + public override Expression? Visit(SyntaxNode? node) + => base.Visit(node); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitAnonymousObjectCreationExpression(AnonymousObjectCreationExpressionSyntax anonymousObjectCreation) + { + // Creating an actual anonymous object means creating a new type, which can only be done with Reflection.Emit. + // At least for EF's purposes, it doesn't matter, so we build a placeholder. + if (_semanticModel.GetSymbolInfo(anonymousObjectCreation).Symbol is not IMethodSymbol constructorSymbol) + { + throw new InvalidOperationException( + "Could not find symbol for anonymous object creation initializer: " + anonymousObjectCreation); + } + + var anonymousType = ResolveType(constructorSymbol.ContainingType); + + var parameters = constructorSymbol.Parameters.ToArray(); + + var parameterInfos = new ParameterInfo[parameters.Length]; + var memberInfos = new MemberInfo[parameters.Length]; + var arguments = new Expression[parameters.Length]; + + foreach (var initializer in anonymousObjectCreation.Initializers) + { + // If the initializer's name isn't explicitly specified, infer it from the initializer's expression like the compiler does + var name = initializer.NameEquals is not null + ? initializer.NameEquals.Name.Identifier.Text + : initializer.Expression is MemberAccessExpressionSyntax memberAccess + ? memberAccess.Name.Identifier.Text + : throw new InvalidOperationException( + $"AnonymousObjectCreation: unnamed initializer with non-MemberAccess expression: {initializer.Expression}"); + + var position = Array.FindIndex(parameters, p => p.Name == name); + var parameter = parameters[position]; + var parameterType = ResolveType(parameter.Type) ?? throw new InvalidOperationException( + "Could not resolve type symbol for: " + parameter.Type); + + parameterInfos[position] = new FakeParameterInfo(name, parameterType, position); + arguments[position] = Visit(initializer.Expression); + memberInfos[position] = anonymousType.GetProperty(parameter.Name)!; + } + + return New( + new FakeConstructorInfo(anonymousType, parameterInfos), + arguments: arguments, + memberInfos); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitArgument(ArgumentSyntax argument) + { + if (!argument.RefKindKeyword.IsKind(SyntaxKind.None)) + { + throw new InvalidOperationException($"Argument with ref/out: {argument}"); + } + + return Visit(argument.Expression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitArrayCreationExpression(ArrayCreationExpressionSyntax arrayCreation) + { + if (_semanticModel.GetTypeInfo(arrayCreation).Type is not IArrayTypeSymbol arrayTypeSymbol) + { + throw new InvalidOperationException($"ArrayCreation: non-array type symbol: {arrayCreation}"); + } + + if (arrayTypeSymbol.Rank > 1) + { + throw new NotImplementedException($"ArrayCreation: multi-dimensional array: {arrayCreation}"); + } + + var elementType = ResolveType(arrayTypeSymbol.ElementType); + Check.DebugAssert(elementType is not null, "elementType is not null"); + + return arrayCreation.Initializer is null + ? NewArrayBounds(elementType, Visit(arrayCreation.Type.RankSpecifiers[0].Sizes[0])) + : NewArrayInit(elementType, arrayCreation.Initializer.Expressions.Select(e => Visit(e))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitBinaryExpression(BinaryExpressionSyntax binary) + { + var left = Visit(binary.Left); + var right = Visit(binary.Right); + + // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.BinaryExpressionSyntax + return binary.Kind() switch + { + // String concatenation + SyntaxKind.AddExpression + when left.Type == typeof(string) && right.Type == typeof(string) + => Add(left, right, + _stringConcatMethod ??= + typeof(string).GetMethod(nameof(string.Concat), new[] { typeof(string), typeof(string) })), + + SyntaxKind.AddExpression => Add(left, right), + SyntaxKind.SubtractExpression => Subtract(left, right), + SyntaxKind.MultiplyExpression => Multiply(left, right), + SyntaxKind.DivideExpression => Divide(left, right), + SyntaxKind.ModuloExpression => Modulo(left, right), + SyntaxKind.LeftShiftExpression => LeftShift(left, right), + SyntaxKind.RightShiftExpression => RightShift(left, right), + // TODO UnsignedRightShiftExpression + SyntaxKind.LogicalOrExpression => OrElse(left, right), + SyntaxKind.LogicalAndExpression => AndAlso(left, right), + + // For bitwise operations over enums, we the enum to its underlying type before the bitwise operation, and then back to the + // enum afterwards (this is corresponds to the LINQ expression tree that the compiler generates) + SyntaxKind.BitwiseOrExpression when left.Type.IsEnum || right.Type.IsEnum + => Convert(Or(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type), + SyntaxKind.BitwiseAndExpression when left.Type.IsEnum || right.Type.IsEnum + => Convert(And(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type), + SyntaxKind.ExclusiveOrExpression when left.Type.IsEnum || right.Type.IsEnum + => Convert(ExclusiveOr(Convert(left, left.Type.GetEnumUnderlyingType()), Convert(right, right.Type.GetEnumUnderlyingType())), left.Type), + + SyntaxKind.BitwiseOrExpression => Or(left, right), + SyntaxKind.BitwiseAndExpression => And(left, right), + SyntaxKind.ExclusiveOrExpression => ExclusiveOr(left, right), + + SyntaxKind.EqualsExpression => Equal(left, right), + SyntaxKind.NotEqualsExpression => NotEqual(left, right), + SyntaxKind.LessThanExpression => LessThan(left, right), + SyntaxKind.LessThanOrEqualExpression => LessThanOrEqual(left, right), + SyntaxKind.GreaterThanExpression => GreaterThan(left, right), + SyntaxKind.GreaterThanOrEqualExpression => GreaterThanOrEqual(left, right), + SyntaxKind.IsExpression => TypeIs(left, right is ConstantExpression { Value : Type type } + ? type + : throw new InvalidOperationException( + $"Encountered {SyntaxKind.IsExpression} with non-constant type right argument: {right}")), + SyntaxKind.AsExpression => TypeAs(left, right is ConstantExpression { Value : Type type } + ? type + : throw new InvalidOperationException( + $"Encountered {SyntaxKind.AsExpression} with non-constant type right argument: {right}")), + SyntaxKind.CoalesceExpression => Coalesce(left, right), + + _ => throw new ArgumentOutOfRangeException($"BinaryExpressionSyntax with {binary.Kind()}") + }; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitCastExpression(CastExpressionSyntax cast) + => Convert(Visit(cast.Expression), ResolveType(cast.Type)); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitConditionalExpression(ConditionalExpressionSyntax conditional) + => Condition( + Visit(conditional.Condition), + Visit(conditional.WhenTrue), + Visit(conditional.WhenFalse)); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitElementAccessExpression(ElementAccessExpressionSyntax elementAccessExpression) + { + var arguments = elementAccessExpression.ArgumentList.Arguments; + var visitedExpression = Visit(elementAccessExpression.Expression); + + switch (_semanticModel.GetTypeInfo(elementAccessExpression.Expression).ConvertedType) + { + case IArrayTypeSymbol: + Check.DebugAssert(elementAccessExpression.ArgumentList.Arguments.Count == 1, + $"ElementAccessExpressionSyntax over array with {arguments.Count} arguments"); + return ArrayIndex(visitedExpression, Visit(arguments[0].Expression)); + + case INamedTypeSymbol: + var property = visitedExpression.Type + .GetProperties() + .Select(p => new { Property = p, IndexParameters = p.GetIndexParameters() }) + .Where( + t => t.IndexParameters.Length == arguments.Count + && t.IndexParameters + .Select(p => p.ParameterType) + .SequenceEqual(arguments.Select(a => ResolveType(a.Expression)))) + .Select(t => t.Property) + .FirstOrDefault(); + + if (property?.GetMethod is null) + { + throw new UnreachableException("No matching property found for ElementAccessExpressionSyntax"); + } + + return Call(visitedExpression, property.GetMethod, arguments.Select(a => Visit(a.Expression))); + + case null: + throw new InvalidOperationException( + $"No type for expression {elementAccessExpression.Expression} in {nameof(ElementAccessExpressionSyntax)}"); + + default: + throw new NotImplementedException($"{nameof(ElementAccessExpressionSyntax)} over non-array"); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitIdentifierName(IdentifierNameSyntax identifierName) + { + if (_parameterStack.Peek().TryGetValue(identifierName.Identifier.Text, out var parameter)) + { + return parameter; + } + + var symbol = _semanticModel.GetSymbolInfo(identifierName).Symbol; + + ITypeSymbol typeSymbol; + switch (symbol) + { + case INamedTypeSymbol s: + return Constant(ResolveType(s)); + case ILocalSymbol s: + typeSymbol = s.Type; + break; + case IFieldSymbol s: + typeSymbol = s.Type; + break; + case IPropertySymbol s: + typeSymbol = s.Type; + break; + case null: + throw new InvalidOperationException($"Identifier without symbol: {identifierName}"); + default: + throw new NotImplementedException($"IdentifierName of type {symbol.GetType().Name}: {identifierName}"); + } + + // TODO: Separate out EF Core-specific logic (EF Core would extend this visitor) + if (typeSymbol.Name.Contains("DbSet")) + { + throw new NotImplementedException("DbSet local symbol"); + } + + // We have an identifier which isn't in our parameters stack. + + // First, if the identifier type is the user's DbContext type (e.g. DbContext local variable, or field/property), + // return a constant over that. + if (typeSymbol.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default)) + { + return Constant(_userDbContext); + } + + // The Translate entry point into the translator uses Roslyn's data flow analysis to locate all captured variables, and populates + // the _capturedVariable dictionary with them (with null values). + // TODO: Test closure over class member (not local variable) + if (symbol is ILocalSymbol localSymbol && _capturedVariables.TryGetValue(localSymbol, out var memberExpression)) + { + // The first time we see a captured variable, we create MemberExpression for it and cache it in _capturedVariables. + return memberExpression + ?? (_capturedVariables[localSymbol] = + Field( + Constant(new FakeClosureFrameClass()), + new FakeFieldInfo( + typeof(FakeClosureFrameClass), + ResolveType(localSymbol.Type), + localSymbol.Name))); + } + + throw new InvalidOperationException( + $"Encountered unknown identifier name '{identifierName}', which doesn't correspond to a lambda parameter or captured variable"); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitImplicitArrayCreationExpression(ImplicitArrayCreationExpressionSyntax implicitArrayCreation) + { + if (_semanticModel.GetTypeInfo(implicitArrayCreation).Type is not IArrayTypeSymbol arrayTypeSymbol) + { + throw new InvalidOperationException($"ArrayCreation: non-array type symbol: {implicitArrayCreation}"); + } + + if (arrayTypeSymbol.Rank > 1) + { + throw new NotImplementedException($"ArrayCreation: multi-dimensional array: {implicitArrayCreation}"); + } + + var elementType = ResolveType(arrayTypeSymbol.ElementType); + Check.DebugAssert(elementType is not null, "elementType is not null"); + + var initializers = implicitArrayCreation.Initializer.Expressions.Select(e => Visit(e)); + + return NewArrayInit(elementType, initializers); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitInterpolatedStringExpression(InterpolatedStringExpressionSyntax interpolatedString) + { + var formatBuilder = new StringBuilder(); + var arguments = new List(); + foreach (var fragment in interpolatedString.Contents) + { + switch (fragment) + { + case InterpolatedStringTextSyntax text: + formatBuilder.Append(text); + break; + case InterpolationSyntax interpolation: + var interpolationExpression = Visit(interpolation.Expression); + if (interpolationExpression.Type != typeof(object)) + { + interpolationExpression = Convert(interpolationExpression, typeof(object)); + } + arguments.Add(interpolationExpression); + formatBuilder.Append('{').Append(arguments.Count - 1).Append('}'); + break; + default: + throw new UnreachableException(); + } + } + + // Return a call to string.Format(), unless we have an implicit conversion to FormattableString, in which case return a call to + // FormattableStringFactory.Create(). + return Call( + _semanticModel.GetTypeInfo(interpolatedString).ConvertedType switch + { + { } t when t.Equals(_formattableStringSymbol, SymbolEqualityComparer.Default) + => _formattableStringFactoryCreateMethod ??= typeof(FormattableStringFactory).GetMethod( + nameof(FormattableStringFactory.Create), [typeof(string), typeof(object[])])!, + + _ => _stringFormatMethod ??= typeof(string).GetMethod(nameof(string.Format), [typeof(string), typeof(object[])])! + }, + Constant(formatBuilder.ToString()), + NewArrayInit(typeof(object), arguments)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitInvocationExpression(InvocationExpressionSyntax invocation) + { + if (_semanticModel.GetSymbolInfo(invocation).Symbol is not IMethodSymbol methodSymbol) + { + throw new InvalidOperationException("Could not find symbol for method invocation: " + invocation); + } + + // First, if the method return type is the user's DbContext type (e.g. DbContext local variable, or field/property), return a + // constant over that DbContext type; the invocation can serve as the root for a LINQ query we can precompile. + if (methodSymbol.ReturnType.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default)) + { + return Constant(_userDbContext); + } + + var declaringType = ResolveType(methodSymbol.ContainingType); + + Expression? instance = null; + if (!methodSymbol.IsStatic || methodSymbol.IsExtensionMethod) + { + // In normal method calls (the ones we support), the invocation node is composed on top of a member access + if (invocation.Expression is not MemberAccessExpressionSyntax { Expression: var receiver }) + { + throw new NotSupportedException($"Invocation over non-member access: {invocation}"); + } + + instance = Visit(receiver); + } + + MethodInfo? methodInfo; + + if (methodSymbol.IsGenericMethod) + { + var originalDefinition = methodSymbol.OriginalDefinition; + if (originalDefinition.ReducedFrom is not null) + { + originalDefinition = originalDefinition.ReducedFrom; + } + + // To accurately find the right open generic method definition based on the Roslyn symbol, we need to create a mapping between + // generic type parameter names (based on the Roslyn side) and .NET reflection Types representing those type parameters. + // This includes both type parameters immediately on the generic method, as well as type parameters from the method's + // containing type (and recursively, its containing types) + var typeTypeParameterMap = new Dictionary(Foo(methodSymbol.ContainingType)); + + IEnumerable> Foo(INamedTypeSymbol typeSymbol) + { + // TODO: We match Roslyn type parameters by name, not sure that's right; also for the method's generic type parameters + + if (typeSymbol.ContainingType is INamedTypeSymbol containingTypeSymbol) + { + foreach (var kvp in Foo(containingTypeSymbol)) + { + yield return kvp; + } + } + + var type = ResolveType(typeSymbol); + var genericArguments = type.GetGenericArguments(); + + Check.DebugAssert( + genericArguments.Length == typeSymbol.TypeParameters.Length, + "genericArguments.Length == typeSymbol.TypeParameters.Length"); + + foreach (var (typeParamSymbol, typeParamType) in typeSymbol.TypeParameters.Zip(genericArguments)) + { + yield return new KeyValuePair(typeParamSymbol.Name, typeParamType); + } + } + + var definitionMethodInfos = declaringType.GetMethods() + .Where(m => + { + if (m.Name == methodSymbol.Name + && m.IsGenericMethodDefinition + && m.GetGenericArguments() is var candidateGenericArguments + && candidateGenericArguments.Length == originalDefinition.TypeParameters.Length + && m.GetParameters() is var candidateParams + && candidateParams.Length == originalDefinition.Parameters.Length) + { + var methodTypeParameterMap = new Dictionary(typeTypeParameterMap); + + // Prepare a dictionary that will be used to resolve generic type parameters (ITypeParameterSymbol) to the + // corresponding reflection Type. This is needed to correctly (and recursively) resolve the type of parameters + // below. + foreach (var (symbol, type) in methodSymbol.TypeParameters.Zip(candidateGenericArguments)) + { + if (symbol.Name != type.Name) + { + return false; + } + + methodTypeParameterMap[symbol.Name] = type; + } + + for (var i = 0; i < candidateParams.Length; i++) + { + var translatedParamType = ResolveType(originalDefinition.Parameters[i].Type, methodTypeParameterMap); + if (translatedParamType != candidateParams[i].ParameterType) + { + return false; + } + } + + return true; + } + + return false; + }).ToArray(); + + if (definitionMethodInfos.Length != 1) + { + throw new InvalidOperationException($"Invocation: Found {definitionMethodInfos.Length} matches for generic method: {invocation}"); + } + + var definitionMethodInfo = definitionMethodInfos[0]; + var typeParams = methodSymbol.TypeArguments.Select(a => ResolveType(a)).ToArray(); + methodInfo = definitionMethodInfo.MakeGenericMethod(typeParams); + } + else + { + // Non-generic method + var reducedMethodSymbol = methodSymbol.ReducedFrom ?? methodSymbol; + + methodInfo = declaringType.GetMethod( + methodSymbol.Name, + BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static, + reducedMethodSymbol.Parameters.Select(p => ResolveType(p.Type)).ToArray()); + + if (methodInfo is null) + { + throw new InvalidOperationException( + $"Invocation: couldn't find method '{methodSymbol.Name}' on type '{declaringType.Name}': {invocation}"); + } + } + + // We have the reflection MethodInfo for the method, prepare the arguments. + + // We can have less arguments than parameters when the method has optional parameters; fill in the missing ones with the default + // value. + // If the method also has a "params" parameter, we also need to take care of that - the syntactic arguments will need to be packed + // into the "params" array etc. + var parameters = methodInfo.GetParameters(); + var sourceArguments = invocation.ArgumentList.Arguments; + var destArguments = new Expression?[parameters.Length]; + var paramIndex = 0; + + // At the syntactic level, an extension method invocation looks like a normal instance's. + // Prepend the instance to the argument list. + // TODO: Test invoking extension without extension syntax (as static) + if (methodSymbol is { IsExtensionMethod: true /*, ReceiverType: { } */ }) + { + destArguments[0] = instance; + paramIndex = 1; + instance = null; + } + + for (var sourceArgIndex = 0; paramIndex < parameters.Length; paramIndex++) + { + var parameter = parameters[paramIndex]; + if (parameter.IsDefined(typeof(ParamArrayAttribute))) + { + // We've reached a "params" parameter; pack all the remaining args (possibly zero) into a NewArrayExpression + var elementType = parameter.ParameterType.GetElementType()!; + var paramsArguments = new Expression[sourceArguments.Count - sourceArgIndex]; + for (var paramsArgIndex = 0; sourceArgIndex < sourceArguments.Count; sourceArgIndex++, paramsArgIndex++) + { + var arg = invocation.ArgumentList.Arguments[sourceArgIndex]; + Check.DebugAssert(arg.NameColon is null, "Named argument in params"); + + paramsArguments[paramsArgIndex] = Visit(arg); + } + + destArguments[paramIndex] = NewArrayInit(elementType, paramsArguments); + Check.DebugAssert(paramIndex == parameters.Length - 1, "Parameters after params"); + break; + } + + if (sourceArgIndex >= sourceArguments.Count) + { + // Fewer arguments than there are parameters - we have optional parameters. + Check.DebugAssert(parameter.IsOptional, "Missing non-optional argument"); + + destArguments[paramIndex] = Constant( + parameter.DefaultValue is null && parameter.ParameterType.IsValueType + ? Activator.CreateInstance(parameter.ParameterType) + : parameter.DefaultValue, + parameter.ParameterType); + continue; + } + + var argument = invocation.ArgumentList.Arguments[sourceArgIndex++]; + + // Positional argument + if (argument.NameColon is null) + { + destArguments[paramIndex] = Visit(argument); + continue; + } + + // Named argument + throw new NotImplementedException("Named argument"); + } + + Check.DebugAssert(destArguments.All(a => a is not null), "arguments.All(a => a is not null)"); + + // TODO: Generic type arguments + return Call(instance, methodInfo, destArguments!); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitLiteralExpression(LiteralExpressionSyntax literal) + => _semanticModel.GetTypeInfo(literal) is { ConvertedType: ITypeSymbol type } + ? Constant(literal.Token.Value, ResolveType(type)) + : Constant(literal.Token.Value); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitMemberAccessExpression(MemberAccessExpressionSyntax memberAccess) + { + var expression = Visit(memberAccess.Expression); + + if (_semanticModel.GetSymbolInfo(memberAccess).Symbol is not ISymbol memberSymbol) + { + throw new InvalidOperationException($"MemberAccess: Couldn't find symbol for member: {memberAccess}"); + } + + var containingType = ResolveType(memberSymbol.ContainingType); + var memberInfo = memberSymbol switch + { + IPropertySymbol p => (MemberInfo?)containingType.GetProperty(p.Name), + IFieldSymbol f => containingType.GetField(f.Name), + INamedTypeSymbol t => containingType.GetNestedType(t.Name), + + null => throw new InvalidOperationException($"MemberAccess: Couldn't find symbol for member: {memberAccess}"), + _ => throw new NotSupportedException($"MemberAccess: unsupported member symbol '{memberSymbol.GetType().Name}': {memberAccess}") + }; + + switch (memberInfo) + { + case Type nestedType: + return Constant(nestedType); + + case null: + throw new InvalidOperationException($"MemberAccess: couldn't find member '{memberSymbol.Name}': {memberAccess}"); + } + + // Enum field constant + if (containingType.IsEnum) + { + return Constant(Enum.Parse(containingType, memberInfo.Name), containingType); + } + + // array.Length + if (expression.Type.IsArray && memberInfo.Name == "Length") + { + if (expression.Type.GetArrayRank() != 1) + { + throw new NotImplementedException("MemberAccess on multi-dimensional array"); + } + + return ArrayLength(expression); + } + + return MakeMemberAccess( + expression is ConstantExpression { Value: Type } ? null : expression, + memberInfo); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitObjectCreationExpression(ObjectCreationExpressionSyntax objectCreation) + { + if (_semanticModel.GetSymbolInfo(objectCreation).Symbol is not IMethodSymbol constructorSymbol) + { + throw new InvalidOperationException($"ObjectCreation: couldn't find IMethodSymbol for constructor: {objectCreation}"); + } + + Check.DebugAssert(constructorSymbol.MethodKind == MethodKind.Constructor, "constructorSymbol.MethodKind == MethodKind.Constructor"); + + var type = ResolveType(constructorSymbol.ContainingType); + + // Find the reflection constructor that matches the constructor symbol's signature + var parameterTypes = constructorSymbol.Parameters.Select(ps => ResolveType(ps.Type)).ToArray(); + var constructor = type.GetConstructor(parameterTypes); + + var newExpression = constructor is not null + ? New( + constructor, + objectCreation.ArgumentList?.Arguments.Select(a => Visit(a)) ?? Array.Empty()) + : parameterTypes.Length == 0 // For structs, there's no actual parameterless constructor + ? New(type) + : throw new InvalidOperationException($"ObjectCreation: Missing constructor: {objectCreation}"); + + switch (objectCreation.Initializer) + { + // No initializers, just return the NewExpression + case null or { Expressions: [] }: + return newExpression; + + // Assignment initializer (new Blog { Name = "foo" }) + case { Expressions: [AssignmentExpressionSyntax, ..] }: + return MemberInit( + newExpression, + objectCreation.Initializer.Expressions.Select( + e => + { + if (e is not AssignmentExpressionSyntax { Left: var lValue, Right: var value }) + { + throw new NotSupportedException( + $"ObjectCreation: non-assignment initializer expression of type '{e.GetType().Name}': {objectCreation}"); + } + + var lValueSymbol = _semanticModel.GetSymbolInfo(lValue).Symbol; + var memberInfo = lValueSymbol switch + { + IPropertySymbol p => (MemberInfo?)type.GetProperty(p.Name), + IFieldSymbol f => type.GetField(f.Name), + + _ => throw new InvalidOperationException( + $"ObjectCreation: unsupported initializer for member of type '{lValueSymbol?.GetType().Name}': {e}") + }; + + if (memberInfo is null) + { + throw new InvalidOperationException( + $"ObjectCreation: couldn't find initialized member '{lValueSymbol.Name}': {e}"); + } + + return Bind(memberInfo, Visit(value)); + })); + + // Non-assignment initializer => list initializer (new List { 1, 2, 3 }) + default: + // Find the correct Add() method on the collection type + // TODO: This doesn't work if there are multiple Add() methods (contrived). Complete solution would be to find the base + // Type for all initializer expressions and find an Add overload of that type (or a superclass thereof) + var addMethod = type.GetMethods().SingleOrDefault(m => m.Name == "Add" && m.GetParameters().Length == 1); + if (addMethod is null) + { + throw new InvalidOperationException( + $"Couldn't find single Add method on type '{type.Name}', required for list initializer"); + } + + // TODO: Dictionary initializer, where each ElementInit has more than one expression + + return ListInit( + newExpression, + objectCreation.Initializer.Expressions.Select(e => ElementInit(addMethod, Visit(e)))); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitParenthesizedExpression(ParenthesizedExpressionSyntax parenthesized) + => Visit(parenthesized.Expression); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax lambda) + => VisitLambdaExpression(lambda); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitPredefinedType(PredefinedTypeSyntax predefinedType) + => Constant(ResolveType(predefinedType), typeof(Type)); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitPrefixUnaryExpression(PrefixUnaryExpressionSyntax unary) + { + var operand = Visit(unary.Operand); + + // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.PrefixUnaryExpressionSyntax + + return unary.Kind() switch + { + SyntaxKind.UnaryPlusExpression => UnaryPlus(operand), + SyntaxKind.UnaryMinusExpression => Negate(operand), + SyntaxKind.BitwiseNotExpression => Not(operand), + SyntaxKind.LogicalNotExpression => Not(operand), + + SyntaxKind.AddressOfExpression => throw NotSupportedInExpressionTrees(), + SyntaxKind.IndexExpression => throw NotSupportedInExpressionTrees(), + SyntaxKind.PointerIndirectionExpression => throw NotSupportedInExpressionTrees(), + SyntaxKind.PreDecrementExpression => throw NotSupportedInExpressionTrees(), + SyntaxKind.PreIncrementExpression => throw NotSupportedInExpressionTrees(), + + _ => throw new UnreachableException( + $"Unexpected syntax kind '{unary.Kind()}' when visiting a {nameof(PrefixUnaryExpressionSyntax)}") + }; + + NotSupportedException NotSupportedInExpressionTrees() + => throw new UnreachableException( + $"Unary expression of type {unary.Kind()} is not supported in expression trees"); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitPostfixUnaryExpression(PostfixUnaryExpressionSyntax unary) + { + var operand = Visit(unary.Operand); + + // https://learn.microsoft.com/dotnet/api/Microsoft.CodeAnalysis.CSharp.Syntax.PostfixUnaryExpressionSyntax + + return unary.Kind() switch + { + SyntaxKind.SuppressNullableWarningExpression => operand, + + SyntaxKind.PostIncrementExpression => throw NotSupportedInExpressionTrees(), + SyntaxKind.PostDecrementExpression => throw NotSupportedInExpressionTrees(), + + _ => throw new UnreachableException( + $"Unexpected syntax kind '{unary.Kind()}' when visiting a {nameof(PostfixUnaryExpressionSyntax)}") + }; + + NotSupportedException NotSupportedInExpressionTrees() + => throw new UnreachableException( + $"Unary expression of type {unary.Kind()} is not supported in expression trees"); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitQueryExpression(QueryExpressionSyntax node) + => throw new NotSupportedException(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax lambda) + => VisitLambdaExpression(lambda); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression VisitTypeOfExpression(TypeOfExpressionSyntax typeOf) + { + if (_semanticModel.GetSymbolInfo(typeOf.Type).Symbol is not ITypeSymbol typeSymbol) + { + throw new InvalidOperationException( + "Could not find symbol for typeof() expression: " + typeOf); + } + + var type = ResolveType(typeSymbol); + return Constant(type, typeof(Type)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression DefaultVisit(SyntaxNode node) + => throw new NotSupportedException($"Unsupported syntax node of type '{node.GetType()}': {node}"); + + private Expression VisitLambdaExpression(AnonymousFunctionExpressionSyntax lambda) + { + if (lambda.ExpressionBody is null) + { + throw new NotSupportedException("Lambda with null expression body"); + } + + if (lambda.Modifiers.Any()) + { + throw new NotSupportedException("Lambda with modifiers not supported: " + lambda.Modifiers); + } + + if (!lambda.AsyncKeyword.IsKind(SyntaxKind.None)) + { + throw new NotSupportedException("Async lambdas are not supported"); + } + + var lambdaParameters = lambda switch + { + SimpleLambdaExpressionSyntax simpleLambda => SyntaxFactory.SingletonSeparatedList(simpleLambda.Parameter), + ParenthesizedLambdaExpressionSyntax parenthesizedLambda => parenthesizedLambda.ParameterList.Parameters, + + _ => throw new UnreachableException() + }; + + var translatedParameters = new List(); + foreach (var parameter in lambdaParameters) + { + if (_semanticModel.GetDeclaredSymbol(parameter) is not { } parameterSymbol || + ResolveType(parameterSymbol.Type) is not { } parameterType) + { + throw new InvalidOperationException("Could not found symbol for parameter lambda: " + parameter); + } + + translatedParameters.Add(Parameter(parameterType, parameter.Identifier.Text)); + } + + _parameterStack.Push(_parameterStack.Peek() + .AddRange(translatedParameters.Select(p => new KeyValuePair(p.Name ?? throw new NotImplementedException(), p)))); + + try + { + var body = Visit(lambda.ExpressionBody); + return Lambda(body, translatedParameters); + } + finally + { + _parameterStack.Pop(); + } + } + + /// + /// Given a Roslyn type symbol, returns a .NET reflection . + /// + /// The type symbol to be translated. + /// A .NET reflection that corresponds to . + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Type TranslateType(ITypeSymbol typeSymbol) + => ResolveType(typeSymbol); + + private Type ResolveType(SyntaxNode node) + => _semanticModel.GetTypeInfo(node).Type is { } typeSymbol + ? ResolveType(typeSymbol) + : throw new InvalidOperationException("Could not find type symbol for: " + node); + + private Type ResolveType(ITypeSymbol typeSymbol, Dictionary? genericParameterMap = null) + { + switch (typeSymbol) + { + case INamedTypeSymbol { IsAnonymousType: true } anonymousTypeSymbol: + _anonymousTypeDefinitions ??= LoadAnonymousTypes(anonymousTypeSymbol.ContainingAssembly); + var properties = anonymousTypeSymbol.GetMembers().OfType().ToArray(); + var found = _anonymousTypeDefinitions.TryGetValue(properties.Select(p => p.Name).ToArray(), + out var anonymousTypeGenericDefinition); + Debug.Assert(found, "Anonymous type not found"); + + var constructorParameters = anonymousTypeGenericDefinition!.GetConstructors()[0].GetParameters(); + var genericTypeArguments = new Type[constructorParameters.Length]; + + for (var i = 0; i < constructorParameters.Length; i++) + { + genericTypeArguments[i] = + ResolveType(properties.FirstOrDefault(p => p.Name == constructorParameters[i].Name)!.Type); + } + + // TODO: Cache closed anonymous types + + return anonymousTypeGenericDefinition.MakeGenericType(genericTypeArguments); + + case INamedTypeSymbol { IsDefinition: true } genericTypeSymbol: + return GetClrType(genericTypeSymbol); + + case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol: + { + var definition = GetClrType(genericTypeSymbol.OriginalDefinition); + var typeArguments = genericTypeSymbol.TypeArguments.Select(a => ResolveType(a, genericParameterMap)).ToArray(); + return definition.MakeGenericType(typeArguments); + } + + case ITypeParameterSymbol typeParameterSymbol: + return genericParameterMap?.TryGetValue(typeParameterSymbol.Name, out var type) == true + ? type + : throw new InvalidOperationException($"Unknown generic type parameter symbol {typeParameterSymbol}"); + + case INamedTypeSymbol namedTypeSymbol: + return GetClrType(namedTypeSymbol); + + case IArrayTypeSymbol arrayTypeSymbol: + // The ContainingAssembly of array type symbols can be null; recurse down the element types (down to the non-array element + // type) to get the assembly. + var containingAssembly = arrayTypeSymbol.ContainingAssembly; + ITypeSymbol currentSymbol = arrayTypeSymbol; + while (containingAssembly is null && currentSymbol is IArrayTypeSymbol { ElementType: var nestedTypeSymbol }) + { + currentSymbol = nestedTypeSymbol; + containingAssembly = currentSymbol.ContainingAssembly; + } + + return GetClrTypeFromAssembly( + containingAssembly, + typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)); + + default: + return GetClrTypeFromAssembly( + typeSymbol.ContainingAssembly, + typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)); + } + + Type GetClrType(INamedTypeSymbol symbol) + { + var name = symbol.ContainingType is null + ? typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat) + : typeSymbol.Name; + + if (symbol.IsGenericType) + { + name += '`' + symbol.Arity.ToString(); + } + + if (symbol.ContainingType is not null) + { + var containingType = ResolveType(symbol.ContainingType); + + return containingType.GetNestedType(name) + ?? throw new InvalidOperationException( + $"Couldn't find nested type '{name}' on containing type '{containingType.Name}'"); + } + + return GetClrTypeFromAssembly(typeSymbol.ContainingAssembly, name); + } + + Type GetClrTypeFromAssembly(IAssemblySymbol? assemblySymbol, string name) + => (assemblySymbol is null + ? Type.GetType(name)! + : Type.GetType($"{name}, {assemblySymbol.Name}")) + // If we can't find the Type, check the assembly where the user's DbContext type lives; this is primarily to support + // testing, where user code is in an assembly that's built as part of the the test and loaded into a specific + // AssemblyLoadContext (which gets unloaded later). + ?? _additionalAssembly?.GetType(name) + ?? throw new InvalidOperationException( + $"Couldn't resolve CLR type '{name}' in assembly '{assemblySymbol?.Name}'"); + + Dictionary LoadAnonymousTypes(IAssemblySymbol assemblySymbol) + { + Assembly? assembly; + try + { + assembly = Assembly.Load(assemblySymbol.Name); + } + catch (FileNotFoundException) + { + // If we can't find the assembly, use the assembly where the user's DbContext type lives; this is primarily to support + // testing, where user code is in an assembly that's built as part of the the test and loaded into a specific + // AssemblyLoadContext (which gets unloaded later). + + // TODO: Strings + assembly = _additionalAssembly + ?? throw new InvalidOperationException($"Could not load assembly for IAssemblySymbol '{assemblySymbol.Name}'"); + } + + // Get all the anonymous type in the assembly, and index them by the ordered names of their properties. + // Note that anonymous types are generic, so we don't have property types in the key. + + // TODO: An alternative strategy would be to just generate the types as we need them (with ref.emit) - that's probably safer. + // TODO: Though it may mean that the resulting CLR Type can't be anonymous (Type.IsAnonymousType()) - not sure that matters. + return assembly.GetTypes() + .Where(t => t.IsAnonymousType()) + .ToDictionary(t => t.GetProperties().Select(x => x.Name).ToArray(), t => t, new ArrayStructuralComparer()); + } + } + + private sealed class ArrayStructuralComparer : IEqualityComparer + { + public bool Equals(T[]? x, T[]? y) + => x is null ? y is null : y is not null && x.SequenceEqual(y); + + public int GetHashCode(T[] obj) + { + var hashcode = new HashCode(); + + foreach (var value in obj) + { + hashcode.Add(value); + } + + return hashcode.ToHashCode(); + } + } + + private Dictionary? _anonymousTypeDefinitions; + + [CompilerGenerated] + private sealed class FakeClosureFrameClass; + + private sealed class FakeFieldInfo(Type declaringType, Type fieldType, string name) : FieldInfo + { + public override object[] GetCustomAttributes(bool inherit) + => Array.Empty(); + + public override object[] GetCustomAttributes(Type attributeType, bool inherit) + => Array.Empty(); + + public override bool IsDefined(Type attributeType, bool inherit) + => false; + + public override Type DeclaringType { get; } = declaringType; + + public override string Name { get; } = name; + + public override Type? ReflectedType => null; + + // We implement GetValue since ExpressionTreeFuncletizer calls it to get the parameter value. In AOT generation time, we obviously + // have no parameter values, nor do we need them for the first part of the query pipeline. + public override object? GetValue(object? obj) + => FieldType.IsValueType + ? Activator.CreateInstance(FieldType) + : FieldType == typeof(string) + ? "" + : null; + + public override void SetValue(object? obj, object? value, BindingFlags invokeAttr, Binder? binder, + CultureInfo? culture) + => throw new NotSupportedException(); + + public override FieldAttributes Attributes + => FieldAttributes.Public; + + public override RuntimeFieldHandle FieldHandle + => throw new NotSupportedException(); + + public override Type FieldType { get; } = fieldType; + } + + private sealed class FakeConstructorInfo(Type type, ParameterInfo[] parameters) : ConstructorInfo + { + public override object[] GetCustomAttributes(bool inherit) + => Array.Empty(); + + public override object[] GetCustomAttributes(Type attributeType, bool inherit) + => Array.Empty(); + + public override bool IsDefined(Type attributeType, bool inherit) + => false; + + public override Type DeclaringType { get; } = type; + + public override string Name + => ".ctor"; + + public override Type ReflectedType + => DeclaringType; + + public override MethodImplAttributes GetMethodImplementationFlags() + => MethodImplAttributes.Managed; + + public override ParameterInfo[] GetParameters() + => parameters; + + public override MethodAttributes Attributes + => MethodAttributes.Public; + + public override RuntimeMethodHandle MethodHandle + => throw new NotSupportedException(); + + public override object Invoke(object? obj, BindingFlags invokeAttr, Binder? binder, object?[]? parameters, + CultureInfo? culture) + => throw new NotSupportedException(); + + public override object Invoke(BindingFlags invokeAttr, Binder? binder, object?[]? parameters, + CultureInfo? culture) + => throw new NotSupportedException(); + } + + private sealed class FakeParameterInfo(string name, Type parameterType, int position) : ParameterInfo + { + public override ParameterAttributes Attributes + => ParameterAttributes.In; + + public override string? Name { get; } = name; + public override Type ParameterType { get; } = parameterType; + public override int Position { get; } = position; + + public override MemberInfo Member + => throw new NotSupportedException(); + } +} diff --git a/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs b/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs index e7b5f0d1d45..4ff06e5fcda 100644 --- a/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs +++ b/src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs @@ -10,11 +10,9 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; -using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Design.Internal; using Microsoft.EntityFrameworkCore.Internal; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using E = System.Linq.Expressions.Expression; namespace Microsoft.EntityFrameworkCore.Query.Internal; @@ -43,6 +41,19 @@ private sealed record LiftedState internal readonly Dictionary Variables = new(); internal readonly HashSet VariableNames = []; internal readonly List UnassignedVariableDeclarations = []; + + internal LiftedState CreateChild() + { + var child = new LiftedState(); + + foreach (var (parameter, name) in Variables) + { + child.Variables.Add(parameter, name); + } + child.VariableNames.UnionWith(VariableNames); + + return child; + } } private LiftedState _liftedState = new(); @@ -53,13 +64,15 @@ private sealed record LiftedState private readonly HashSet _capturedVariables = []; private ISet _collectedNamespaces = null!; + private readonly Dictionary _methodUnsafeAccessors = new(); + private readonly Dictionary<(FieldInfo Field, bool ForWrite), MethodDeclarationSyntax> _fieldUnsafeAccessors = new(); - private static MethodInfo? _activatorCreateInstanceMethod; private static MethodInfo? _mathPowMethod; private readonly SideEffectDetectionSyntaxWalker _sideEffectDetector = new(); private readonly ConstantDetectionSyntaxWalker _constantDetector = new(); private readonly SyntaxGenerator _g; + private readonly StringBuilder _stringBuilder = new(); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -98,8 +111,9 @@ public virtual IReadOnlySet CapturedVariables public virtual SyntaxNode TranslateStatement( Expression node, IReadOnlyDictionary? constantReplacements, - ISet collectedNamespaces) - => TranslateCore(node, constantReplacements, collectedNamespaces, statementContext: true); + ISet collectedNamespaces, + ISet unsafeAccessors) + => TranslateCore(node, constantReplacements, collectedNamespaces, unsafeAccessors, statementContext: true); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -110,8 +124,9 @@ public virtual IReadOnlySet CapturedVariables public virtual SyntaxNode TranslateExpression( Expression node, IReadOnlyDictionary? constantReplacements, - ISet collectedNamespaces) - => TranslateCore(node, constantReplacements, collectedNamespaces, statementContext: false); + ISet collectedNamespaces, + ISet unsafeAccessors) + => TranslateCore(node, constantReplacements, collectedNamespaces, unsafeAccessors, statementContext: false); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -123,6 +138,7 @@ public virtual IReadOnlySet CapturedVariables Expression node, IReadOnlyDictionary? constantReplacements, ISet collectedNamespaces, + ISet unsafeAccessors, bool statementContext) { _capturedVariables.Clear(); @@ -146,6 +162,11 @@ public virtual IReadOnlySet CapturedVariables Check.DebugAssert(_stack.Peek().Labels.Count == 0, "_stack.Peek().Labels.Count == 0"); Check.DebugAssert(_stack.Peek().UnnamedLabelNames.Count == 0, "_stack.Peek().UnnamedLabelNames.Count == 0"); + foreach (var unsafeAccessor in _fieldUnsafeAccessors.Values.Concat(_methodUnsafeAccessors.Values)) + { + unsafeAccessors.Add(unsafeAccessor); + } + return Result!; } @@ -288,7 +309,7 @@ protected override Expression VisitBinary(BinaryExpression binary) case ExpressionType.Power when binary.Left.Type == typeof(double) && binary.Right.Type == typeof(double): return Visit( - E.Call( + Expression.Call( _mathPowMethod ??= typeof(Math).GetMethod( nameof(Math.Pow), BindingFlags.Static | BindingFlags.Public, [typeof(double), typeof(double)])!, binary.Left, @@ -299,9 +320,9 @@ protected override Expression VisitBinary(BinaryExpression binary) case ExpressionType.PowerAssign: return Visit( - E.Assign( + Expression.Assign( binary.Left, - E.Power( + Expression.Power( binary.Left, binary.Right))); } @@ -461,7 +482,7 @@ protected override Expression VisitBlock(BlockExpression block) if (blockContext != ExpressionContext.Expression) { ownStackFrame = PushNewStackFrame(); - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); } var stackFrame = _stack.Peek(); @@ -469,217 +490,211 @@ protected override Expression VisitBlock(BlockExpression block) // Do a 1st pass to identify and register any labels, since goto can appear before its label. PreprocessLabels(); - try + // Go over the block's variables, assign names to any unnamed ones and uniquify. Then add them to our stack frame, unless + // this is an expression block that will get lifted. + foreach (var parameter in block.Variables) { - // Go over the block's variables, assign names to any unnamed ones and uniquify. Then add them to our stack frame, unless - // this is an expression block that will get lifted. - foreach (var parameter in block.Variables) - { - var (variables, variableNames) = (stackFrame.Variables, stackFrame.VariableNames); + var (variables, variableNames) = (stackFrame.Variables, stackFrame.VariableNames); - var uniquifiedName = UniquifyVariableName(parameter.Name ?? "unnamed"); + var uniquifiedName = UniquifyVariableName(parameter.Name ?? "unnamed"); - if (blockContext == ExpressionContext.Expression) + if (blockContext == ExpressionContext.Expression) + { + if (!_liftedState.Variables.TryAdd(parameter, uniquifiedName)) { - if (_liftedState.Variables.ContainsKey(parameter)) - { - throw new NotSupportedException("Parameter clash during expression lifting for: " + parameter.Name); - } - - _liftedState.Variables.Add(parameter, uniquifiedName); - _liftedState.VariableNames.Add(uniquifiedName); + throw new NotSupportedException("Parameter clash during expression lifting for: " + parameter.Name); } - else - { - if (!variables.TryAdd(parameter, uniquifiedName)) - { - throw new InvalidOperationException( - DesignStrings.SameParameterExpressionDeclaredAsVariableInNestedBlocks(parameter.Name ?? "")); - } - variableNames.Add(uniquifiedName); + _liftedState.VariableNames.Add(uniquifiedName); + } + else + { + if (!variables.TryAdd(parameter, uniquifiedName)) + { + throw new InvalidOperationException( + DesignStrings.SameParameterExpressionDeclaredAsVariableInNestedBlocks(parameter.Name ?? "")); } + + variableNames.Add(uniquifiedName); } + } - var unassignedVariables = block.Variables.ToList(); + var unassignedVariables = block.Variables.ToList(); - var statements = new List(); - LabeledStatementSyntax? pendingLabeledStatement = null; + var statements = new List(); + LabeledStatementSyntax? pendingLabeledStatement = null; - // Now visit the block's expressions - for (var i = 0; i < block.Expressions.Count; i++) - { - var expression = block.Expressions[i]; - var onLastBlockLine = i == block.Expressions.Count - 1; - _onLastLambdaLine = parentOnLastLambdaLine && onLastBlockLine; + // Now visit the block's expressions + for (var i = 0; i < block.Expressions.Count; i++) + { + var expression = block.Expressions[i]; + var onLastBlockLine = i == block.Expressions.Count - 1; + _onLastLambdaLine = parentOnLastLambdaLine && onLastBlockLine; - // Any lines before the last are evaluated in statement context (they aren't returned); the last line is evaluated in the - // context of the block as a whole. _context now refers to the statement's context, blockContext to the block's. - var statementContext = onLastBlockLine ? _context : ExpressionContext.Statement; + // Any lines before the last are evaluated in statement context (they aren't returned); the last line is evaluated in the + // context of the block as a whole. _context now refers to the statement's context, blockContext to the block's. + var statementContext = onLastBlockLine ? _context : ExpressionContext.Statement; - SyntaxNode translated; - using (ChangeContext(statementContext)) + SyntaxNode translated; + using (ChangeContext(statementContext)) + { + translated = Translate(expression); + } + + // If we have a labeled statement, unwrap it and keep the label as pending. VisitLabel returns a dummy statement (since + // LINQ labels don't have a statement, unlike C#), so we'll skip that statement and add the label to the next real one. + if (translated is LabeledStatementSyntax labeledStatement) + { + if (pendingLabeledStatement is not null) { - translated = Translate(expression); + throw new NotImplementedException("Multiple labels on the same statement"); } - // If we have a labeled statement, unwrap it and keep the label as pending. VisitLabel returns a dummy statement (since - // LINQ labels don't have a statement, unlike C#), so we'll skip that statement and add the label to the next real one. - if (translated is LabeledStatementSyntax labeledStatement) - { - if (pendingLabeledStatement is not null) - { - throw new NotImplementedException("Multiple labels on the same statement"); - } + pendingLabeledStatement = labeledStatement; + translated = labeledStatement.Statement; + } - pendingLabeledStatement = labeledStatement; - translated = labeledStatement.Statement; - } + // Syntax optimization. This is an assignment of a block variable to some value. Render this as: + // var x = ; + // ... instead of: + // int x; + // x = ; + // ... except for expression context (i.e. on the last line), where we just return the value if needed. + if (expression is BinaryExpression { NodeType: ExpressionType.Assign, Left: ParameterExpression lValue } + && translated is AssignmentExpressionSyntax { Right: var valueSyntax } + && statementContext == ExpressionContext.Statement + && unassignedVariables.Remove(lValue)) + { + var useExplicitVariableType = valueSyntax.Kind() == SyntaxKind.NullLiteralExpression; - // Syntax optimization. This is an assignment of a block variable to some value. Render this as: - // var x = ; - // ... instead of: - // int x; - // x = ; - // ... except for expression context (i.e. on the last line), where we just return the value if needed. - if (expression is BinaryExpression { NodeType: ExpressionType.Assign, Left: ParameterExpression lValue } - && translated is AssignmentExpressionSyntax { Right: var valueSyntax } - && statementContext == ExpressionContext.Statement - && unassignedVariables.Remove(lValue)) - { - var useExplicitVariableType = valueSyntax.Kind() == SyntaxKind.NullLiteralExpression; + translated = useExplicitVariableType + ? _g.LocalDeclarationStatement(Generate(lValue.Type), LookupVariableName(lValue), valueSyntax) + : _g.LocalDeclarationStatement(LookupVariableName(lValue), valueSyntax); + } - translated = useExplicitVariableType - ? _g.LocalDeclarationStatement(Generate(lValue.Type), LookupVariableName(lValue), valueSyntax) - : _g.LocalDeclarationStatement(LookupVariableName(lValue), valueSyntax); - } + if (statementContext == ExpressionContext.Expression) + { + // We're on the last line of a block in expression context - the block is being lifted out. + // All statements before the last line (this one) have already been added to _liftedStatements, just return the last + // expression. + Check.DebugAssert(onLastBlockLine, "onLastBlockLine"); + Result = translated; + break; + } - if (statementContext == ExpressionContext.Expression) + if (blockContext != ExpressionContext.Expression) + { + if (_liftedState.Statements.Count > 0) { - // We're on the last line of a block in expression context - the block is being lifted out. - // All statements before the last line (this one) have already been added to _liftedStatements, just return the last - // expression. - Check.DebugAssert(onLastBlockLine, "onLastBlockLine"); - Result = translated; - break; + // If any expressions were lifted out of the current expression, flatten them into our own block, just before the + // expression from which they were lifted. Note that we don't do this in Expression context, since our own block is + // lifted out. + statements.AddRange(_liftedState.Statements); + _liftedState.Statements.Clear(); } - if (blockContext != ExpressionContext.Expression) + // Same for any variables being lifted out of the block; we add them to our own stack frame so that we can do proper + // variable name uniquification etc. + if (_liftedState.Variables.Count > 0) { - if (_liftedState.Statements.Count > 0) + foreach (var (parameter, name) in _liftedState.Variables) { - // If any expressions were lifted out of the current expression, flatten them into our own block, just before the - // expression from which they were lifted. Note that we don't do this in Expression context, since our own block is - // lifted out. - statements.AddRange(_liftedState.Statements); - _liftedState.Statements.Clear(); + stackFrame.Variables[parameter] = name; + stackFrame.VariableNames.Add(name); } - // Same for any variables being lifted out of the block; we add them to our own stack frame so that we can do proper - // variable name uniquification etc. - if (_liftedState.Variables.Count > 0) - { - foreach (var (parameter, name) in _liftedState.Variables) - { - stackFrame.Variables[parameter] = name; - stackFrame.VariableNames.Add(name); - } - - _liftedState.Variables.Clear(); - } - } - - // Skip useless expressions with no side effects in statement context (these can be the result of switch/conditional lifting - // with assignment lowering) - if (statementContext == ExpressionContext.Statement && !_sideEffectDetector.MayHaveSideEffects(translated)) - { - continue; + _liftedState.Variables.Clear(); } + } - var statement = translated switch - { - StatementSyntax s => s, + // Skip useless expressions with no side effects in statement context (these can be the result of switch/conditional lifting + // with assignment lowering) + if (statementContext == ExpressionContext.Statement && !_sideEffectDetector.MayHaveSideEffects(translated)) + { + continue; + } - // If this is the last line in an expression lambda, wrap it in a return statement. - ExpressionSyntax e when _onLastLambdaLine && statementContext == ExpressionContext.ExpressionLambda - => ReturnStatement(e), + var statement = translated switch + { + StatementSyntax s => s, - // If we're in statement context and we have an expression that can't stand alone (e.g. literal), assign it to discard - ExpressionSyntax e when statementContext == ExpressionContext.Statement && !IsExpressionValidAsStatement(e) - => ExpressionStatement((ExpressionSyntax)_g.AssignmentStatement(_g.IdentifierName("_"), e)), + // If this is the last line in an expression lambda, wrap it in a return statement. + ExpressionSyntax e when _onLastLambdaLine && statementContext == ExpressionContext.ExpressionLambda + => ReturnStatement(e), - ExpressionSyntax e => ExpressionStatement(e), + // If we're in statement context and we have an expression that can't stand alone (e.g. literal), assign it to discard + ExpressionSyntax e when statementContext == ExpressionContext.Statement && !IsExpressionValidAsStatement(e) + => ExpressionStatement((ExpressionSyntax)_g.AssignmentStatement(_g.IdentifierName("_"), e)), - _ => throw new ArgumentOutOfRangeException() - }; + ExpressionSyntax e => ExpressionStatement(e), - if (blockContext == ExpressionContext.Expression) - { - // This block is in expression context, and so will be lifted (we won't be returning a block). - _liftedState.Statements.Add(statement); - } - else - { - if (pendingLabeledStatement is not null) - { - statement = pendingLabeledStatement.WithStatement(statement); - pendingLabeledStatement = null; - } + _ => throw new ArgumentOutOfRangeException() + }; - statements.Add(statement); - } + if (blockContext == ExpressionContext.Expression) + { + // This block is in expression context, and so will be lifted (we won't be returning a block). + _liftedState.Statements.Add(statement); } - - // If a label existed on the last line of the block, add an empty statement (since C# requires it); for expression blocks we'd - // have to lift that, not supported for now. - if (pendingLabeledStatement is not null) + else { - if (blockContext == ExpressionContext.Expression) + if (pendingLabeledStatement is not null) { - throw new NotImplementedException("Label on last expression of an expression block"); + statement = pendingLabeledStatement.WithStatement(statement); + pendingLabeledStatement = null; } - else - { - statements.Add(pendingLabeledStatement.WithStatement(EmptyStatement())); - } - } - // Above we transform top-level assignments (i = 8) to var-declarations with initializers (var i = 8); those variables have - // already been taken care of and removed from the list. - // But there may still be variables that get assigned inside nested blocks or other situations; prepare declarations for those - // and either add them to the block, or lift them if we're an expression block. - var unassignedVariableDeclarations = - unassignedVariables.Select( - v => (LocalDeclarationStatementSyntax)_g.LocalDeclarationStatement(Generate(v.Type), LookupVariableName(v))); + statements.Add(statement); + } + } + // If a label existed on the last line of the block, add an empty statement (since C# requires it); for expression blocks we'd + // have to lift that, not supported for now. + if (pendingLabeledStatement is not null) + { if (blockContext == ExpressionContext.Expression) { - _liftedState.UnassignedVariableDeclarations.AddRange(unassignedVariableDeclarations); + throw new NotImplementedException("Label on last expression of an expression block"); } else { - statements.InsertRange(0, unassignedVariableDeclarations.Concat(_liftedState.UnassignedVariableDeclarations)); - _liftedState.UnassignedVariableDeclarations.Clear(); - - // We're done. If the block is in an expression context, it needs to be lifted out; but not if it's in a lambda (in that - // case we just added return above). - Result = Block(statements); + statements.Add(pendingLabeledStatement.WithStatement(EmptyStatement())); } + } + + // Above we transform top-level assignments (i = 8) to var-declarations with initializers (var i = 8); those variables have + // already been taken care of and removed from the list. + // But there may still be variables that get assigned inside nested blocks or other situations; prepare declarations for those + // and either add them to the block, or lift them if we're an expression block. + var unassignedVariableDeclarations = + unassignedVariables.Select( + v => (LocalDeclarationStatementSyntax)_g.LocalDeclarationStatement(Generate(v.Type), LookupVariableName(v), initializer: _g.DefaultExpression(Generate(v.Type)))); - return block; + if (blockContext == ExpressionContext.Expression) + { + _liftedState.UnassignedVariableDeclarations.AddRange(unassignedVariableDeclarations); } - finally + else { - _onLastLambdaLine = parentOnLastLambdaLine; - _liftedState = parentLiftedState; + statements.InsertRange(0, unassignedVariableDeclarations.Concat(_liftedState.UnassignedVariableDeclarations)); + _liftedState.UnassignedVariableDeclarations.Clear(); - if (ownStackFrame is not null) - { - var popped = _stack.Pop(); - Check.DebugAssert(popped.Equals(ownStackFrame), "popped.Equals(ownStackFrame)"); - } + // We're done. If the block is in an expression context, it needs to be lifted out; but not if it's in a lambda (in that + // case we just added return above). + Result = Block(statements); + } + + if (ownStackFrame is not null) + { + var popped = _stack.Pop(); + Check.DebugAssert(popped.Equals(ownStackFrame), "popped.Equals(ownStackFrame)"); } + _onLastLambdaLine = parentOnLastLambdaLine; + _liftedState = parentLiftedState; + + return block; + // Returns true for expressions which have side-effects, and can therefore appear alone as a statement static bool IsExpressionValidAsStatement(ExpressionSyntax expression) => expression.Kind() switch @@ -826,7 +841,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional } var parentLiftedState = _liftedState; - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); // If we're in a lambda body, we try to translate as an expression if possible (i.e. no blocks in the true/false arms). using (ChangeContext(ExpressionContext.Expression)) @@ -858,13 +873,13 @@ protected override Expression VisitConditional(ConditionalExpression conditional TranslateConditionalStatement( conditional.Update( conditional.Test, - conditional.IfTrue is BlockExpression ? conditional.IfTrue : E.Block(conditional.IfTrue), - conditional.IfFalse is BlockExpression ? conditional.IfFalse : E.Block(conditional.IfFalse)))); + conditional.IfTrue is BlockExpression ? conditional.IfTrue : Expression.Block(conditional.IfTrue), + conditional.IfFalse is BlockExpression ? conditional.IfFalse : Expression.Block(conditional.IfFalse)))); } // We're in regular expression context, and there are lifted expressions inside one of the arms; we translate to an if/else // statement but lowering an assignment into both sides of the condition - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); IdentifierNameSyntax assignmentVariable; TypeSyntax? loweredAssignmentVariableType = null; @@ -872,7 +887,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional if (lowerableAssignmentVariable is null) { var name = UniquifyVariableName("liftedConditional"); - var parameter = E.Parameter(conditional.Type, name); + var parameter = Expression.Parameter(conditional.Type, name); assignmentVariable = IdentifierName(name); loweredAssignmentVariableType = Generate(parameter.Type); } @@ -1101,11 +1116,6 @@ IEqualityComparer c Generate(typeof(Encoding)), IdentifierName(nameof(Encoding.Default))), - FieldInfo fieldInfo - => HandleFieldInfo(fieldInfo), - - //TODO: Handle PropertyInfo - _ => GenerateUnknownValue(value) }; @@ -1162,27 +1172,6 @@ ExpressionSyntax HandleValueTuple(ITuple tuple) return TupleExpression(SeparatedList(arguments)); } - - ExpressionSyntax HandleFieldInfo(FieldInfo fieldInfo) - => fieldInfo.DeclaringType is null - ? throw new NotSupportedException("Field without a declaring type: " + fieldInfo.Name) - : (ExpressionSyntax)InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(Generate(fieldInfo.DeclaringType)), - IdentifierName(nameof(Type.GetField))), - ArgumentList( - SeparatedList(new[] { - Argument(LiteralExpression( - SyntaxKind.StringLiteralExpression, - Literal(fieldInfo.Name))), - Argument(BinaryExpression( - SyntaxKind.BitwiseOrExpression, - HandleEnum(fieldInfo.IsStatic ? BindingFlags.Static : BindingFlags.Instance), - BinaryExpression( - SyntaxKind.BitwiseOrExpression, - HandleEnum(fieldInfo.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic), - HandleEnum(BindingFlags.DeclaredOnly)))) }))); } /// @@ -1200,6 +1189,12 @@ protected virtual ExpressionSyntax GenerateUnknownValue(object value) return DefaultExpression(Generate(type)); } + if (value is IRelationalQuotableExpression relationalQuotableExpression + && Translate(relationalQuotableExpression.Quote()) is ExpressionSyntax expressionSyntax) + { + return expressionSyntax; + } + throw new NotSupportedException( $"Encountered a constant of unsupported type '{value.GetType().Name}'. Only primitive constant nodes are supported." + Environment.NewLine + value); @@ -1227,35 +1222,48 @@ protected override Expression VisitGoto(GotoExpression gotoNode) /// protected override Expression VisitInvocation(InvocationExpression invocation) { - var lambda = (LambdaExpression)invocation.Expression; - - // We need to inline the lambda invocation into the tree, by replacing parameters in the lambda body with the invocation arguments. - // However, if an argument to the invocation can have side effects (e.g. a method call), and it's referenced multiple times from - // the body, then that would cause multiple evaluation, which is wrong (same if the arguments are evaluated only once but in reverse - // order). - // So we have to lift such arguments. - var arguments = new Expression[invocation.Arguments.Count]; - - for (var i = 0; i < arguments.Length; i++) + if (invocation.Expression is LambdaExpression lambda) { - var argument = invocation.Arguments[i]; + // We need to inline the lambda invocation into the tree, by replacing parameters in the lambda body with the invocation arguments. + // However, if an argument to the invocation can have side effects (e.g. a method call), and it's referenced multiple times from + // the body, then that would cause multiple evaluation, which is wrong (same if the arguments are evaluated only once but in reverse + // order). + // So we have to lift such arguments. + var arguments = new Expression[invocation.Arguments.Count]; - if (argument is ConstantExpression) + for (var i = 0; i < arguments.Length; i++) { - // No need to evaluate into a separate variable, just pass directly - arguments[i] = argument; - continue; + var argument = invocation.Arguments[i]; + + if (argument is ConstantExpression) + { + // No need to evaluate into a separate variable, just pass directly + arguments[i] = argument; + continue; + } + + // Need to lift + var name = UniquifyVariableName(lambda.Parameters[i].Name ?? "lifted"); + var parameter = Expression.Parameter(argument.Type, name); + _liftedState.Statements.Add(GenerateVarDeclaration(name, Translate(argument))); + _liftedState.VariableNames.Add(name); + arguments[i] = parameter; } - // Need to lift - var name = UniquifyVariableName(lambda.Parameters[i].Name ?? "lifted"); - var parameter = E.Parameter(argument.Type, name); - _liftedState.Statements.Add(GenerateVarDeclaration(name, Translate(argument))); - arguments[i] = parameter; + var replacedBody = new ReplacingExpressionVisitor(lambda.Parameters, arguments).Visit(lambda.Body); + Result = Translate(replacedBody); } + else + { + // The invocation is over a non-inline lambda expression (i.e. field/property/method) + var expression = (ExpressionSyntax)Translate(invocation.Expression); + + var translatedExpressions = TranslateList(invocation.Arguments); - var replacedBody = new ReplacingExpressionVisitor(lambda.Parameters, arguments).Visit(lambda.Body); - Result = Translate(replacedBody); + Result = InvocationExpression( + expression, + ArgumentList(SeparatedList(translatedExpressions.Select(Argument)))); + } return invocation; } @@ -1330,14 +1338,7 @@ protected virtual TypeSyntax Generate(Type type) generic); } - if (type.IsNested) - { - AddNamespace(type.DeclaringType!); - } - else if (type.Namespace != null) - { - _collectedNamespaces.Add(type.Namespace); - } + AddNamespace(type); return generic; } @@ -1509,10 +1510,10 @@ protected override Expression VisitLoop(LoopExpression loop) if (loop.ContinueLabel is not null) { - var blockBody = loop.Body is BlockExpression b ? b : E.Block(loop.Body); + var blockBody = loop.Body is BlockExpression b ? b : Expression.Block(loop.Body); blockBody = blockBody.Update( blockBody.Variables, - new[] { E.Label(loop.ContinueLabel) }.Concat(blockBody.Expressions)); + new[] { Expression.Label(loop.ContinueLabel) }.Concat(blockBody.Expressions)); rewrittenLoop1 = loop.Update( loop.BreakLabel, @@ -1525,9 +1526,9 @@ protected override Expression VisitLoop(LoopExpression loop) if (loop.BreakLabel is not null) { rewrittenLoop2 = - E.Block( + Expression.Block( rewrittenLoop1.Update(breakLabel: null, rewrittenLoop1.ContinueLabel, rewrittenLoop1.Body), - E.Label(loop.BreakLabel)); + Expression.Label(loop.BreakLabel)); } if (rewrittenLoop2 != loop) @@ -1567,18 +1568,27 @@ protected override Expression VisitMember(MemberExpression member) when constantExpression.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) && System.Attribute.IsDefined(constantExpression.Type, typeof(CompilerGeneratedAttribute), inherit: true): // Unwrap closure - VisitConstant(E.Constant(closureField.GetValue(constantExpression.Value), member.Type)); + VisitConstant(Expression.Constant(closureField.GetValue(constantExpression.Value), member.Type)); break; // TODO: private event default: - Result = MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - member.Expression is null - ? Generate(member.Member.DeclaringType!) // static - : Translate(member.Expression), - IdentifierName(member.Member.Name)); + var expression = member switch + { + // Static member + { Expression: null } => Generate(member.Member.DeclaringType!), + + // If the member isn't declared on the same type as the expression, (e.g. explicit interface implementation), add + // a cast up to the declaring type. + _ when member.Member.DeclaringType is Type declaringType && declaringType != member.Expression.Type + => ParenthesizedExpression( + CastExpression(Generate(member.Member.DeclaringType), Translate(member.Expression))), + + _ => Translate(member.Expression) + }; + + Result = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, expression, IdentifierName(member.Member.Name)); break; } @@ -1591,24 +1601,28 @@ when constantExpression.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected virtual void TranslateNonPublicMemberAccess(MemberExpression member) + protected virtual void TranslateNonPublicMemberAccess(MemberExpression memberExpression) { - if (member.Expression is null) + if (memberExpression.Expression is null) { throw new NotImplementedException("Private static field access"); } - var translatedExpression = Translate(member.Expression); - Result = ParenthesizedExpression( - CastExpression( - Generate(member.Type), - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - GenerateValue(member.Member), - IdentifierName(nameof(FieldInfo.GetValue))), - ArgumentList( - SingletonSeparatedList(Argument(translatedExpression)))))); + // Get an unsafe accessor for this field/property (this internally caches and adds it to the output list of unsafe accessors) + + // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] + // static extern ref int UnsafeAccessor_Foo_Name(Foo f); + var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration( + memberExpression.Member is PropertyInfo propertyInfo + ? propertyInfo.GetMethod ?? throw new UnreachableException("Attempting to read from property without getter") + : memberExpression.Member, + forWrite: false); + + // The unsafe accessor declaration has been created; invoke it. + Result = + _g.InvocationExpression( + _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), + Translate(memberExpression.Expression)); } /// @@ -1622,19 +1636,205 @@ protected virtual void TranslateNonPublicMemberAccess(MemberExpression member) Expression value, SyntaxKind assignmentKind) { - // LINQ expression trees can directly access private members. Use the .NET [UnsafeAccessor] feature. + // LINQ expression trees can directly access private members, but C# code cannot. Use the .NET [UnsafeAccessor] feature. if (memberExpression.Expression is null) { throw new NotImplementedException("Private static field assignment"); } - Result = InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - GenerateValue(memberExpression.Member), - IdentifierName(nameof(FieldInfo.SetValue))), - ArgumentList( - SeparatedList(new[] { Argument(Translate(memberExpression.Expression)), Argument(Translate(value)) }))); + // Get an unsafe accessor for this field/property (this internally caches and adds it to the output list of unsafe accessors) + + // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] + // static extern ref int UnsafeAccessor_Foo_Name(Foo f); + var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration( + memberExpression.Member is PropertyInfo propertyInfo + ? propertyInfo.SetMethod ?? throw new UnreachableException("Attempting to assign to property without setter") + : memberExpression.Member, + forWrite: true); + + // The unsafe accessor declaration has been created; invoke it. + Result = memberExpression.Member switch + { + FieldInfo => AssignmentExpression( + assignmentKind, + (ExpressionSyntax)_g.InvocationExpression( + _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), + Translate(memberExpression.Expression)), + Translate(value)), + + PropertyInfo => + _g.InvocationExpression( + _g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), Translate(memberExpression.Expression), + assignmentKind is SyntaxKind.SimpleAssignmentExpression + ? Translate(value) + : throw new NotImplementedException("Compound assignment of private property not yet supported")), + + _ => throw new UnreachableException() + }; + } + + private MethodDeclarationSyntax GetUnsafeAccessorDeclaration(MemberInfo member, bool forWrite = false) + { + MethodDeclarationSyntax? unsafeAccessorDeclaration; + + switch (member) + { + case FieldInfo field: + { + // Note that we generate two accessors for fields (get/set), since the get accessor needs to be used in expression trees, + // which don't support ref return + if (_fieldUnsafeAccessors.TryGetValue((field, forWrite), out unsafeAccessorDeclaration)) + { + return unsafeAccessorDeclaration; + } + + break; + } + + case MethodBase method: // Also constructors + { + if (_methodUnsafeAccessors.TryGetValue(method, out unsafeAccessorDeclaration)) + { + return unsafeAccessorDeclaration; + } + + break; + } + + default: + throw new UnreachableException(); + } + + _stringBuilder.Clear().Append("UnsafeAccessor_"); + + if (member.DeclaringType?.Namespace?.Replace(".", "_") is string typeNamespace) + { + _stringBuilder.Append(typeNamespace).Append('_'); + } + + _stringBuilder.Append(member.DeclaringType!.Name).Append('_'); + + var memberName = member.Name; + _stringBuilder.Append( + member switch + { + // If this is the backing field of an auto-property, extract the name of the property from its compiler-generated name + // (e.g. k__BackingField) + FieldInfo when memberName[0] == '<' && memberName.IndexOf(">k__BackingField", StringComparison.Ordinal) is > 1 and var pos + => memberName[1..pos], + ConstructorInfo => "Ctor", + _ => memberName + }); + + var unsafeAccessorName = _stringBuilder.ToString(); + + switch (member) + { + case FieldInfo field: + { + // Unsafe accessor for fields: + // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_bar")] + // private static extern ref int GetSetPrivateField(Foo f); + // Note that we generate two accessors for fields (get/set), since the get accessor needs to be used in expression trees, + // which don't support ref return + unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration( + unsafeAccessorName + (forWrite ? "_Set" : "_Get"), + accessibility: Accessibility.Private, + modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern, + returnType: forWrite + ? RefType(Generate(field.FieldType)) + : Generate(field.FieldType), + parameters: [_g.ParameterDeclaration("instance", Generate(member.DeclaringType))]); + + unsafeAccessorDeclaration = + (MethodDeclarationSyntax)_g.AddAttributes( + unsafeAccessorDeclaration, + _g.Attribute( + "UnsafeAccessor", + _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Field)), + _g.AttributeArgument( + nameof(UnsafeAccessorAttribute.Name), _g.LiteralExpression(member.Name)))); + break; + } + + case MethodInfo { IsStatic: false } method: + { + // Unsafe accessor for methods. Note that this is used also for property getter and setter: + // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Bar")] + // private static void SetPrivateProperty(Foo f, int value); + unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration( + unsafeAccessorName, + accessibility: Accessibility.Private, + modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern, + parameters: + [ + _g.ParameterDeclaration("instance", Generate(member.DeclaringType)), + .. method.GetParameters() + .Select( + p => _g.ParameterDeclaration( + p.Name ?? throw new UnreachableException("Missing parameter name"), + Generate(p.ParameterType))) + ]); + + unsafeAccessorDeclaration = + (MethodDeclarationSyntax)_g.AddAttributes( + unsafeAccessorDeclaration, + _g.Attribute( + "UnsafeAccessor", + _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Method)), + _g.AttributeArgument( + nameof(UnsafeAccessorAttribute.Name), _g.LiteralExpression(memberName)))); + + break; + } + + case ConstructorInfo constructor: + { + // Unsafe accessor for constructors: + // [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + // extern static Class PrivateCtor(int i); + unsafeAccessorDeclaration = (MethodDeclarationSyntax)_g.MethodDeclaration( + unsafeAccessorName, + accessibility: Accessibility.Private, + modifiers: DeclarationModifiers.Static | DeclarationModifiers.Extern, + returnType: Generate(member.DeclaringType), + parameters: constructor.GetParameters() + .Select( + p => _g.ParameterDeclaration( + p.Name ?? throw new UnreachableException("Missing parameter name"), + Generate(p.ParameterType)))); + + unsafeAccessorDeclaration = + (MethodDeclarationSyntax)_g.AddAttributes( + unsafeAccessorDeclaration, + _g.Attribute( + "UnsafeAccessor", + _g.MemberAccessExpression(Generate(typeof(UnsafeAccessorKind)), nameof(UnsafeAccessorKind.Constructor)))); + + break; + } + + default: + throw new UnreachableException("Unsafe declaration for unknown member type: " + member.GetType().Name); + } + + unsafeAccessorDeclaration = unsafeAccessorDeclaration + .WithBody(null) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + switch (member) + { + case FieldInfo field: + _fieldUnsafeAccessors[(field, forWrite)] = unsafeAccessorDeclaration; + break; + case MethodBase method: + _methodUnsafeAccessors[method] = unsafeAccessorDeclaration; + break; + default: + throw new UnreachableException(); + } + + return unsafeAccessorDeclaration; } /// @@ -1711,24 +1911,25 @@ protected override Expression VisitMethodCall(MethodCallExpression call) } else { - ExpressionSyntax expression; - if (call.Object is null) + var expression = call switch { - // Static method call. Recursively add MemberAccessExpressions for all declaring types (for methods on nested types) - expression = GetMemberAccessesForAllDeclaringTypes(call.Method.DeclaringType); + { Method.IsStatic: true } => GetMemberAccessesForAllDeclaringTypes(call.Method.DeclaringType), - ExpressionSyntax GetMemberAccessesForAllDeclaringTypes(Type type) - => type.DeclaringType is null - ? Generate(type) - : MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - GetMemberAccessesForAllDeclaringTypes(type.DeclaringType), - IdentifierName(type.Name)); - } - else - { - expression = Translate(call.Object); - } + // If the member isn't declared on the same type as the expression, (e.g. explicit interface implementation), add + // a cast up to the declaring type. + { Method.DeclaringType: Type declaringType, Object.Type: Type objectType, } when declaringType != objectType + => ParenthesizedExpression(CastExpression(Generate(declaringType), Translate(call.Object))), + + _ => Translate(call.Object) + }; + + ExpressionSyntax GetMemberAccessesForAllDeclaringTypes(Type type) + => type.DeclaringType is null + ? Generate(type) + : MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + GetMemberAccessesForAllDeclaringTypes(type.DeclaringType), + IdentifierName(type.Name)); if (call.Method.Name.StartsWith("get_", StringComparison.Ordinal) && call.Method.GetParameters().Length == 1 @@ -1834,32 +2035,16 @@ protected override Expression VisitNew(NewExpression node) return node; } - // If the type has any required properties and the constructor doesn't have [SetsRequiredMembers], we can't just generate an - // instantiation expression. - // TODO: Currently matching attributes by name since we target .NET 6.0. If/when we target .NET 7.0 and above, match the type. - if (node.Type.GetCustomAttributes(inherit: true) - .Any(a => a.GetType().FullName == "System.Runtime.CompilerServices.RequiredMemberAttribute") - && node.Constructor is not null - && node.Constructor.GetCustomAttributes() - .Any(a => a.GetType().FullName == "System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute") - != true) - { - // If the constructor is parameterless, we generate Activator.Create() which is almost as fast (<10ns difference). - // For constructors with parameters, we currently throw as not supported (we can pass parameters, but boxing, probably - // speed degradation etc.). - if (node.Constructor.GetParameters().Length == 0) - { - Result = - Translate( - E.Call( - (_activatorCreateInstanceMethod ??= typeof(Activator).GetMethod( - nameof(Activator.CreateInstance), [])!) - .MakeGenericMethod(node.Type))); - } - else - { - throw new NotImplementedException("Instantiation of type with required properties via constructor that has parameters"); - } + // If the constructor isn't public, or it has required properties and the constructor doesn't have [SetsRequiredMembers], we can't + // just generate a regular instantiation expression (won't compile). Generate an unsafe accessor instead. + if (node.Constructor is ConstructorInfo constructor + && (!constructor.IsPublic + || node.Type.GetCustomAttribute() is not null + && constructor.GetCustomAttribute() is null)) + { + var unsafeAccessorDeclaration = GetUnsafeAccessorDeclaration(constructor); + + Result = _g.InvocationExpression(_g.IdentifierName(unsafeAccessorDeclaration.Identifier.Text), arguments); } else { @@ -1870,9 +2055,9 @@ protected override Expression VisitNew(NewExpression node) initializer: null); } - if (node.Constructor?.DeclaringType is not null) + if (node.Type.Namespace is not null) { - AddNamespace(node.Constructor?.DeclaringType!); + AddNamespace(node.Type); } return node; @@ -1940,7 +2125,7 @@ protected virtual CSharpSyntaxNode TranslateSwitch(SwitchExpression switchNode, case ExpressionContext.Statement: { var parentLiftedState = _liftedState; - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); var cases = List( switchNode.Cases.Select( @@ -1993,7 +2178,7 @@ SyntaxList ProcessArmBody(Expression body) } var parentLiftedState = _liftedState; - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); // Translate all arms var arms = SeparatedList( @@ -2020,7 +2205,7 @@ SyntaxList ProcessArmBody(Expression body) // There are lifted expressions inside some of the arms, we must lift the entire switch expression, rewriting it to // a switch statement. - _liftedState = new LiftedState(); + _liftedState = parentLiftedState.CreateChild(); IdentifierNameSyntax assignmentVariable; TypeSyntax? loweredAssignmentVariableType = null; @@ -2028,7 +2213,7 @@ SyntaxList ProcessArmBody(Expression body) if (lowerableAssignmentVariable is null) { var name = UniquifyVariableName("liftedSwitch"); - var parameter = E.Parameter(switchNode.Type, name); + var parameter = Expression.Parameter(switchNode.Type, name); assignmentVariable = IdentifierName(name); loweredAssignmentVariableType = Generate(parameter.Type); } @@ -2116,8 +2301,8 @@ static ConditionalExpression RewriteSwitchToConditionals(SwitchExpression node) .Aggregate( node.DefaultBody, (expression, arm) => expression is null - ? E.IfThen(E.Equal(node.SwitchValue, arm.Label), arm.Body) - : E.IfThenElse(E.Equal(node.SwitchValue, arm.Label), arm.Body, expression)) + ? Expression.IfThen(Expression.Equal(node.SwitchValue, arm.Label), arm.Body) + : Expression.IfThenElse(Expression.Equal(node.SwitchValue, arm.Label), arm.Body, expression)) ?? throw new NotImplementedException("Empty switch statement")); } @@ -2128,8 +2313,8 @@ static ConditionalExpression RewriteSwitchToConditionals(SwitchExpression node) .Reverse() .Aggregate( node.DefaultBody, - (expression, arm) => E.Condition( - E.Equal(node.SwitchValue, arm.Label), + (expression, arm) => Expression.Condition( + Expression.Equal(node.SwitchValue, arm.Label), arm.Body, expression)); } @@ -2162,7 +2347,7 @@ protected override Expression VisitTry(TryExpression tryNode) Result = _g.TryCatchStatement( translatedBody, - catchClauses: [TranslateCatchBlock(E.Catch(typeof(Exception), tryNode.Fault), noType: true)]); + catchClauses: [TranslateCatchBlock(Expression.Catch(typeof(Exception), tryNode.Fault), noType: true)]); return tryNode; } @@ -2239,8 +2424,8 @@ protected override Expression VisitUnary(UnaryExpression unary) ExpressionType.Quote => operand, ExpressionType.UnaryPlus => PrefixUnaryExpression(SyntaxKind.UnaryPlusExpression, operand), ExpressionType.Unbox => operand, - ExpressionType.Increment => Translate(E.Add(unary.Operand, E.Constant(1))), - ExpressionType.Decrement => Translate(E.Subtract(unary.Operand, E.Constant(1))), + ExpressionType.Increment => Translate(Expression.Add(unary.Operand, Expression.Constant(1))), + ExpressionType.Decrement => Translate(Expression.Subtract(unary.Operand, Expression.Constant(1))), ExpressionType.PostIncrementAssign => PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, operand), ExpressionType.PostDecrementAssign => PostfixUnaryExpression(SyntaxKind.PostDecrementExpression, operand), ExpressionType.PreIncrementAssign => PrefixUnaryExpression(SyntaxKind.PreIncrementExpression, operand), @@ -2633,6 +2818,10 @@ public bool MayHaveSideEffects(SyntaxNode node) public override void Visit(SyntaxNode node) { _mayHaveSideEffects |= MayHaveSideEffectsCore(node); + if (_mayHaveSideEffects) + { + return; + } base.Visit(node); } @@ -2640,9 +2829,10 @@ public override void Visit(SyntaxNode node) private static bool MayHaveSideEffectsCore(SyntaxNode node) => node switch { - IdentifierNameSyntax or LiteralExpressionSyntax => false, + IdentifierNameSyntax or LiteralExpressionSyntax or PredefinedTypeSyntax => false, ExpressionStatementSyntax e => MayHaveSideEffectsCore(e.Expression), EmptyStatementSyntax => false, + DefaultExpressionSyntax => false, // TODO: we can exempt most binary and unary expressions as well, e.g. i + 5, but not anything involving assignment _ => true diff --git a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs new file mode 100644 index 00000000000..11ae12ebae7 --- /dev/null +++ b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs @@ -0,0 +1,1128 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Runtime.ExceptionServices; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Editing; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class PrecompiledQueryCodeGenerator +{ + private readonly QueryLocator _queryLocator; + private readonly CSharpToLinqTranslator _csharpToLinqTranslator; + + private SyntaxGenerator _g = null!; + private IQueryCompiler _queryCompiler = null!; + private ExpressionTreeFuncletizer _funcletizer = null!; + private LinqToCSharpSyntaxTranslator _linqToCSharpTranslator = null!; + private LiftableConstantProcessor _liftableConstantProcessor = null!; + + private Symbols _symbols; + + private readonly HashSet _namespaces = new(); + private readonly HashSet _unsafeAccessors = new(); + private readonly IndentedStringBuilder _code = new(); + + private const string InterceptorsNamespace = "Microsoft.EntityFrameworkCore.GeneratedInterceptors"; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public PrecompiledQueryCodeGenerator() + { + _queryLocator = new QueryLocator(); + _csharpToLinqTranslator = new CSharpToLinqTranslator(); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IReadOnlyList GeneratePrecompiledQueries( + Compilation compilation, + SyntaxGenerator syntaxGenerator, + DbContext dbContext, + List precompilationErrors, + Assembly? additionalAssembly = null, + CancellationToken cancellationToken = default) + { + _queryLocator.Initialize(compilation); + _symbols = Symbols.Load(compilation); + _g = syntaxGenerator; + _linqToCSharpTranslator = new LinqToCSharpSyntaxTranslator(_g); + _liftableConstantProcessor = new LiftableConstantProcessor(null!); + _queryCompiler = dbContext.GetService(); + _unsafeAccessors.Clear(); + _funcletizer = new ExpressionTreeFuncletizer( + dbContext.Model, + dbContext.GetService(), + dbContext.GetType(), + generateContextAccessors: false, + dbContext.GetService>()); + + // This must be done after we complete generating the final compilation above + _csharpToLinqTranslator.Load(compilation, dbContext, additionalAssembly); + + // TODO: Ignore our auto-generated code! Also compiled model, generated code (comment, filename...?). + var generatedSyntaxTrees = new List(); + foreach (var syntaxTree in compilation.SyntaxTrees) + { + if (_queryLocator.LocateQueries(syntaxTree, precompilationErrors, cancellationToken) is not { Count: > 0 } locatedQueries) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(syntaxTree); + var generatedSyntaxTree = ProcessSyntaxTreeAsync( + syntaxTree, semanticModel, locatedQueries, precompilationErrors, cancellationToken); + if (generatedSyntaxTree is not null) + { + generatedSyntaxTrees.Add(generatedSyntaxTree); + } + } + + return generatedSyntaxTrees; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual GeneratedInterceptorFile? ProcessSyntaxTreeAsync( + SyntaxTree syntaxTree, + SemanticModel semanticModel, + IReadOnlyList locatedQueries, + List precompilationErrors, + CancellationToken cancellationToken) + { + var queriesPrecompiledInFile = 0; + _namespaces.Clear(); + _code.Clear(); + _code + .AppendLine() + .AppendLine("#pragma warning disable EF9100 // Precompiled query is experimental") + .AppendLine() + .Append("namespace ").AppendLine(InterceptorsNamespace) + .AppendLine("{") + .IncrementIndent() + .AppendLine("file static class EntityFrameworkCoreInterceptors") + .AppendLine("{") + .IncrementIndent(); + + for (var queryNum = 0; queryNum < locatedQueries.Count; queryNum++) + { + var querySyntax = locatedQueries[queryNum]; + + try + { + // We have a query lambda, as a Roslyn syntax tree. Translate to LINQ expression tree. + // TODO: Add verification that this is an EF query over our user's context. If translation returns null the moment + // there's another query root (another context or another LINQ provider), that's fine. + if (_csharpToLinqTranslator.Translate(querySyntax, semanticModel) is not MethodCallExpression terminatingOperator) + { + throw new UnreachableException("Non-method call encountered as the root of a LINQ query"); + } + + // We have a LINQ representation of the query tree as it appears in the user's source code, but this isn't the same as the + // LINQ tree the EF query pipeline needs to get; the latter is the result of evaluating the queryable operators in the user's + // source code. For example, in the user's code the root is a DbSet as the root, but the expression tree we require needs to + // contain an EntityQueryRootExpression. To get the LINQ tree for EF, we need to evaluate the operator chain, building an + // expression tree as usual. + + // However, we cannot evaluate the last operator, since that would execute the query instead of returning an expression tree. + // So we need to chop off the last operator before evaluation, and then (optionally) recompose it back afterwards. + // For ToList(), we don't actually recompose it (since ToList() isn't a node in the expression tree), and for async operators, + // we need to rewrite them to their sync counterparts (since that's what gets injected into the query tree). + var penultimateOperator = terminatingOperator switch + { + // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators) + { Object: Expression @object } => @object, + { Arguments: [var sourceArgument, ..] } => sourceArgument, + _ => throw new UnreachableException() + }; + + penultimateOperator = Expression.Lambda>(penultimateOperator) + .Compile(preferInterpretation: true)().Expression; + + // Pass the query through EF's query pipeline; this returns the query's executor function, which can produce an enumerable + // that invokes the query. + // Note that we cannot recompose the terminating operator on top of the evaluated penultimate, since method signatures + // may not allow that (e.g. DbSet.AsAsyncEnumerable() requires a DbSet, but the evaluated value for a DbSet is + // EntityQueryRootExpression. So we handle the penultimate and the terminating separately. + var queryExecutor = CompileQuery(penultimateOperator, terminatingOperator); + + // The query has been compiled successfully by the EF query pipeline. + // Now go over each LINQ operator, generating an interceptor for it. + _code.AppendLine($"#region Query{queryNum + 1}").AppendLine(); + + try + { + _funcletizer.ResetPathCalculation(); + + if (querySyntax is not { Expression: MemberAccessExpressionSyntax { Expression: var penultimateOperatorSyntax } }) + { + throw new UnreachableException(); + } + + // Generate interceptors for all LINQ operators in the query, starting from the root up until the penultimate. + // Then generate the interceptor for the terminating operator, and finally the query's executor. + GenerateOperatorInterceptorsRecursively( + _code, penultimateOperator, penultimateOperatorSyntax, semanticModel, queryNum + 1, out var operatorNum, + cancellationToken: cancellationToken); + + GenerateOperatorInterceptor( + _code, terminatingOperator, querySyntax, semanticModel, queryNum + 1, operatorNum + 1, isTerminatingOperator: true, + cancellationToken); + + GenerateQueryExecutor(_code, queryNum + 1, queryExecutor, _namespaces, _unsafeAccessors); + } + finally + { + _code + .AppendLine() + .AppendLine($"#endregion Query{queryNum + 1}"); + } + } + catch (Exception e) + { + precompilationErrors.Add(new(querySyntax, e)); + continue; + } + + // We're done generating the interceptors for the query's LINQ operators. + + queriesPrecompiledInFile++; + } + + if (queriesPrecompiledInFile == 0) + { + return null; + } + + // Output all the unsafe accessors that were generated for all intercepted shapers, e.g.: + // [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] + // static extern ref int GetSet_Foo_Name(Foo f); + if (_unsafeAccessors.Count > 0) + { + _code.AppendLine("#region Unsafe accessors"); + foreach (var unsafeAccessor in _unsafeAccessors) + { + _code.AppendLine(unsafeAccessor.NormalizeWhitespace().ToFullString()); + } + _code.AppendLine("#endregion Unsafe accessors"); + } + + _code + .DecrementIndent().AppendLine("}") + .DecrementIndent().AppendLine("}"); + + var mainCode = _code.ToString(); + + _code.Clear(); + _code.AppendLine("// ").AppendLine(); + + foreach (var ns in _namespaces + // In addition to the namespaces auto-detected by LinqToCSharpTranslator, we manually add these namespaces which are required + // by manually generated code above. + .Append("System") + .Append("System.Collections.Concurrent") + .Append("System.Collections.Generic") + .Append("System.Linq") + .Append("System.Linq.Expressions") + .Append("System.Runtime.CompilerServices") + .Append("System.Reflection") + .Append("System.Threading.Tasks") + .Append("Microsoft.EntityFrameworkCore") + .Append("Microsoft.EntityFrameworkCore.ChangeTracking.Internal") + .Append("Microsoft.EntityFrameworkCore.Diagnostics") + .Append("Microsoft.EntityFrameworkCore.Infrastructure") + .Append("Microsoft.EntityFrameworkCore.Infrastructure.Internal") + .Append("Microsoft.EntityFrameworkCore.Internal") + .Append("Microsoft.EntityFrameworkCore.Metadata") + .Append("Microsoft.EntityFrameworkCore.Query") + .Append("Microsoft.EntityFrameworkCore.Query.Internal") + .Append("Microsoft.EntityFrameworkCore.Storage") + .OrderBy( + ns => ns switch + { + _ when ns.StartsWith("System.", StringComparison.Ordinal) => 10, + _ when ns.StartsWith("Microsoft.", StringComparison.Ordinal) => 9, + _ => 0 + }) + .ThenBy(ns => ns)) + { + _code.Append("using ").Append(ns).AppendLine(";"); + } + + _code.AppendLine(mainCode); + + _code.AppendLine( + """ +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) { } + } +} +"""); + + return new( + $"{Path.GetFileNameWithoutExtension(syntaxTree.FilePath)}.EFInterceptors.g{Path.GetExtension(syntaxTree.FilePath)}", + _code.ToString()); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual Expression CompileQuery(Expression penultimateOperator, MethodCallExpression terminatingOperator) + { + // First, check whether this is an async query. + var async = terminatingOperator.Type.IsGenericType + && terminatingOperator.Type.GetGenericTypeDefinition() is var genericDefinition + && (genericDefinition == typeof(Task<>) || genericDefinition == typeof(ValueTask<>)); + + var preparedQuery = PrepareQueryForCompilation(penultimateOperator, terminatingOperator); + + // We now need to figure out the return type of the query's executor. + // Non-scalar query expressions (e.g. ToList()) return an IQueryable; the query executor will return an enumerable (sync or async). + // Scalar query expressions just return the scalar type. + var returnType = preparedQuery.Type.IsGenericType + && preparedQuery.Type.GetGenericTypeDefinition().IsAssignableTo(typeof(IQueryable)) + ? (async + ? typeof(IAsyncEnumerable<>) + : typeof(IEnumerable<>)).MakeGenericType(preparedQuery.Type.GetGenericArguments()[0]) + : terminatingOperator.Type; + + // We now have the query as a finalized LINQ expression tree, ready for compilation. + // Compile the query, invoking CompileQueryToExpression on the IQueryCompiler from the user's context instance. + try + { + return (Expression)_queryCompiler.GetType() + .GetMethod(nameof(IQueryCompiler.PrecompileQuery))! + .MakeGenericMethod(returnType) + .Invoke(_queryCompiler, [preparedQuery, async])!; + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // Unwrap the TargetInvocationException wrapper we get from Invoke() + ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + throw; + } + } + + private void GenerateOperatorInterceptorsRecursively( + IndentedStringBuilder code, + Expression operatorExpression, + ExpressionSyntax operatorSyntax, + SemanticModel semanticModel, + int queryNum, + out int operatorNum, + CancellationToken cancellationToken) + { + // For non-root operators, we get here with an InvocationExpressionSyntax and its corresponding LINQ MethodCallExpression. + // For the query root, we usually don't get called here: a regular EntityQueryRootExpression corresponds to a DbSet (either + // property access on DbContext or a Set<>() method invocation). We can't intercept property accesses, and in any case there's + // nothing to intercept there. + // However, for FromSql specifically, we get here with an InvocationExpressionSyntax (representing the FromSql() invocation), but + // with a corresponding FromSqlQueryRootExpression - not a MethodCallExpression. We must pass this query root through the + // funcletizer as usual to mimic the normal flow. + switch (operatorExpression) + { + // Regular, non-root LINQ operator; the LINQ method call must correspond to a Roslyn syntax invocation. + // We first recurse to handle the nested operator (i.e. generate the interceptor from the root outer). + case MethodCallExpression operatorMethodCall: + if (operatorSyntax is not InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax { Expression: var nestedOperatorSyntax } + }) + { + throw new UnreachableException(); + } + + // We're an operator (not the query root). + // Continue recursing down - we want to handle from the root up. + + var nestedOperatorExpression = operatorMethodCall switch + { + // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators) + { Object: Expression @object } => @object, + { Arguments: [var sourceArgument, ..] } => sourceArgument, + _ => throw new UnreachableException() + }; + + GenerateOperatorInterceptorsRecursively( + code, nestedOperatorExpression, nestedOperatorSyntax, semanticModel, queryNum, out operatorNum, + cancellationToken: cancellationToken); + + operatorNum++; + + GenerateOperatorInterceptor( + code, operatorExpression, operatorSyntax, semanticModel, queryNum, operatorNum, isTerminatingOperator: false, + cancellationToken); + return; + + // For FromSql() queries, an InvocationExpressionSyntax (representing the FromSql() invocation), but with a corresponding + // FromSqlQueryRootExpression - not a MethodCallExpression. + // We must generate an interceptor for FromSql() and pass the arguments array through the funcletizer as usual. + case FromSqlQueryRootExpression: + operatorNum = 1; + GenerateOperatorInterceptor( + code, operatorExpression, operatorSyntax, semanticModel, queryNum, operatorNum, isTerminatingOperator: false, + cancellationToken); + return; + + // For other query roots, we don't generate interceptors - there are no possible captured variables that need to be + // pass through funcletization (as with FromSqlQueryRootExpression). Simply return to process the first non-root operator. + case QueryRootExpression: + operatorNum = 0; + return; + + default: + throw new UnreachableException(); + } + } + + private void GenerateOperatorInterceptor( + IndentedStringBuilder code, + Expression operatorExpression, + ExpressionSyntax operatorSyntax, + SemanticModel semanticModel, + int queryNum, + int operatorNum, + bool isTerminatingOperator, + CancellationToken cancellationToken) + { + // At this point we know we're intercepting a method call invocation. + // Extract the MemberAccessExpressionSyntax for the invocation, representing the method being called. + var memberAccessSyntax = (operatorSyntax as InvocationExpressionSyntax)?.Expression as MemberAccessExpressionSyntax + ?? throw new UnreachableException(); + + // Create the parameter list for our interceptor method from the LINQ operator method's parameter list + if (semanticModel.GetSymbolInfo(memberAccessSyntax, cancellationToken).Symbol is not IMethodSymbol operatorSymbol) + { + throw new InvalidOperationException("Couldn't find method symbol for: " + memberAccessSyntax); + } + + // Throughout the code generation below, we will only be dealing with the original generic definition of the operator (and + // generating a generic interceptor); we'll never be dealing with the concrete types for this invocation, since these may + // be unspeakable anonymous types which we can't embed in generated code. + operatorSymbol = operatorSymbol.OriginalDefinition; + + // For extension methods, this provides the form which has the "this" as its first parameter. + // TODO: Validate the below, throw informative (e.g. top-level TVF fails here because non-generic) + var reducedOperatorSymbol = operatorSymbol.GetConstructedReducedFrom() ?? operatorSymbol; + + var (sourceVariableName, sourceTypeSymbol) = reducedOperatorSymbol.IsStatic + ? (reducedOperatorSymbol.Parameters[0].Name, reducedOperatorSymbol.Parameters[0].Type) + : ("source", reducedOperatorSymbol.ReceiverType!); + + if (sourceTypeSymbol is not INamedTypeSymbol { TypeArguments: [var sourceElementTypeSymbol]}) + { + throw new UnreachableException($"Non-IQueryable first parameter in LINQ operator '{operatorSymbol.Name}'"); + } + + var sourceElementTypeName = sourceElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + var returnTypeSymbol = reducedOperatorSymbol.ReturnType; + + // Unwrap Task to get the element type (e.g. Task>) + var returnTypeWithoutTask = returnTypeSymbol is INamedTypeSymbol namedReturnType + && returnTypeSymbol.OriginalDefinition.Equals(_symbols.GenericTask, SymbolEqualityComparer.Default) + ? namedReturnType.TypeArguments[0] + : returnTypeSymbol; + + var returnElementTypeSymbol = returnTypeWithoutTask switch + { + IArrayTypeSymbol arrayTypeSymbol => arrayTypeSymbol.ElementType, + INamedTypeSymbol namedReturnType2 + when namedReturnType2.AllInterfaces.Prepend(namedReturnType2) + .Any( + i => i.OriginalDefinition.Equals(_symbols.GenericEnumerable, SymbolEqualityComparer.Default) + || i.OriginalDefinition.Equals(_symbols.GenericAsyncEnumerable, SymbolEqualityComparer.Default) + || i.OriginalDefinition.Equals(_symbols.GenericEnumerator, SymbolEqualityComparer.Default)) + => namedReturnType2.TypeArguments[0], + _ => null + }; + + // Output the interceptor method signature preceded by the [InterceptsLocation] attribute. + var startPosition = operatorSyntax.SyntaxTree.GetLineSpan(memberAccessSyntax.Name.Span, cancellationToken).StartLinePosition; + var interceptorName = $"Query{queryNum}_{memberAccessSyntax.Name}{operatorNum}"; + code.AppendLine($"""[InterceptsLocation("{operatorSyntax.SyntaxTree.FilePath}", {startPosition.Line + 1}, {startPosition.Character + 1})]"""); + GenerateInterceptorMethodSignature(); + code.AppendLine("{").IncrementIndent(); + + // If this is the first query operator (no nested operator), cast the input source to IInfrastructure and extract the + // DbContext, create a new QueryContext, and wrap it all in a PrecompiledQueryContext that will flow through to the + // terminating operator, where the query will actually get executed. + // Otherwise, if this is a non-first operator, receive the PrecompiledQueryContext from the nested operator and flow it forward. + code.AppendLine( + "var precompiledQueryContext = " + + (operatorNum == 1 + ? $"new PrecompiledQueryContext<{sourceElementTypeName}>(((IInfrastructure){sourceVariableName}).Instance);" + : $"(PrecompiledQueryContext<{sourceElementTypeName}>){sourceVariableName};")); + + var declaredQueryContextVariable = false; + + ProcessCapturedVariables(); + + if (isTerminatingOperator) + { + // We're intercepting the query's terminating operator - this is where the query actually gets executed. + if (!declaredQueryContextVariable) + { + code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;"); + } + + var executorFieldIdentifier = $"Query{queryNum}_Executor"; + code.AppendLine( + $"{executorFieldIdentifier} ??= Query{queryNum}_GenerateExecutor(precompiledQueryContext.DbContext, precompiledQueryContext.QueryContext);"); + + if (returnElementTypeSymbol is null) + { + // The query returns a scalar, not an enumerable (e.g. the terminating operator is Max()). + // The executor directly returns the needed result (e.g. int), so just return that. + var returnType = returnTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + code.AppendLine($"return ((Func)({executorFieldIdentifier}))(queryContext);"); + } + else + { + // The query returns an IEnumerable/IAsyncEnumerable/IQueryable, which is a bit trickier: the executor doesn't return a + // simple value as in the scalar case, but rather e.g. SingleQueryingEnumerable; we need to compose the terminating + // operator (e.g. ToList()) on top of that. Cast the executor delegate to Func> + // (contravariance). + var isAsync = + operatorExpression.Type.IsGenericType + && operatorExpression.Type.GetGenericTypeDefinition() is var genericDefinition + && ( + genericDefinition == typeof(Task<>) + || genericDefinition == typeof(ValueTask<>) + || genericDefinition == typeof(IAsyncEnumerable<>)); + + var isQueryable = !isAsync + && operatorExpression.Type.IsGenericType + && operatorExpression.Type.GetGenericTypeDefinition() == typeof(IQueryable<>); + + var returnValue = isAsync + ? $"IAsyncEnumerable<{sourceElementTypeName}>" + : $"IEnumerable<{sourceElementTypeName}>"; + + code.AppendLine( + $"var queryingEnumerable = ((Func)({executorFieldIdentifier}))(queryContext);"); + + if (isQueryable) + { + // If the terminating operator returns IQueryable, that means the query is actually evaluated via foreach + // (i.e. there's no method such as AsEnumerable/ToList which evaluates). Note that this is necessarily sync only - + // IQueryable can't be directly inside await foreach (AsAsyncEnumerable() is required). + // For this case, we need to compose AsQueryable() on top, to make the querying enumerable compatible with the + // operator signature. + code.AppendLine("return queryingEnumerable.AsQueryable();"); + } + else + { + if (isAsync) + { + // For sync queries, we get an IEnumerable above, and can just compose the original terminating operator + // directly on top of that (ToList(), ToDictionary()...). + // But for async queries, we get an IAsyncEnumerable above, but cannot directly compose the original + // terminating operator (ToListAsync(), ToDictionaryAsync()...), since those require an IQueryable in their + // signature (which they internally case to IAsyncEnumerable). + // So we introduce an adapter in the middle, which implements both IQueryable (to be able to compose + // ToListAsync() on top), and IAsyncEnumerable (so that the actual implementation of ToListAsync() works). + // TODO: This is an additional runtime allocation; if we had System.Linq.Async we wouldn't need this. We could + // have additional versions of all async terminating operators over IAsyncEnumerable (effectively duplicating + // System.Linq.Async) as an alternative. + code.AppendLine($"var asyncQueryingEnumerable = new PrecompiledQueryableAsyncEnumerableAdapter<{sourceElementTypeName}>(queryingEnumerable);"); + code.Append("return asyncQueryingEnumerable"); + } + else + { + code.Append("return queryingEnumerable"); + } + + // Invoke the original terminating operator (e.g. ToList(), ToDictionary()...) on the querying enumerable, passing + // through the interceptor's arguments. + code.AppendLine( + $".{memberAccessSyntax.Name}({string.Join(", ", operatorSymbol.Parameters.Select(p => p.Name))});"); + } + } + } + else + { + // Non-terminating operator - we need to flow precompiledQueryContext forward. + + // The operator returns a different IQueryable type as its source (e.g. Select), convert the precompiledQueryContext + // before returning it. + Check.DebugAssert(returnElementTypeSymbol is not null, "Non-terminating operator must return IEnumerable"); + + code.AppendLine( + returnTypeSymbol switch + { + // The operator return IQueryable or IOrderedQueryable. + // If T is the same as the source, simply return our context as is (note that PrecompiledQueryContext implements + // IOrderedQueryable). Otherwise, e.g. Select() is being applied - change the context's type. + _ when returnTypeSymbol.OriginalDefinition.Equals(_symbols.IQueryable, SymbolEqualityComparer.Default) + || returnTypeSymbol.OriginalDefinition.Equals(_symbols.IOrderedQueryable, SymbolEqualityComparer.Default) + => SymbolEqualityComparer.Default.Equals(sourceElementTypeSymbol, returnElementTypeSymbol) + ? "return precompiledQueryContext;" + : $"return precompiledQueryContext.ToType<{returnElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();", + + _ when returnTypeSymbol.OriginalDefinition.Equals(_symbols.IIncludableQueryable, SymbolEqualityComparer.Default) + && returnTypeSymbol is INamedTypeSymbol { OriginalDefinition.TypeArguments: [_, var includedPropertySymbol] } + => $"return precompiledQueryContext.ToIncludable<{includedPropertySymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();", + + _ => throw new UnreachableException() + }); + } + + code.DecrementIndent().AppendLine("}").AppendLine(); + + void GenerateInterceptorMethodSignature() + { + code + .Append("public static ") + .Append(_g.TypeExpression(reducedOperatorSymbol.ReturnType).ToFullString()) + .Append(' ') + .Append(interceptorName); + + var (typeParameters, constraints) = (reducedOperatorSymbol.IsGenericMethod, reducedOperatorSymbol.ContainingType.IsGenericType) switch + { + (true, false) => (reducedOperatorSymbol.TypeParameters, ((MethodDeclarationSyntax)_g.MethodDeclaration(reducedOperatorSymbol)).ConstraintClauses), + (false, true) => (reducedOperatorSymbol.ContainingType.TypeParameters, ((TypeDeclarationSyntax)_g.Declaration(reducedOperatorSymbol.ContainingType)).ConstraintClauses), + (false, false) => ([], []), + (true, true) => throw new NotImplementedException("Generic method on generic type not supported") + }; + + if (typeParameters.Length > 0) + { + code.Append('<'); + for (var i = 0; i < typeParameters.Length; i++) + { + if (i > 0) + { + code.Append(", "); + } + + code.Append(_g.TypeExpression(typeParameters[i]).ToFullString()); + } + + code.Append('>'); + } + + code.Append('('); + + // For instance methods (IEnumerable.GetEnumerator(), DbSet.GetAsyncEnumerable()...), we generate an extension method + // (with this) for the interceptor. + if (reducedOperatorSymbol is { IsStatic: false, ReceiverType: not null }) + { + code + .Append("this ") + .Append(_g.TypeExpression(reducedOperatorSymbol.ReceiverType).ToFullString()) + .Append(' ') + .Append(sourceVariableName); + } + + for (var i = 0; i < reducedOperatorSymbol.Parameters.Length; i++) + { + var parameter = reducedOperatorSymbol.Parameters[i]; + + if (i == 0) + { + switch (reducedOperatorSymbol) + { + case { IsExtensionMethod: true }: + code.Append("this "); + break; + + // For instance methods we already added a this parameter above + case { IsStatic: false, ReceiverType: not null }: + code.Append(", "); + break; + + default: + throw new NotImplementedException("Non-extension static method not supported"); + } + } + else + { + code.Append(", "); + } + + code + .Append(_g.TypeExpression(parameter.Type).ToFullString()) + .Append(' ') + .Append(parameter.Name); + } + + code.AppendLine(")"); + + foreach (var f in constraints) + { + code.AppendLine(f.NormalizeWhitespace().ToFullString()); + } + } + + void ProcessCapturedVariables() + { + // Go over the operator's arguments (skipping the first, which is the source). + // For those which have captured variables, run them through our funcletizer, which will return code for extracting any captured + // variables from them. + switch (operatorExpression) + { + // Regular case: this is an operator method + case MethodCallExpression operatorMethodCall: + { + var parameters = operatorMethodCall.Method.GetParameters(); + + for (var i = 1; i < parameters.Length; i++) + { + var parameter = parameters[i]; + + if (parameter.ParameterType == typeof(CancellationToken)) + { + continue; + } + + if (_funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i) is not ExpressionTreeFuncletizer.PathNode + evaluatableRootPaths) + { + // There are no captured variables in this lambda argument - skip the argument + continue; + } + + // We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code + // which extracts them and sets them on our query context. + if (!declaredQueryContextVariable) + { + code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;"); + declaredQueryContextVariable = true; + } + + if (!parameter.ParameterType.IsSubclassOf(typeof(Expression))) + { + // Special case: this is a non-lambda argument (Skip/Take/FromSql). + // Simply add the argument directly as a parameter + code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameter.Name});"""); + continue; + } + + var variableCounter = 0; + + // Lambda argument. Recurse through evaluatable path trees. + foreach (var child in evaluatableRootPaths.Children!) + { + GenerateCapturedVariableExtractors(parameter.Name!, parameter.ParameterType, child); + + void GenerateCapturedVariableExtractors( + string currentIdentifier, + Type currentType, + ExpressionTreeFuncletizer.PathNode capturedVariablesPathTree) + { + var linqPathSegment = + capturedVariablesPathTree.PathFromParent!(Expression.Parameter(currentType, currentIdentifier)); + var collectedNamespaces = new HashSet(); + var unsafeAccessors = new HashSet(); + var roslynPathSegment = _linqToCSharpTranslator.TranslateExpression( + linqPathSegment, constantReplacements: null, collectedNamespaces, unsafeAccessors); + + var variableName = capturedVariablesPathTree.ExpressionType.Name; + variableName = char.ToLower(variableName[0]) + variableName[1..^"Expression".Length] + ++variableCounter; + code.AppendLine( + $"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};"); + + if (capturedVariablesPathTree.Children?.Count > 0) + { + // This is an intermediate node which has captured variables in the children. Continue recursing down. + foreach (var child in capturedVariablesPathTree.Children) + { + GenerateCapturedVariableExtractors(variableName, capturedVariablesPathTree.ExpressionType, child); + } + + return; + } + + // We've reached a leaf, meaning that it's an evaluatable node that contains captured variables. + // Generate code to evaluate this node and assign the result to the parameters dictionary: + + // TODO: For the common case of a simple parameter (member access over closure type), generate reflection code directly + // TODO: instead of going through the interpreter, as we do in the funcletizer itself (for perf) + // TODO: Remove the convert to object. We can flow out the actual type of the evaluatable root, and just stick it + // in Func<> instead of object. + // TODO: For specific cases, don't go through the interpreter, but just integrate code that extracts the value directly. + // (see ExpressionTreeFuncletizer.Evaluate()). + // TODO: Basically this means that the evaluator should come from ExpressionTreeFuncletizer itself, as part of its outputs + // TODO: Integrate try/catch around the evaluation? + code.AppendLine("queryContext.AddParameter("); + using (code.Indent()) + { + code + .Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",") + .AppendLine($"Expression.Lambda>(Expression.Convert({variableName}, typeof(object)))") + .AppendLine(".Compile(preferInterpretation: true)") + .AppendLine(".Invoke());"); + } + } + } + } + + break; + } + + // Special case: this is a FromSql query root; we're intercepting the invocation syntax for the FromSql() call, but on the LINQ + // side we have a query root (i.e. not the MethodCallExpression for the FromSql(), but rather its evaluated result) + case FromSqlQueryRootExpression fromSqlQueryRoot: + { + if (_funcletizer.CalculatePathsToEvaluatableRoots(fromSqlQueryRoot.Argument) is not ExpressionTreeFuncletizer.PathNode + evaluatableRootPaths) + { + // There are no captured variables in this FromSqlQueryRootExpression, skip it. + break; + } + + // We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code + // which extracts them and sets them on our query context. + if (!declaredQueryContextVariable) + { + code.AppendLine("var queryContext = precompiledQueryContext.QueryContext;"); + declaredQueryContextVariable = true; + } + + var argumentsParameter = reducedOperatorSymbol switch + { + { Name: "FromSqlRaw", Parameters: [_, _, { Name: "parameters" }] } => "parameters", + { Name: "FromSql", Parameters: [_, { Name: "sql" }] } => "sql.GetArguments()", + { Name: "FromSqlInterpolated", Parameters: [_, { Name: "sql" }] } => "sql.GetArguments()", + _ => throw new UnreachableException() + }; + + code.AppendLine( + $"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {argumentsParameter});"""); + + break; + } + + default: + throw new UnreachableException(); + } + } + } + + private void GenerateQueryExecutor( + IndentedStringBuilder code, + int queryNum, + Expression queryExecutor, + HashSet namespaces, + HashSet unsafeAccessors) + { + // We're going to generate the method which will create the query executor (Func). + // Note that the we store the executor itself (and return it) as object, not as a typed Func. + // We can't strong-type it since it may return an anonymous type, which is unspeakable; so instead we cast down from object to + // the real strongly-typed signature inside the interceptor, where the return value is represented as a generic type parameter + // (which can be an anonymous type). + code + .AppendLine($"private static object Query{queryNum}_GenerateExecutor(DbContext dbContext, QueryContext queryContext)") + .AppendLine("{") + .IncrementIndent() + .AppendLine("var relationalModel = dbContext.Model.GetRelationalModel();") + .AppendLine("var relationalTypeMappingSource = dbContext.GetService();") + .AppendLine("var materializerLiftableConstantContext = new RelationalMaterializerLiftableConstantContext(dbContext.GetService(), dbContext.GetService());"); + + HashSet variableNames = ["relationalModel", "relationalTypeMappingSource", "materializerLiftableConstantContext"]; + + var materializerLiftableConstantContext = + Expression.Parameter(typeof(RelationalMaterializerLiftableConstantContext), "materializerLiftableConstantContext"); + + // The materializer expression tree contains LiftedConstantExpression nodes, which contain instructions on how to resolve + // constant values which need to be lifted. + var queryExecutorAfterLiftingExpression = + _liftableConstantProcessor.LiftConstants(queryExecutor, materializerLiftableConstantContext, variableNames); + + foreach (var liftedConstant in _liftableConstantProcessor.LiftedConstants) + { + var variableValueSyntax = _linqToCSharpTranslator.TranslateExpression( + liftedConstant.Expression, constantReplacements: null, namespaces, unsafeAccessors); + // code.AppendLine($"{liftedConstant.Parameter.Type.Name} {liftedConstant.Parameter.Name} = {variableValueSyntax.NormalizeWhitespace().ToFullString()};"); + code.AppendLine($"var {liftedConstant.Parameter.Name} = {variableValueSyntax.NormalizeWhitespace().ToFullString()};"); + } + + var queryExecutorSyntaxTree = + (AnonymousFunctionExpressionSyntax)_linqToCSharpTranslator.TranslateExpression( + queryExecutorAfterLiftingExpression, + constantReplacements: null, + namespaces, + unsafeAccessors); + + code + .AppendLine($"return {queryExecutorSyntaxTree.NormalizeWhitespace().ToFullString()};") + .DecrementIndent() + .AppendLine("}") + .AppendLine() + .AppendLine($"private static object Query{queryNum}_Executor;"); + } + + /// + /// Performs processing of a query's terminating operator before handing the query off for EF compilation. + /// This involves removing the operator when it shouldn't be in the tree (e.g. ToList()), and rewriting async terminating operators + /// to their sync counterparts (e.g. MaxAsync() -> Max()). This only needs to be modified/overridden if a new terminating operator + /// is introduced which needs to be rewritten. + /// + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual Expression PrepareQueryForCompilation(Expression penultimateOperator, MethodCallExpression terminatingOperator) + { + var method = terminatingOperator.Method; + + return method.Name switch + { + // These sync terminating operators are defined over IEnumerable, and don't inject a node into the query tree. Simply remove them. + nameof(Enumerable.AsEnumerable) + or nameof(Enumerable.ToArray) + or nameof(Enumerable.ToDictionary) + or nameof(Enumerable.ToHashSet) + or nameof(Enumerable.ToLookup) + or nameof(Enumerable.ToList) + when method.DeclaringType == typeof(Enumerable) + => penultimateOperator, + + nameof(IEnumerable.GetEnumerator) + when method.DeclaringType is { IsConstructedGenericType: true } declaringType + && declaringType.GetGenericTypeDefinition() == typeof(IEnumerable<>) + => penultimateOperator, + + // Async ToListAsync, ToArrayAsync and AsAsyncEnumerable don't inject a node into the query tree - remove these as well. + nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable) + or nameof(EntityFrameworkQueryableExtensions.ToArrayAsync) + or nameof(EntityFrameworkQueryableExtensions.ToDictionaryAsync) + or nameof(EntityFrameworkQueryableExtensions.ToHashSetAsync) + // or nameof(EntityFrameworkQueryableExtensions.ToLookupAsync) + or nameof(EntityFrameworkQueryableExtensions.ToListAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + => penultimateOperator, + + // There's also an instance method version of AsAsyncEnumerable on DbSet, remove that as well. + nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable) + when method.DeclaringType?.IsConstructedGenericType == true + && method.DeclaringType.GetGenericTypeDefinition() == typeof(DbSet<>) + => penultimateOperator, + + // The EF async counterparts to all the standard scalar-returning terminating operators. These need to be rewritten, as they + // inject the sync versions into the query tree. + nameof(EntityFrameworkQueryableExtensions.AllAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + => RewriteToSync(QueryableMethods.All), + nameof(EntityFrameworkQueryableExtensions.AnyAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.AnyWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.AnyAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.AnyWithPredicate), + nameof(EntityFrameworkQueryableExtensions.AverageAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync( + QueryableMethods.GetAverageWithoutSelector(method.GetParameters()[0].ParameterType.GenericTypeArguments[0])), + nameof(EntityFrameworkQueryableExtensions.AverageAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync( + QueryableMethods.GetAverageWithSelector( + method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])), + nameof(EntityFrameworkQueryableExtensions.ContainsAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + => RewriteToSync(QueryableMethods.Contains), + nameof(EntityFrameworkQueryableExtensions.CountAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.CountWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.CountAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.CountWithPredicate), + // nameof(EntityFrameworkQueryableExtensions.DefaultIfEmptyAsync) + nameof(EntityFrameworkQueryableExtensions.ElementAtAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + => RewriteToSync(QueryableMethods.ElementAt), + nameof(EntityFrameworkQueryableExtensions.ElementAtOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + => RewriteToSync(QueryableMethods.ElementAtOrDefault), + nameof(EntityFrameworkQueryableExtensions.FirstAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.FirstWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.FirstAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.FirstWithPredicate), + nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.FirstOrDefaultWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.FirstOrDefaultWithPredicate), + nameof(EntityFrameworkQueryableExtensions.LastAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.LastWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.LastAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.LastWithPredicate), + nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.LastOrDefaultWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.LastOrDefaultWithPredicate), + nameof(EntityFrameworkQueryableExtensions.LongCountAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.LongCountWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.LongCountAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.LongCountWithPredicate), + nameof(EntityFrameworkQueryableExtensions.MaxAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.MaxWithoutSelector), + nameof(EntityFrameworkQueryableExtensions.MaxAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.MaxWithSelector), + nameof(EntityFrameworkQueryableExtensions.MinAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.MinWithoutSelector), + nameof(EntityFrameworkQueryableExtensions.MinAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.MinWithSelector), + // nameof(EntityFrameworkQueryableExtensions.MaxByAsync) + // nameof(EntityFrameworkQueryableExtensions.MinByAsync) + nameof(EntityFrameworkQueryableExtensions.SingleAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.SingleWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.SingleAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.SingleWithPredicate), + nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.SingleOrDefaultWithoutPredicate), + nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync(QueryableMethods.SingleOrDefaultWithPredicate), + nameof(EntityFrameworkQueryableExtensions.SumAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 2 + => RewriteToSync(QueryableMethods.GetSumWithoutSelector(method.GetParameters()[0].ParameterType.GenericTypeArguments[0])), + nameof(EntityFrameworkQueryableExtensions.SumAsync) + when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) && method.GetParameters().Length == 3 + => RewriteToSync( + QueryableMethods.GetSumWithSelector( + method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])), + + // ExecuteDelete/Update behave just like other scalar-returning operators + nameof(RelationalQueryableExtensions.ExecuteDeleteAsync) when method.DeclaringType == typeof(RelationalQueryableExtensions) + => RewriteToSync(typeof(RelationalQueryableExtensions).GetMethod(nameof(RelationalQueryableExtensions.ExecuteDelete))), + nameof(RelationalQueryableExtensions.ExecuteUpdateAsync) when method.DeclaringType == typeof(RelationalQueryableExtensions) + => RewriteToSync(typeof(RelationalQueryableExtensions).GetMethod(nameof(RelationalQueryableExtensions.ExecuteUpdate))), + + // In the regular case (sync terminating operator which needs to stay in the query tree), simply compose the terminating + // operator over the penultimate and return that. + _ => terminatingOperator switch + { + // This is needed e.g. for GetEnumerator(), DbSet.AsAsyncEnumerable (non-static terminating operators) + { Object: Expression } + => terminatingOperator.Update(penultimateOperator, terminatingOperator.Arguments), + { Arguments: [_, ..] } + => terminatingOperator.Update(@object: null, [penultimateOperator, .. terminatingOperator.Arguments.Skip(1)]), + _ => throw new UnreachableException() + } + }; + + MethodCallExpression RewriteToSync(MethodInfo? syncMethod) + { + if (syncMethod is null) + { + throw new UnreachableException($"Could find replacement method for {method.Name}"); + } + + if (syncMethod.IsGenericMethodDefinition) + { + syncMethod = syncMethod.MakeGenericMethod(method.GetGenericArguments()); + } + + // Replace the first argument with the penultimate argument, and chop off the CancellationToken argument + Expression[] syncArguments = + [penultimateOperator, .. terminatingOperator.Arguments.Skip(1).Take(terminatingOperator.Arguments.Count - 2)]; + + return Expression.Call(terminatingOperator.Object, syncMethod, syncArguments); + } + } + + /// + /// Contains information on a failure to precompile a specific query in the user's source code. + /// Includes information about the query, its location, and the exception that occured. + /// + public sealed record QueryPrecompilationError(SyntaxNode SyntaxNode, Exception Exception); + + private readonly struct Symbols + { + private readonly Compilation _compilation; + + // ReSharper disable InconsistentNaming + public readonly INamedTypeSymbol GenericEnumerable; + public readonly INamedTypeSymbol GenericAsyncEnumerable; + public readonly INamedTypeSymbol GenericEnumerator; + public readonly INamedTypeSymbol IQueryable; + public readonly INamedTypeSymbol IOrderedQueryable; + public readonly INamedTypeSymbol IIncludableQueryable; + public readonly INamedTypeSymbol GenericTask; + // ReSharper restore InconsistentNaming + + private Symbols(Compilation compilation) + { + _compilation = compilation; + + GenericEnumerable = + GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerable`1"); + GenericAsyncEnumerable = + GetTypeSymbolOrThrow("System.Collections.Generic.IAsyncEnumerable`1"); + GenericEnumerator = + GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerator`1"); + IQueryable = + GetTypeSymbolOrThrow("System.Linq.IQueryable`1"); + IOrderedQueryable = + GetTypeSymbolOrThrow("System.Linq.IOrderedQueryable`1"); + IIncludableQueryable = + GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.Query.IIncludableQueryable`2"); + GenericTask = + GetTypeSymbolOrThrow("System.Threading.Tasks.Task`1"); + } + + public static Symbols Load(Compilation compilation) + => new(compilation); + + private INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName) + => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName) + ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName); + } + + /// + /// A generated file containing LINQ operator interceptors. + /// + /// The path of the generated file. + /// The code of the generated file. + public sealed record GeneratedInterceptorFile(string Path, string Code); +} diff --git a/src/EFCore.Design/Query/Internal/QueryLocator.cs b/src/EFCore.Design/Query/Internal/QueryLocator.cs new file mode 100644 index 00000000000..3270d0c74e9 --- /dev/null +++ b/src/EFCore.Design/Query/Internal/QueryLocator.cs @@ -0,0 +1,371 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.EntityFrameworkCore.Internal; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// Statically analyzes user code and locates EF LINQ queries within it, by identifying well-known terminating operators +/// (e.g. ToList, Single). +/// +/// +/// After a is loaded via , is called repeatedly +/// for all syntax trees in the compilation. +/// +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class QueryLocator : CSharpSyntaxWalker +{ + private Compilation? _compilation; + private Symbols _symbols; + + private SemanticModel _semanticModel = null!; + private CancellationToken _cancellationToken; + private List _locatedQueries = null!; + private List _precompilationErrors = null!; + + + /// + /// Loads a new , representing a user project in which to locate queries. + /// + /// A representing a user project. + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void Initialize(Compilation compilation) + { + _compilation = compilation; + _symbols = Symbols.Load(compilation); + } + + /// + /// Locates EF LINQ queries within the given , which represents user code. + /// + /// A in which to locate EF LINQ queries. + /// + /// A list of errors populated with dynamic LINQ queries detected in . + /// + /// A to observe while waiting for the task to complete. + /// A list of EF LINQ queries confirmed to be compatible with precompilation. + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IReadOnlyList LocateQueries( + SyntaxTree syntaxTree, + List precompilationErrors, + CancellationToken cancellationToken = default) + { + if (_compilation is null) + { + throw new InvalidOperationException("A compilation must be loaded."); + } + + if (!_compilation.SyntaxTrees.Contains(syntaxTree)) + { + throw new ArgumentException(""); + } + + _cancellationToken = cancellationToken; + _semanticModel = _compilation.GetSemanticModel(syntaxTree); + _locatedQueries = new(); + _precompilationErrors = precompilationErrors; + Visit(syntaxTree.GetRoot(cancellationToken)); + + return _locatedQueries; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override void VisitInvocationExpression(InvocationExpressionSyntax invocation) + { + // TODO: Support non-extension invocation syntax: var blogs = ToList(ctx.Blogs); + if (invocation.Expression is MemberAccessExpressionSyntax + { + Name: IdentifierNameSyntax { Identifier.Text: var identifier }, + Expression: var innerExpression + }) + { + // First, pattern-match on the method name as a string; this avoids accessing the semantic model for each and + // every invocation (more efficient). + switch (identifier) + { + // These sync terminating operators exist exist over IEnumerable only, so verify the actual argument is an IQueryable (otherwise + // this is just LINQ to Objects) + case nameof(Enumerable.AsEnumerable) or nameof(Enumerable.ToArray) or nameof(Enumerable.ToDictionary) + or nameof(Enumerable.ToHashSet) or nameof(Enumerable.ToLookup) or nameof(Enumerable.ToList) + when IsOnEnumerable() && IsQueryable(innerExpression): + + case nameof(IEnumerable.GetEnumerator) + when IsOnIEnumerable() && IsQueryable(innerExpression): + + // The async terminating operators are defined by EF, and accept an IQueryable - no need to look at the argument. + case nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable) + or nameof(EntityFrameworkQueryableExtensions.ToArrayAsync) + or nameof(EntityFrameworkQueryableExtensions.ToDictionaryAsync) + or nameof(EntityFrameworkQueryableExtensions.ToHashSetAsync) + // or nameof(EntityFrameworkQueryableExtensions.ToLookupAsync) + or nameof(EntityFrameworkQueryableExtensions.ToListAsync) + when IsOnEfQueryableExtensions(): + + case nameof(EntityFrameworkQueryableExtensions.AsAsyncEnumerable) + when IsOnEfQueryableExtensions() || IsOnTypeSymbol(_symbols.DbSet): + + case nameof(Queryable.All) + or nameof(Queryable.Any) + or nameof(Queryable.Average) + or nameof(Queryable.Contains) + or nameof(Queryable.Count) + or nameof(Queryable.DefaultIfEmpty) + or nameof(Queryable.ElementAt) + or nameof(Queryable.ElementAtOrDefault) + or nameof(Queryable.First) + or nameof(Queryable.FirstOrDefault) + or nameof(Queryable.Last) + or nameof(Queryable.LastOrDefault) + or nameof(Queryable.LongCount) + or nameof(Queryable.Max) + or nameof(Queryable.MaxBy) + or nameof(Queryable.Min) + or nameof(Queryable.MinBy) + or nameof(Queryable.Single) + or nameof(Queryable.SingleOrDefault) + or nameof(Queryable.Sum) + when IsOnQueryable(): + + case nameof(EntityFrameworkQueryableExtensions.AllAsync) + or nameof(EntityFrameworkQueryableExtensions.AnyAsync) + or nameof(EntityFrameworkQueryableExtensions.AverageAsync) + or nameof(EntityFrameworkQueryableExtensions.ContainsAsync) + or nameof(EntityFrameworkQueryableExtensions.CountAsync) + // or nameof(EntityFrameworkQueryableExtensions.DefaultIfEmptyAsync) + or nameof(EntityFrameworkQueryableExtensions.ElementAtAsync) + or nameof(EntityFrameworkQueryableExtensions.ElementAtOrDefaultAsync) + or nameof(EntityFrameworkQueryableExtensions.FirstAsync) + or nameof(EntityFrameworkQueryableExtensions.FirstOrDefaultAsync) + or nameof(EntityFrameworkQueryableExtensions.LastAsync) + or nameof(EntityFrameworkQueryableExtensions.LastOrDefaultAsync) + or nameof(EntityFrameworkQueryableExtensions.LongCountAsync) + or nameof(EntityFrameworkQueryableExtensions.MaxAsync) + // or nameof(EntityFrameworkQueryableExtensions.MaxByAsync) + or nameof(EntityFrameworkQueryableExtensions.MinAsync) + // or nameof(EntityFrameworkQueryableExtensions.MinByAsync) + or nameof(EntityFrameworkQueryableExtensions.SingleAsync) + or nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync) + or nameof(EntityFrameworkQueryableExtensions.SumAsync) + or nameof(EntityFrameworkQueryableExtensions.ForEachAsync) + when IsOnEfQueryableExtensions(): + + case nameof(RelationalQueryableExtensions.ExecuteDelete) + or nameof(RelationalQueryableExtensions.ExecuteUpdate) + or nameof(RelationalQueryableExtensions.ExecuteDeleteAsync) + or nameof(RelationalQueryableExtensions.ExecuteUpdateAsync) + when IsOnEfRelationalQueryableExtensions(): + if (ProcessQueryCandidate(invocation)) + { + return; + } + + break; + } + } + + base.VisitInvocationExpression(invocation); + + bool IsOnEnumerable() + => IsOnTypeSymbol(_symbols.Enumerable); + + bool IsOnIEnumerable() + => IsOnTypeSymbol(_symbols.IEnumerableOfT); + + bool IsOnQueryable() + => IsOnTypeSymbol(_symbols.Queryable); + + bool IsOnEfQueryableExtensions() + => IsOnTypeSymbol(_symbols.EfQueryableExtensions); + + bool IsOnEfRelationalQueryableExtensions() + => IsOnTypeSymbol(_symbols.EfRelationalQueryableExtensions); + + bool IsOnTypeSymbol(ITypeSymbol typeSymbol) + => _semanticModel.GetSymbolInfo(invocation, _cancellationToken).Symbol is IMethodSymbol methodSymbol + && methodSymbol.ContainingType.OriginalDefinition.Equals(typeSymbol, SymbolEqualityComparer.Default); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override void VisitForEachStatement(ForEachStatementSyntax forEach) + { + // Note: a LINQ queryable can't be placed directly inside await foreach, since IQueryable does not extend + // IAsyncEnumerable. So users need to add our AsAsyncEnumerable, which is detected above as a normal invocation. + + // C# interceptors can (currently) intercept only method calls, not property accesses; this means that we can't + // TODO: Support DbSet() method call directly inside foreach/await foreach + // TODO: Warn for DbSet property access directly inside foreach (can't be intercepted so not supported) + if (forEach.Expression is InvocationExpressionSyntax invocation + && IsQueryable(invocation) + && ProcessQueryCandidate(invocation)) + { + return; + } + + base.VisitForEachStatement(forEach); + } + + private bool ProcessQueryCandidate(InvocationExpressionSyntax query) + { + // TODO: Carefully think about exactly what kind of verification we want to do here: static/non-static, actually get the + // TODO: method symbols and confirm it's an IQueryable flowing all the way through, etc. + // TODO: Move this code out, for reuse in the inner loop source generator + + // Work backwards through the LINQ operator chain until we reach something that isn't a method invocation + ExpressionSyntax operatorSyntax = query; + while (operatorSyntax is InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax { Expression: var innerExpression } + }) + { + if (innerExpression is QueryExpressionSyntax or ParenthesizedExpressionSyntax { Expression: QueryExpressionSyntax }) + { + _precompilationErrors.Add( + new(query, new InvalidOperationException(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries))); + return false; + } + + operatorSyntax = innerExpression; + } + + // We've reached a non-invocation. + + // First, check if this is a property access for a DbSet + if (operatorSyntax is MemberAccessExpressionSyntax { Expression: var innerExpression2 } + && IsDbContext(innerExpression2)) + { + _locatedQueries.Add(query); + + // TODO: Check symbol for DbSet? + return true; + } + + // If we had context.Set(), the Set() method was skipped like any other method, and we're on the context. + if (IsDbContext(operatorSyntax)) + { + _locatedQueries.Add(query); + return true; + } + + _precompilationErrors.Add(new(query, new InvalidOperationException(DesignStrings.DynamicQueryNotSupported))); + return false; + + bool IsDbContext(ExpressionSyntax expression) + { + return _semanticModel.GetSymbolInfo(expression, _cancellationToken).Symbol switch + { + ILocalSymbol localSymbol => IsDbContextType(localSymbol.Type), + IPropertySymbol propertySymbol => IsDbContextType(propertySymbol.Type), + IFieldSymbol fieldSymbol => IsDbContextType(fieldSymbol.Type), + IMethodSymbol methodSymbol => IsDbContextType(methodSymbol.ReturnType), + _ => false + }; + + bool IsDbContextType(ITypeSymbol typeSymbol) + { + while (true) + { + // TODO: Check for the user's specific DbContext type + if (typeSymbol.Equals(_symbols.DbContext, SymbolEqualityComparer.Default)) + { + return true; + } + + if (typeSymbol.BaseType is null) + { + return false; + } + + typeSymbol = typeSymbol.BaseType; + } + } + } + } + + private bool IsQueryable(ExpressionSyntax expression) + => _semanticModel.GetSymbolInfo(expression, _cancellationToken).Symbol switch + { + IMethodSymbol methodSymbol + => methodSymbol.ReturnType.OriginalDefinition.Equals(_symbols.IQueryableOfT, SymbolEqualityComparer.Default) + || methodSymbol.ReturnType.OriginalDefinition.AllInterfaces + .Contains(_symbols.IQueryable, SymbolEqualityComparer.Default), + + IPropertySymbol propertySymbol => IsDbSet(propertySymbol.Type), + + _ => false + }; + + // TODO: Handle DbSet subclasses which aren't InternalDbSet? + private bool IsDbSet(ITypeSymbol typeSymbol) + => SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _symbols.DbSet); + + private readonly struct Symbols + { + private readonly Compilation _compilation; + + // ReSharper disable InconsistentNaming + public readonly INamedTypeSymbol IQueryableOfT; + public readonly INamedTypeSymbol IQueryable; + public readonly INamedTypeSymbol DbContext; + public readonly INamedTypeSymbol DbSet; + + public readonly INamedTypeSymbol Enumerable; + public readonly INamedTypeSymbol IEnumerableOfT; + public readonly INamedTypeSymbol Queryable; + public readonly INamedTypeSymbol EfQueryableExtensions; + public readonly INamedTypeSymbol EfRelationalQueryableExtensions; + // ReSharper restore InconsistentNaming + + private Symbols(Compilation compilation) + { + _compilation = compilation; + + IQueryableOfT = GetTypeSymbolOrThrow("System.Linq.IQueryable`1"); + IQueryable = GetTypeSymbolOrThrow("System.Linq.IQueryable"); + DbContext = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.DbContext"); + DbSet = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.DbSet`1"); + + Enumerable = GetTypeSymbolOrThrow("System.Linq.Enumerable"); + IEnumerableOfT = GetTypeSymbolOrThrow("System.Collections.Generic.IEnumerable`1"); + Queryable = GetTypeSymbolOrThrow("System.Linq.Queryable"); + EfQueryableExtensions = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions"); + EfRelationalQueryableExtensions = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.RelationalQueryableExtensions"); + } + + public static Symbols Load(Compilation compilation) + => new(compilation); + + private INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName) + => _compilation.GetTypeByMetadataName(fullyQualifiedMetadataName) + ?? throw new InvalidOperationException("Could not find type symbol for: " + fullyQualifiedMetadataName); + } +} diff --git a/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs b/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs index 4036908d8f0..2ab5978e829 100644 --- a/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs +++ b/src/EFCore.Design/Query/Internal/RuntimeModelLinqToCSharpSyntaxTranslator.cs @@ -44,10 +44,11 @@ public RuntimeModelLinqToCSharpSyntaxTranslator(SyntaxGenerator syntaxGenerator) Expression node, IReadOnlyDictionary? constantReplacements, IReadOnlyDictionary? memberAccessReplacements, - ISet collectedNamespaces) + ISet collectedNamespaces, + ISet unsafeAccessors) { _memberAccessReplacements = memberAccessReplacements; - var result = TranslateStatement(node, constantReplacements, collectedNamespaces); + var result = TranslateStatement(node, constantReplacements, collectedNamespaces, unsafeAccessors); _memberAccessReplacements = null; return result; } @@ -62,10 +63,11 @@ public RuntimeModelLinqToCSharpSyntaxTranslator(SyntaxGenerator syntaxGenerator) Expression node, IReadOnlyDictionary? constantReplacements, IReadOnlyDictionary? memberAccessReplacements, - ISet collectedNamespaces) + ISet collectedNamespaces, + ISet unsafeAccessors) { _memberAccessReplacements = memberAccessReplacements; - var result = TranslateExpression(node, constantReplacements, collectedNamespaces); + var result = TranslateExpression(node, constantReplacements, collectedNamespaces, unsafeAccessors); _memberAccessReplacements = null; return result; } @@ -120,7 +122,7 @@ protected override void TranslateNonPublicMemberAccess(MemberExpression memberEx protected override void TranslateNonPublicMemberAssignment(MemberExpression memberExpression, Expression value, SyntaxKind assignmentKind) { var propertyInfo = memberExpression.Member as PropertyInfo; - var member = propertyInfo?.SetMethod! ?? memberExpression.Member; + var member = propertyInfo?.SetMethod ?? memberExpression.Member; if (_memberAccessReplacements?.TryGetValue(member, out var methodName) == true) { AddNamespace(methodName.Namespace); @@ -134,10 +136,10 @@ protected override void TranslateNonPublicMemberAssignment(MemberExpression memb Result = InvocationExpression( IdentifierName(methodName.Name), ArgumentList(SeparatedList(new[] - { - Argument(Translate(memberExpression.Expression)), - Argument(Translate(value)) - }))); + { + Argument(Translate(memberExpression.Expression)), + Argument(Translate(value)) + }))); } else { diff --git a/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs b/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs index 88d7f2fe530..b46304c7855 100644 --- a/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs +++ b/src/EFCore.Design/Scaffolding/Internal/CSharpRuntimeModelCodeGenerator.cs @@ -1074,16 +1074,19 @@ private void out var structuralGetterExpression, out var hasStructuralSentinelExpression); + // TODO + var unsafeAccessors = new HashSet(); + mainBuilder .Append(variableName).AppendLine(".SetGetter(") .IncrementIndent() - .AppendLines(_code.Expression(getterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(getterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") - .AppendLines(_code.Expression(hasSentinelExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(hasSentinelExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") - .AppendLines(_code.Expression(structuralGetterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(structuralGetterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") - .AppendLines(_code.Expression(hasStructuralSentinelExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(hasStructuralSentinelExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -1092,7 +1095,7 @@ private void mainBuilder .Append(variableName).AppendLine(".SetSetter(") .IncrementIndent() - .AppendLines(_code.Expression(setterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(setterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -1101,7 +1104,7 @@ private void mainBuilder .Append(variableName).AppendLine(".SetMaterializationSetter(") .IncrementIndent() - .AppendLines(_code.Expression(materializationSetterExpression, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(materializationSetterExpression, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -1115,19 +1118,19 @@ private void mainBuilder .Append(variableName).AppendLine(".SetAccessors(") .IncrementIndent() - .AppendLines(_code.Expression(currentValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(currentValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") - .AppendLines(_code.Expression(preStoreGeneratedCurrentValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(preStoreGeneratedCurrentValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(originalValueGetter == null ? "null" - : _code.Expression(originalValueGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(originalValueGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") - .AppendLines(_code.Expression(relationshipSnapshotGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(relationshipSnapshotGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(valueBufferGetter == null ? "null" - : _code.Expression(valueBufferGetter, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(valueBufferGetter, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); } @@ -1918,6 +1921,9 @@ private void out var createAndSetCollection, out var createCollection); + // TODO + var unsafeAccessors = new HashSet(); + AddNamespace(propertyType, parameters.Namespaces); mainBuilder .Append(parameters.TargetName) @@ -1925,23 +1931,23 @@ private void .IncrementIndent() .AppendLines(getCollection == null ? "null" - : _code.Expression(getCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(getCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(setCollection == null ? "null" - : _code.Expression(setCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(setCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(setCollectionForMaterialization == null ? "null" - : _code.Expression(setCollectionForMaterialization, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(setCollectionForMaterialization, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(createAndSetCollection == null ? "null" - : _code.Expression(createAndSetCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(createAndSetCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(",") .AppendLines(createCollection == null ? "null" - : _code.Expression(createCollection, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + : _code.Expression(createCollection, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); } @@ -2142,11 +2148,14 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame var runtimeType = (IRuntimeEntityType)entityType; + // TODO + var unsafeAccessors = new HashSet(); + var originalValuesFactory = OriginalValuesFactoryFactory.Instance.CreateExpression(runtimeType); mainBuilder .Append(parameters.TargetName).AppendLine(".SetOriginalValuesFactory(") .IncrementIndent() - .AppendLines(_code.Expression(originalValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(originalValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -2154,7 +2163,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame mainBuilder .Append(parameters.TargetName).AppendLine(".SetStoreGeneratedValuesFactory(") .IncrementIndent() - .AppendLines(_code.Expression(storeGeneratedValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(storeGeneratedValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -2162,7 +2171,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame mainBuilder .Append(parameters.TargetName).AppendLine(".SetTemporaryValuesFactory(") .IncrementIndent() - .AppendLines(_code.Expression(temporaryValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(temporaryValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -2170,7 +2179,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame mainBuilder .Append(parameters.TargetName).AppendLine(".SetShadowValuesFactory(") .IncrementIndent() - .AppendLines(_code.Expression(shadowValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(shadowValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -2178,7 +2187,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame mainBuilder .Append(parameters.TargetName).AppendLine(".SetEmptyShadowValuesFactory(") .IncrementIndent() - .AppendLines(_code.Expression(emptyShadowValuesFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(emptyShadowValuesFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); @@ -2186,7 +2195,7 @@ private void Create(ITrigger trigger, CSharpRuntimeAnnotationCodeGeneratorParame mainBuilder .Append(parameters.TargetName).AppendLine(".SetRelationshipSnapshotFactory(") .IncrementIndent() - .AppendLines(_code.Expression(relationshipSnapshotFactory, parameters.Namespaces, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) + .AppendLines(_code.Expression(relationshipSnapshotFactory, parameters.Namespaces, unsafeAccessors, (IReadOnlyDictionary)parameters.ScopeVariables, memberAccessReplacements), skipFinalNewline: true) .AppendLine(");") .DecrementIndent(); diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs index 85715b76f2c..f2abc516164 100644 --- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs @@ -182,7 +182,14 @@ public static DbCommand CreateDbCommand(this IQueryable source) sql.GetArguments())); } - private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] + public static FromSqlQueryRootExpression GenerateFromSqlQueryRoot( IQueryable source, string sql, object?[] arguments, diff --git a/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs index 90f2fdd6636..d7df2bb5eb2 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalQueryCompilationContextFactory.cs @@ -35,6 +35,15 @@ public class RelationalQueryCompilationContextFactory : IQueryCompilationContext /// protected virtual RelationalQueryCompilationContextDependencies RelationalDependencies { get; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryCompilationContext Create(bool async, bool precompiling) + => new RelationalQueryCompilationContext(Dependencies, RelationalDependencies, async, precompiling); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -42,5 +51,5 @@ public class RelationalQueryCompilationContextFactory : IQueryCompilationContext /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual QueryCompilationContext Create(bool async) - => new RelationalQueryCompilationContext(Dependencies, RelationalDependencies, async); + => throw new UnreachableException("The overload with `precompiling` should be called"); } diff --git a/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs index d4d0ff8ee33..826f26d027a 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalValueConverterCompensatingExpressionVisitor.cs @@ -85,7 +85,7 @@ private Expression VisitSelect(SelectExpression selectExpression) var orderings = this.VisitAndConvert(selectExpression.Orderings); var offset = (SqlExpression?)Visit(selectExpression.Offset); var limit = (SqlExpression?)Visit(selectExpression.Limit); - return selectExpression.Update(projections, tables, predicate, groupBy, having, orderings, limit, offset); + return selectExpression.Update(tables, predicate, groupBy, having, projections, orderings, offset, limit); } private Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression) diff --git a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs index 970c8eda4b2..2f3935a8184 100644 --- a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs +++ b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs @@ -22,11 +22,13 @@ public class RelationalQueryCompilationContext : QueryCompilationContext /// Parameter object containing dependencies for this class. /// Parameter object containing relational dependencies for this class. /// A bool value indicating whether it is for async query. + /// Indicates whether the query is being precompiled. public RelationalQueryCompilationContext( QueryCompilationContextDependencies dependencies, RelationalQueryCompilationContextDependencies relationalDependencies, - bool async) - : base(dependencies, async) + bool async, + bool precompiling) + : base(dependencies, async, precompiling) { RelationalDependencies = relationalDependencies; QuerySplittingBehavior = RelationalOptionsExtension.Extract(ContextOptions).QuerySplittingBehavior; diff --git a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs index 9a87151efe4..88f40c34035 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs @@ -25,7 +25,14 @@ public DeleteExpression(TableExpression table, SelectExpression selectExpression { } - private DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags) + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags) { Table = table; SelectExpression = selectExpression; diff --git a/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs index 06388fcdcfa..e7a4be40eee 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/ExceptExpression.cs @@ -32,7 +32,14 @@ public class ExceptExpression : SetOperationBase { } - private ExceptExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public ExceptExpression( string alias, SelectExpression source1, SelectExpression source2, diff --git a/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs index a94b187b8ed..d42b54c186a 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/ExistsExpression.cs @@ -55,7 +55,7 @@ public virtual ExistsExpression Update(SelectExpression subquery) public override Expression Quote() => New( _quotingConstructor ??= - typeof(ExistsExpression).GetConstructor([typeof(SelectExpression), typeof(bool), typeof(RelationalTypeMapping)])!, + typeof(ExistsExpression).GetConstructor([typeof(SelectExpression), typeof(RelationalTypeMapping)])!, Subquery.Quote(), RelationalExpressionQuotingUtilities.QuoteTypeMapping(TypeMapping)); diff --git a/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs index d21c4a1f834..b1050066e0f 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/InnerJoinExpression.cs @@ -27,7 +27,14 @@ public InnerJoinExpression(TableExpressionBase table, SqlExpression joinPredicat { } - private InnerJoinExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public InnerJoinExpression( TableExpressionBase table, SqlExpression joinPredicate, bool prunable, diff --git a/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs index 2a3327e66a2..85009fc32c2 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/IntersectExpression.cs @@ -32,7 +32,14 @@ public class IntersectExpression : SetOperationBase { } - private IntersectExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public IntersectExpression( string alias, SelectExpression source1, SelectExpression source2, diff --git a/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs index e359be4355a..a5e2ea0ebc5 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/LeftJoinExpression.cs @@ -27,7 +27,14 @@ public LeftJoinExpression(TableExpressionBase table, SqlExpression joinPredicate { } - private LeftJoinExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public LeftJoinExpression( TableExpressionBase table, SqlExpression joinPredicate, bool prunable, diff --git a/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs index 91e319011f9..6692ce4cc51 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/OuterApplyExpression.cs @@ -25,7 +25,14 @@ public OuterApplyExpression(TableExpressionBase table) { } - private OuterApplyExpression(TableExpressionBase table, IReadOnlyDictionary? annotations) + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public OuterApplyExpression(TableExpressionBase table, IReadOnlyDictionary? annotations) : base(table, prunable: false, annotations) { } diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index 82031d038d3..82bd96b7d91 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -595,14 +595,14 @@ protected override Expression VisitExtension(Expression expression) return base.VisitExtension( selectExpression.Update( - selectExpression.Projection, visitedTables ?? selectExpression.Tables, selectExpression.Predicate, selectExpression.GroupBy, selectExpression.Having, + selectExpression.Projection, selectExpression.Orderings, - selectExpression.Limit, - selectExpression.Offset)); + selectExpression.Offset, + selectExpression.Limit)); } private void RemapProjections(int[]? map, SelectExpression selectExpression) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 2334d206162..7720b6b8e8c 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -63,18 +63,75 @@ public sealed partial class SelectExpression : TableExpressionBase public SelectExpression( string? alias, List tables, + SqlExpression? predicate, List groupBy, + SqlExpression? having, List projections, + bool distinct, List orderings, + SqlExpression? offset, + SqlExpression? limit, + ISet tags, IReadOnlyDictionary? annotations, - SqlAliasManager sqlAliasManager) + SqlAliasManager? sqlAliasManager, + bool isMutable) : base(alias, annotations) { - _projection = projections; + Check.DebugAssert(!(isMutable && sqlAliasManager is null), "Need SqlAliasManager when the SelectExpression is mutable"); + _tables = tables; + Predicate = predicate; _groupBy = groupBy; + Having = having; + _projection = projections; + IsDistinct = distinct; _orderings = orderings; - _sqlAliasManager = sqlAliasManager; + Offset = offset; + Limit = limit; + Tags = tags; + IsMutable = isMutable; + + _sqlAliasManager = sqlAliasManager!; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] + public SelectExpression( + string? alias, + IReadOnlyList tables, + SqlExpression? predicate, + IReadOnlyList groupBy, + SqlExpression? having, + IReadOnlyList projections, + bool distinct, + IReadOnlyList orderings, + SqlExpression? offset, + SqlExpression? limit, + IReadOnlySet tags, + IReadOnlyDictionary? annotations) + : this(alias, tables.ToList(), predicate, groupBy.ToList(), having, projections.ToList(), distinct, orderings.ToList(), + offset, limit, tags.ToHashSet(), annotations, sqlAliasManager: null, isMutable: false) + { + } + + private SelectExpression( + string? alias, + List tables, + List groupBy, + List projections, + List orderings, + IReadOnlyDictionary? annotations, + SqlAliasManager sqlAliasManager) + : this( + alias, tables, predicate: null, groupBy: groupBy, having: null, projections: projections, distinct: false, orderings: orderings, offset: null, + limit: null, tags: new HashSet(), + annotations: annotations, sqlAliasManager: sqlAliasManager, isMutable: true) + { } /// @@ -119,7 +176,11 @@ public SelectExpression(SqlExpression projection, SqlAliasManager sqlAliasManage // should have an alias manager at all, so this is temporary). [EntityFrameworkInternal] public static SelectExpression CreateImmutable(string alias, List tables, List projection) - => new(alias, tables, groupBy: [], projections: projection, orderings: [], annotations: null, sqlAliasManager: null!) { IsMutable = false }; + => new( + alias, tables, predicate: null, groupBy: [], having: null, projections: projection, distinct: false, orderings: [], + offset: null, limit: null, + tags: new HashSet(), sqlAliasManager: null, annotations: new Dictionary(), + isMutable: false); /// /// The list of tags applied to this . @@ -3753,17 +3814,11 @@ private TableExpressionBase Clone(string? alias, ExpressionVisitor cloningExpres var limit = (SqlExpression?)cloningExpressionVisitor.Visit(Limit); var newSelectExpression = new SelectExpression( - alias, newTables, newGroupBy, newProjections, newOrderings, Annotations, _sqlAliasManager) + alias, newTables, predicate, newGroupBy, havingExpression, newProjections, IsDistinct, newOrderings, offset, limit, + Tags, Annotations, _sqlAliasManager, IsMutable) { - Predicate = predicate, - Having = havingExpression, - Offset = offset, - Limit = limit, - IsDistinct = IsDistinct, - Tags = Tags, _projectionMapping = newProjectionMappings, _clientProjections = newClientProjections, - IsMutable = IsMutable }; foreach (var (column, comparer) in _identifier) @@ -4055,17 +4110,11 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) if (changed) { var newSelectExpression = new SelectExpression( - Alias, newTables, newGroupBy, newProjections, newOrderings, Annotations, _sqlAliasManager) + Alias, newTables, predicate, newGroupBy, havingExpression, newProjections, IsDistinct, newOrderings, offset, + limit, (IReadOnlySet)Tags, Annotations) { _clientProjections = _clientProjections, - _projectionMapping = _projectionMapping, - Predicate = predicate, - Having = havingExpression, - Offset = offset, - Limit = limit, - IsDistinct = IsDistinct, - Tags = Tags, - IsMutable = false + _projectionMapping = _projectionMapping }; newSelectExpression._identifier.AddRange(identifier.Zip(_identifier).Select(e => (e.First, e.Second.Comparer))); @@ -4124,25 +4173,25 @@ List VisitList(List list, bool inPlace, out bool changed) /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will /// return this expression. /// - /// The property of the result. /// The property of the result. /// The property of the result. /// The property of the result. /// The property of the result. + /// The property of the result. /// The property of the result. - /// The property of the result. /// The property of the result. + /// The property of the result. /// This expression if no children changed, or an expression with the updated children. // This does not take internal states since when using this method SelectExpression should be finalized public SelectExpression Update( - IReadOnlyList projections, IReadOnlyList tables, SqlExpression? predicate, IReadOnlyList groupBy, SqlExpression? having, + IReadOnlyList projections, IReadOnlyList orderings, - SqlExpression? limit, - SqlExpression? offset) + SqlExpression? offset, + SqlExpression? limit) { if (IsMutable) { @@ -4155,8 +4204,8 @@ List VisitList(List list, bool inPlace, out bool changed) && groupBy == GroupBy && having == Having && orderings == Orderings - && limit == Limit - && offset == Offset) + && offset == Offset + && limit == Limit) { return this; } @@ -4168,17 +4217,11 @@ List VisitList(List list, bool inPlace, out bool changed) } var newSelectExpression = new SelectExpression( - Alias, tables.ToList(), groupBy.ToList(), projections.ToList(), orderings.ToList(), Annotations, _sqlAliasManager) + Alias, tables, predicate, groupBy, having, projections, IsDistinct, orderings, offset, limit, + (IReadOnlySet)Tags, Annotations) { _projectionMapping = projectionMapping, - _clientProjections = _clientProjections.ToList(), - Predicate = predicate, - Having = having, - Offset = offset, - Limit = limit, - IsDistinct = IsDistinct, - Tags = Tags, - IsMutable = false + _clientProjections = _clientProjections.ToList() }; // We don't copy identifiers because when we are doing reconstruction so projection is already applied. @@ -4196,17 +4239,11 @@ public override SelectExpression WithAlias(string newAlias) { Check.DebugAssert(!IsMutable, "Can't change alias on mutable SelectExpression"); - return new SelectExpression(newAlias, _tables, _groupBy, _projection, _orderings, Annotations, _sqlAliasManager) + return new SelectExpression( + newAlias, _tables, Predicate, _groupBy, Having, _projection, IsDistinct, _orderings, Offset, Limit, Tags, + Annotations, _sqlAliasManager, isMutable: false) { - _projectionMapping = _projectionMapping, - _clientProjections = _clientProjections.ToList(), - Predicate = Predicate, - Having = Having, - Offset = Offset, - Limit = Limit, - IsDistinct = IsDistinct, - Tags = Tags, - IsMutable = false + _projectionMapping = _projectionMapping, _clientProjections = _clientProjections.ToList(), }; } @@ -4223,8 +4260,8 @@ public override Expression Quote() typeof(IReadOnlyList), // projections typeof(bool), // distinct typeof(IReadOnlyList), // orderings - typeof(SqlExpression), // limit typeof(SqlExpression), // offset + typeof(SqlExpression), // limit typeof(IReadOnlySet), // tags typeof(IReadOnlyDictionary) // annotations ])!, @@ -4238,8 +4275,8 @@ public override Expression Quote() NewArrayInit(typeof(ProjectionExpression), initializers: Projection.Select(p => p.Quote())), Constant(IsDistinct), NewArrayInit(typeof(OrderingExpression), initializers: Orderings.Select(o => o.Quote())), - RelationalExpressionQuotingUtilities.QuoteOrNull(Limit), RelationalExpressionQuotingUtilities.QuoteOrNull(Offset), + RelationalExpressionQuotingUtilities.QuoteOrNull(Limit), RelationalExpressionQuotingUtilities.QuoteTags(Tags), RelationalExpressionQuotingUtilities.QuoteAnnotations(Annotations)); diff --git a/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs index 02f8df8cd9a..ff3dcb48ce3 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SqlFunctionExpression.cs @@ -183,7 +183,14 @@ public class SqlFunctionExpression : SqlExpression { } - private SqlFunctionExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public SqlFunctionExpression( SqlExpression? instance, string? schema, string name, diff --git a/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs index 16db24710f0..70b83072cdb 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/UnionExpression.cs @@ -32,7 +32,14 @@ public class UnionExpression : SetOperationBase { } - private UnionExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public UnionExpression( string alias, SelectExpression source1, SelectExpression source2, diff --git a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs index 598f8be6df9..3b0041d2ac5 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs @@ -31,7 +31,14 @@ public UpdateExpression(TableExpression table, SelectExpression selectExpression { } - private UpdateExpression( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // For precompiled queries + public UpdateExpression( TableExpression table, SelectExpression selectExpression, IReadOnlyList columnValueSetters, diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 2f8e60bb01f..ad732e432e7 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -367,7 +367,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression, bool var limit = Visit(selectExpression.Limit, out _); - return selectExpression.Update(projections, tables, predicate, groupBy, having, orderings, limit, offset); + return selectExpression.Update(tables, predicate, groupBy, having, projections, orderings, offset, limit); } /// @@ -670,9 +670,14 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt var projectionExpression = Visit(subqueryProjection, allowOptimizedExpansion, out var projectionNullable); inExpression = inExpression.Update( item, subquery.Update( - [subquery.Projection[0].Update(projectionExpression)], - subquery.Tables, subquery.Predicate, subquery.GroupBy, subquery.Having, subquery.Orderings, subquery.Limit, - subquery.Offset)); + subquery.Tables, + subquery.Predicate, + subquery.GroupBy, + subquery.Having, + projections: [subquery.Projection[0].Update(projectionExpression)], + subquery.Orderings, + subquery.Offset, + subquery.Limit)); if (UseRelationalNulls) { @@ -779,14 +784,14 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt // No need for a projection with EXISTS, clear it to get SELECT 1 subquery = subquery.Update( - [], subquery.Tables, subquery.Predicate, subquery.GroupBy, subquery.Having, + [], subquery.Orderings, - subquery.Limit, - subquery.Offset); + subquery.Offset, + subquery.Limit); var predicate = VisitSqlBinary( _sqlExpressionFactory.Equal(subqueryProjection, item), allowOptimizedExpansion: true, out _); @@ -2077,14 +2082,15 @@ static bool TryNegate(ExpressionType expressionType, out ExpressionType result) #pragma warning restore EF1001 rewrittenSelectExpression = rewrittenSelectExpression.Update( - projection, // TODO: We should change the project column to be non-nullable, but it's too closed down for that. new[] { rewrittenCollectionTable }, selectExpression.Predicate, selectExpression.GroupBy, selectExpression.Having, + // TODO: We should change the project column to be non-nullable, but it's too closed down for that. + projection, selectExpression.Orderings, - selectExpression.Limit, - selectExpression.Offset); + selectExpression.Offset, + selectExpression.Limit); return true; } diff --git a/src/EFCore.Relational/Query/SqlTreePruner.cs b/src/EFCore.Relational/Query/SqlTreePruner.cs index ca03fc79a84..18746cfc999 100644 --- a/src/EFCore.Relational/Query/SqlTreePruner.cs +++ b/src/EFCore.Relational/Query/SqlTreePruner.cs @@ -259,6 +259,6 @@ protected virtual SelectExpression PruneSelect(SelectExpression select, bool pre CurrentTableAlias = parentTableAlias; return select.Update( - projections ?? select.Projection, tables ?? select.Tables, predicate, groupBy, having, orderings, limit, offset); + tables ?? select.Tables, predicate, groupBy, having, projections ?? select.Projection, orderings, offset, limit); } } diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index 55e64112f64..b545b77b761 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -299,7 +299,7 @@ protected override Expression VisitSelect(SelectExpression selectExpression) _isSearchCondition = parentSearchCondition; - return selectExpression.Update(projections, tables, predicate, groupBy, havingExpression, orderings, limit, offset); + return selectExpression.Update(tables, predicate, groupBy, havingExpression, projections, orderings, offset, limit); } /// diff --git a/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs index e80a86c1b6c..00ecaab9ee9 100644 --- a/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SkipTakeCollapsingExpressionVisitor.cs @@ -67,14 +67,14 @@ protected override Expression VisitExtension(Expression extensionExpression) if (IsZero(selectExpression.Limit)) { return selectExpression.Update( - selectExpression.Projection, selectExpression.Tables, selectExpression.GroupBy.Count > 0 ? selectExpression.Predicate : _sqlExpressionFactory.Constant(false), selectExpression.GroupBy, selectExpression.GroupBy.Count > 0 ? _sqlExpressionFactory.Constant(false) : null, + selectExpression.Projection, new List(0), - limit: null, - offset: null); + offset: null, + limit: null); } bool IsZero(SqlExpression? sqlExpression) diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs index 161875f7a8b..1ad0e35e8e3 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerJsonPostprocessor.cs @@ -138,14 +138,14 @@ public virtual Expression Process(Expression expression) var newSelectExpression = newTables is not null ? selectExpression.Update( - selectExpression.Projection, newTables, selectExpression.Predicate, selectExpression.GroupBy, selectExpression.Having, + selectExpression.Projection, selectExpression.Orderings, - selectExpression.Limit, - selectExpression.Offset) + selectExpression.Offset, + selectExpression.Limit) : selectExpression; // when we mark columns for rewrite we don't yet have the updated SelectExpression, so we store the info in temporary dictionary diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs index 7f1063149e0..aca56014c1b 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs @@ -23,8 +23,9 @@ public class SqlServerQueryCompilationContext : RelationalQueryCompilationContex QueryCompilationContextDependencies dependencies, RelationalQueryCompilationContextDependencies relationalDependencies, bool async, + bool precompiling, bool multipleActiveResultSetsEnabled) - : base(dependencies, relationalDependencies, async) + : base(dependencies, relationalDependencies, async, precompiling) { _multipleActiveResultSetsEnabled = multipleActiveResultSetsEnabled; } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs index 30b1d39932c..d7aa170c427 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContextFactory.cs @@ -47,7 +47,16 @@ public class SqlServerQueryCompilationContextFactory : IQueryCompilationContextF /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual QueryCompilationContext Create(bool async) + public virtual QueryCompilationContext Create(bool async, bool precompiling) => new SqlServerQueryCompilationContext( - Dependencies, RelationalDependencies, async, _sqlServerConnection.IsMultipleActiveResultSetsEnabled); + Dependencies, RelationalDependencies, async, precompiling, _sqlServerConnection.IsMultipleActiveResultSetsEnabled); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryCompilationContext Create(bool async) + => throw new UnreachableException("The overload with `precompiling` should be called"); } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs index cf9f206c616..4f6b189126a 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContext.cs @@ -20,8 +20,9 @@ public class SqliteQueryCompilationContext : RelationalQueryCompilationContext public SqliteQueryCompilationContext( QueryCompilationContextDependencies dependencies, RelationalQueryCompilationContextDependencies relationalDependencies, - bool async) - : base(dependencies, relationalDependencies, async) + bool async, + bool precompiling) + : base(dependencies, relationalDependencies, async, precompiling) { } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs index 0570b91fbc3..c088df764ff 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryCompilationContextFactory.cs @@ -35,6 +35,15 @@ public class SqliteQueryCompilationContextFactory : IQueryCompilationContextFact /// protected virtual RelationalQueryCompilationContextDependencies RelationalDependencies { get; } + /// + /// Creates a new . + /// + /// Specifies whether the query is async. + /// Indicates whether the query is being precompiled. + /// The created query compilation context. + public QueryCompilationContext Create(bool async, bool precompiling) + => new SqliteQueryCompilationContext(Dependencies, RelationalDependencies, async, precompiling); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -42,5 +51,5 @@ public class SqliteQueryCompilationContextFactory : IQueryCompilationContextFact /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual QueryCompilationContext Create(bool async) - => new SqliteQueryCompilationContext(Dependencies, RelationalDependencies, async); + => throw new UnreachableException("The overload with `precompiling` should be called"); } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 5698691db62..f6ee51acd22 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -418,7 +418,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } } - private static string? ConstructLikePatternParameter( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] // Can be called from precompiled shapers + public static string? ConstructLikePatternParameter( QueryContext queryContext, string baseParameterName, bool startsWith) diff --git a/src/EFCore/Design/ICSharpHelper.cs b/src/EFCore/Design/ICSharpHelper.cs index 1dd92187d33..e477c22d06d 100644 --- a/src/EFCore/Design/ICSharpHelper.cs +++ b/src/EFCore/Design/ICSharpHelper.cs @@ -351,6 +351,7 @@ string Literal(T? value) /// /// The node to be translated. /// Any namespaces required by the translated code will be added to this set. + /// Any unsafe accessors needed to access private members will be added to this dictionary. /// Collection of translations for statically known instances. /// Collection of translations for non-public member accesses. /// Source code that would produce . @@ -363,6 +364,7 @@ string Literal(T? value) [EntityFrameworkInternal] string Statement(Expression node, ISet collectedNamespaces, + ISet unsafeAccessors, IReadOnlyDictionary? constantReplacements = null, IReadOnlyDictionary? memberAccessReplacements = null); @@ -371,6 +373,7 @@ string Literal(T? value) /// /// The node to be translated. /// Any namespaces required by the translated code will be added to this set. + /// Any unsafe accessors needed to access private members will be added to this dictionary. /// Collection of translations for statically known instances. /// Collection of translations for non-public member accesses. /// Source code that would produce . @@ -383,6 +386,7 @@ string Literal(T? value) [EntityFrameworkInternal] string Expression(Expression node, ISet collectedNamespaces, + ISet unsafeAccessors, IReadOnlyDictionary? constantReplacements = null, IReadOnlyDictionary? memberAccessReplacements = null); } diff --git a/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs b/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs index 9502298b15d..4961e6d98a3 100644 --- a/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs +++ b/src/EFCore/Design/Internal/CSharpRuntimeAnnotationCodeGenerator.cs @@ -364,6 +364,9 @@ public static void AddNamespace(Type type, ISet namespaces) AddNamespace(converter.ModelClrType, parameters.Namespaces); AddNamespace(converter.ProviderClrType, parameters.Namespaces); + // TODO + var unsafeAccessors = new HashSet(); + mainBuilder .Append("new ValueConverter<") .Append(codeHelper.Reference(converter.ModelClrType)) @@ -371,10 +374,10 @@ public static void AddNamespace(Type type, ISet namespaces) .Append(codeHelper.Reference(converter.ProviderClrType)) .AppendLine(">(") .IncrementIndent() - .AppendLines(codeHelper.Expression(converter.ConvertToProviderExpression, parameters.Namespaces, null, null), + .AppendLines(codeHelper.Expression(converter.ConvertToProviderExpression, parameters.Namespaces, unsafeAccessors), skipFinalNewline: true) .AppendLine(",") - .AppendLines(codeHelper.Expression(converter.ConvertFromProviderExpression, parameters.Namespaces, null, null), + .AppendLines(codeHelper.Expression(converter.ConvertFromProviderExpression, parameters.Namespaces, unsafeAccessors), skipFinalNewline: true); if (converter.ConvertsNulls) @@ -425,18 +428,21 @@ public static void AddNamespace(Type type, ISet namespaces) AddNamespace(typeof(ValueComparer<>), parameters.Namespaces); AddNamespace(comparer.Type, parameters.Namespaces); + // TODO + var unsafeAccessors = new HashSet(); + mainBuilder .Append("new ValueComparer<") .Append(codeHelper.Reference(comparer.Type)) .AppendLine(">(") .IncrementIndent() - .AppendLines(codeHelper.Expression(comparer.EqualsExpression, parameters.Namespaces, null, null), + .AppendLines(codeHelper.Expression(comparer.EqualsExpression, parameters.Namespaces, unsafeAccessors), skipFinalNewline: true) .AppendLine(",") - .AppendLines(codeHelper.Expression(comparer.HashCodeExpression, parameters.Namespaces, null, null), + .AppendLines(codeHelper.Expression(comparer.HashCodeExpression, parameters.Namespaces, unsafeAccessors), skipFinalNewline: true) .AppendLine(",") - .AppendLines(codeHelper.Expression(comparer.SnapshotExpression, parameters.Namespaces, null, null), + .AppendLines(codeHelper.Expression(comparer.SnapshotExpression, parameters.Namespaces, unsafeAccessors), skipFinalNewline: true) .Append(")") .DecrementIndent(); diff --git a/src/EFCore/Diagnostics/EventDefinitionBase.cs b/src/EFCore/Diagnostics/EventDefinitionBase.cs index 7c9e94eb9e5..f3d9f50f17f 100644 --- a/src/EFCore/Diagnostics/EventDefinitionBase.cs +++ b/src/EFCore/Diagnostics/EventDefinitionBase.cs @@ -16,7 +16,7 @@ public abstract class EventDefinitionBase /// Creates an event definition instance. /// /// Logging options. - /// The . + /// The . /// The at which the event will be logged. /// /// A string representing the code that should be passed to . diff --git a/src/EFCore/Internal/InternalDbSet.cs b/src/EFCore/Internal/InternalDbSet.cs index 440ad39159e..34eab90ecd6 100644 --- a/src/EFCore/Internal/InternalDbSet.cs +++ b/src/EFCore/Internal/InternalDbSet.cs @@ -19,6 +19,7 @@ public class InternalDbSet<[DynamicallyAccessedMembers(IEntityType.DynamicallyAc DbSet, IQueryable, IAsyncEnumerable, + IInfrastructure, IInfrastructure, IResettableService where TEntity : class @@ -520,6 +521,15 @@ Expression IQueryable.Expression IQueryProvider IQueryable.Provider => EntityQueryable.Provider; + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + DbContext IInfrastructure.Instance + => _context; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 6cde3b34aa4..3ce10a3ac67 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -994,6 +994,12 @@ public static string DuplicateTrigger(object? trigger, object? entityType, objec public static string EFConstantInvoked => GetString("EFConstantInvoked"); + /// + /// The EF.Constant<T> method is not supported when using precompiled queries. + /// + public static string EFConstantNotSupportedInPrecompiledQueries + => GetString("EFConstantNotSupportedInPrecompiledQueries"); + /// /// The EF.Constant<T> method may only be used with an argument that can be evaluated client-side and does not contain any reference to database-side entities. /// @@ -2233,6 +2239,14 @@ public static string NotCollection(object? entityType, object? property) GetString("NotCollection", nameof(entityType), nameof(property)), entityType, property); + /// + /// When precompiling queries, the '{parameter}' parameter of method '{method}' cannot be parameterized. + /// + public static string NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries(object? parameter, object? method) + => string.Format( + GetString("NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries", nameof(parameter), nameof(method)), + parameter, method); + /// /// The given 'IQueryable' does not support generation of query strings. /// @@ -2339,6 +2353,12 @@ public static string PoolingContextCtorError(object? contextType) public static string PoolingOptionsModified => GetString("PoolingOptionsModified"); + /// + /// Precompiled queries aren't supported by your EF provider. + /// + public static string PrecompiledQueryNotSupported + => GetString("PrecompiledQueryNotSupported"); + /// /// The derived type '{derivedType}' cannot have the [PrimaryKey] attribute since primary keys may only be declared on the root type. Move the attribute to '{rootType}', or remove '{rootType}' from the model by using [NotMapped] attribute or calling 'EntityTypeBuilder.Ignore' on the base type in 'OnModelCreating'. /// diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index b1f71053ae5..5e3e94fa50f 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -486,6 +486,9 @@ The EF.Constant<T> method may only be used within Entity Framework LINQ queries. + + The EF.Constant<T> method is not supported when using precompiled queries. + The EF.Constant<T> method may only be used with an argument that can be evaluated client-side and does not contain any reference to database-side entities. @@ -1291,6 +1294,9 @@ The property '{entityType}.{property}' cannot be mapped as a collection since it does not implement 'IEnumerable<T>'. + + When precompiling queries, the '{parameter}' parameter of method '{method}' cannot be parameterized. + The given 'IQueryable' does not support generation of query strings. @@ -1333,6 +1339,9 @@ 'OnConfiguring' cannot be used to modify DbContextOptions when DbContext pooling is enabled. + + Precompiled queries aren't supported by your EF provider. + The derived type '{derivedType}' cannot have the [PrimaryKey] attribute since primary keys may only be declared on the root type. Move the attribute to '{rootType}', or remove '{rootType}' from the model by using [NotMapped] attribute or calling 'EntityTypeBuilder.Ignore' on the base type in 'OnModelCreating'. diff --git a/src/EFCore/Query/ExpressionPrinter.cs b/src/EFCore/Query/ExpressionPrinter.cs index ff1abca0724..4c39cf21c09 100644 --- a/src/EFCore/Query/ExpressionPrinter.cs +++ b/src/EFCore/Query/ExpressionPrinter.cs @@ -461,13 +461,19 @@ protected override Expression VisitConditional(ConditionalExpression conditional /// protected override Expression VisitConstant(ConstantExpression constantExpression) { - if (constantExpression.Value is IPrintableExpression printable) + switch (constantExpression.Value) { - printable.Print(this); - } - else - { - PrintValue(constantExpression.Value); + case IPrintableExpression printable: + printable.Print(this); + break; + + case IQueryable queryable: + _stringBuilder.Append(Print(queryable.Expression)); + break; + + default: + PrintValue(constantExpression.Value); + break; } return constantExpression; diff --git a/src/EFCore/Query/IQueryCompilationContextFactory.cs b/src/EFCore/Query/IQueryCompilationContextFactory.cs index cbfcf9c1780..6206a0dc39f 100644 --- a/src/EFCore/Query/IQueryCompilationContextFactory.cs +++ b/src/EFCore/Query/IQueryCompilationContextFactory.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.EntityFrameworkCore.Query; /// @@ -26,4 +28,16 @@ public interface IQueryCompilationContextFactory /// Specifies whether the query is async. /// The created query compilation context. QueryCompilationContext Create(bool async); + + /// + /// Creates a new . + /// + /// Specifies whether the query is async. + /// Indicates whether the query is being precompiled. + /// The created query compilation context. + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + QueryCompilationContext Create(bool async, bool precompiling) + => precompiling + ? throw new InvalidOperationException(CoreStrings.PrecompiledQueryNotSupported) + : Create(async); } diff --git a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs index 8e01d6b73c6..c905f8d45a5 100644 --- a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs +++ b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs @@ -39,6 +39,11 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor /// private bool _calculatingPath; + /// + /// Indicates whether performing parameter extraction on a precompiled query. + /// + private bool _precompiledQuery; + /// /// Indicates whether we should parameterize. Is false in compiled query mode, as well as when we're handling query filters from /// NavigationExpandingExpressionVisitor. @@ -94,6 +99,10 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor private static readonly MethodInfo ReadOnlyCollectionIndexerGetter = typeof(ReadOnlyCollection).GetProperties() .Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!; + private static readonly MethodInfo ReadOnlyElementInitCollectionIndexerGetter = typeof(ReadOnlyCollection) + .GetProperties() + .Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!; + private static readonly MethodInfo ReadOnlyMemberBindingCollectionIndexerGetter = typeof(ReadOnlyCollection) .GetProperties() .Single(p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)).GetMethod!; @@ -140,11 +149,13 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor public virtual Expression ExtractParameters( Expression expression, IParameterValues parameterValues, + bool precompiledQuery, bool parameterize, bool clearParameterizedValues) { Reset(clearParameterizedValues); _parameterValues = parameterValues; + _precompiledQuery = precompiledQuery; _parameterize = parameterize; _calculatingPath = false; @@ -161,6 +172,23 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor return root; } + /// + /// Resets the funcletizer in preparation for multiple path calculations (i.e. for the same query). After this is called, + /// can be called multiple times, preserving state + /// between calls. + /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + public virtual void ResetPathCalculation() + { + Reset(); + _calculatingPath = true; + _parameterize = true; + + // In precompilation mode we don't actually extract parameter values; but we do need to generate the parameter names, using the + // same logic (and via the same code) used in parameter extraction, and that logic requires _parameterValues. + _parameterValues = new DummyParameterValues(); + } + /// /// Processes an expression tree, locates references to captured variables and returns information on how to extract them from /// expression trees with the same shape. Used to generate C# code for query precompilation. @@ -172,17 +200,74 @@ public class ExpressionTreeFuncletizer : ExpressionVisitor /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + public virtual PathNode? CalculatePathsToEvaluatableRoots(MethodCallExpression linqOperatorMethodCall, int argumentIndex) + { + var argument = linqOperatorMethodCall.Arguments[argumentIndex]; + if (argument is UnaryExpression { NodeType: ExpressionType.Quote } quote) + { + argument = quote.Operand; + } + + var root = Visit(argument, out var state); + + // If the top-most node in the tree is evaluatable, that means we have a non-lambda parameter to the LINQ operator (e.g. Skip/Take). + // We make sure to return a path containing the argument; note that since we're not in a lambda, the argument will always be + // parameterized since we're not inside a lambda (e.g. Skip/Take), except for [NotParameterized]. + if (state.IsEvaluatable + && IsParameterParameterizable(linqOperatorMethodCall.Method, linqOperatorMethodCall.Method.GetParameters()[argumentIndex])) + { + _ = Evaluate(root, out var parameterName, out _); + + state = new() + { + StateType = StateType.ContainsEvaluatable, + Path = new() + { + ExpressionType = state.ExpressionType!, + ParameterName = parameterName, + Children = Array.Empty() + } + }; + } + + return state.Path; + } + + /// + /// Processes an expression tree, locates references to captured variables and returns information on how to extract them from + /// expression trees with the same shape. Used to generate C# code for query precompilation. + /// + /// A tree representing the path to each evaluatable root node in the tree. + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] public virtual PathNode? CalculatePathsToEvaluatableRoots(Expression expression) { - Reset(); - _calculatingPath = true; - _parameterize = true; + var root = Visit(expression, out var state); - // In precompilation mode we don't actually extract parameter values; but we do need to generate the parameter names, using the - // same logic (and via the same code) used in parameter extraction, and that logic requires _parameterValues. - _parameterValues = new DummyParameterValues(); + // If the top-most node in the tree is evaluatable, that means we have a non-lambda parameter to the LINQ operator (e.g. Skip/Take). + // We make sure to return a path containing the argument; note that since we're not in a lambda, the argument will always be + // parameterized since we're not inside a lambda (e.g. Skip/Take), except for [NotParameterized]. + if (state.IsEvaluatable) + { + _ = Evaluate(root, out var parameterName, out _); - _ = Visit(expression, out var state); + state = new() + { + StateType = StateType.ContainsEvaluatable, + Path = new() + { + ExpressionType = state.ExpressionType!, + ParameterName = parameterName, + Children = Array.Empty() + } + }; + } return state.Path; } @@ -798,7 +883,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall) { if (_calculatingPath) { - throw new InvalidOperationException("EF.Constant is not supported when using precompiled queries"); + throw new InvalidOperationException(CoreStrings.EFConstantNotSupportedInPrecompiledQueries); } var argument = Visit(methodCall.Arguments[0], out var argumentState); @@ -888,17 +973,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall) // To support [NotParameterized] and indexer method arguments - which force evaluation as constant - go over the parameters // and modify the states as needed - ParameterInfo[]? parameterInfos = null; + ParameterInfo[]? parameters = null; for (var i = 0; i < methodCall.Arguments.Count; i++) { var argumentState = argumentStates[i]; if (argumentState.IsEvaluatable) { - parameterInfos ??= methodCall.Method.GetParameters(); - if (parameterInfos[i].GetCustomAttribute() is not null - || _model.IsIndexerMethod(methodCall.Method)) + parameters ??= methodCall.Method.GetParameters(); + if (!IsParameterParameterizable(methodCall.Method, parameters[i])) { + if (argumentState.StateType is StateType.EvaluatableWithCapturedVariable && _precompiledQuery) + { + throw new InvalidOperationException( + CoreStrings.NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries( + parameters[i].Name, + method.Name)); + } + argumentStates[i] = argumentState with { StateType = StateType.EvaluatableWithoutCapturedVariable, ForceConstantization = true @@ -1140,7 +1232,7 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInit) // Avoid allocating for the notEvaluatableAsRootHandler closure below unless we actually end up in the evaluatable case var (memberInit2, new2, newState2, bindings2, bindingStates2) = (memberInit, @new, newState, bindings, bindingStates); _state = State.CreateEvaluatable( - typeof(InvocationExpression), + typeof(MemberInitExpression), state is StateType.EvaluatableWithCapturedVariable, notEvaluatableAsRootHandler: () => EvaluateChildren(memberInit2, new2, newState2, bindings2, bindingStates2)); break; @@ -1299,7 +1391,7 @@ protected override Expression VisitListInit(ListInitExpression listInit) { children = [ - newState.Path! with { PathFromParent = static e => Property(e, nameof(MethodCallExpression.Object)) } + newState.Path! with { PathFromParent = static e => Property(e, nameof(ListInitExpression.NewExpression)) } ]; } @@ -1307,17 +1399,24 @@ protected override Expression VisitListInit(ListInitExpression listInit) { var initializer = initializers[i]; + // listInit.Initializers[0].Arguments[1] + var initializerIndex = i; var visitedArguments = EvaluateList( visitedInitializersArguments is null ? initializer.Arguments : visitedInitializersArguments[i], initializerArgumentStates[i], ref children, - static i => e => + j => e => Call( - Property(e, nameof(MethodCallExpression.Arguments)), + Property( + Call( + Property(e, nameof(ListInitExpression.Initializers)), + ReadOnlyElementInitCollectionIndexerGetter, + arguments: [Constant(initializerIndex)]), + nameof(System.Linq.Expressions.ElementInit.Arguments)), ReadOnlyCollectionIndexerGetter, - arguments: [Constant(i)])); + arguments: [Constant(j)])); if (visitedArguments is not null && visitedInitializersArguments is null) { @@ -1964,6 +2063,10 @@ private bool IsGenerallyEvaluatable(Expression expression) // Don't evaluate QueryableMethods if in compiled query || !(expression is MethodCallExpression { Method: var method } && method.DeclaringType == typeof(Queryable))); + private bool IsParameterParameterizable(MethodInfo method, ParameterInfo parameter) + => parameter.GetCustomAttribute() is null + && !_model.IsIndexerMethod(method); + private enum StateType { /// diff --git a/src/EFCore/Query/Internal/IQueryCompiler.cs b/src/EFCore/Query/Internal/IQueryCompiler.cs index 32749f8a345..277a351417a 100644 --- a/src/EFCore/Query/Internal/IQueryCompiler.cs +++ b/src/EFCore/Query/Internal/IQueryCompiler.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.EntityFrameworkCore.Query.Internal; /// @@ -48,4 +50,13 @@ public interface IQueryCompiler /// doing so can result in application failures when updating to a new Entity Framework Core release. /// Func CreateCompiledAsyncQuery(Expression query); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + Expression> PrecompileQuery(Expression query, bool async); } diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index 07b5982255b..f91f6f77a5f 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -209,7 +209,8 @@ protected override Expression VisitExtension(Expression extensionExpression) && entityQueryRootExpression.GetType() == typeof(EntityQueryRootExpression)) { var processedDefiningQueryBody = _funcletizer.ExtractParameters( - definingQuery.Body, _parameters, parameterize: false, clearParameterizedValues: false); + definingQuery.Body, _parameters, _queryCompilationContext.IsPrecompiling, parameterize: false, + clearParameterizedValues: false); processedDefiningQueryBody = _queryTranslationPreprocessor.NormalizeQueryableMethod(processedDefiningQueryBody); processedDefiningQueryBody = _nullCheckRemovingExpressionVisitor.Visit(processedDefiningQueryBody); processedDefiningQueryBody = @@ -1753,7 +1754,8 @@ private Expression ApplyQueryFilter(IEntityType entityType, NavigationExpansionE { filterPredicate = queryFilter; filterPredicate = (LambdaExpression)_funcletizer.ExtractParameters( - filterPredicate, _parameters, parameterize: false, clearParameterizedValues: false); + filterPredicate, _parameters, _queryCompilationContext.IsPrecompiling, parameterize: false, + clearParameterizedValues: false); filterPredicate = (LambdaExpression)_queryTranslationPreprocessor.NormalizeQueryableMethod(filterPredicate); // We need to do entity equality, but that requires a full method call on a query root to properly flow the diff --git a/src/EFCore/Query/Internal/PrecompiledQueryContext.cs b/src/EFCore/Query/Internal/PrecompiledQueryContext.cs new file mode 100644 index 00000000000..8587758c602 --- /dev/null +++ b/src/EFCore/Query/Internal/PrecompiledQueryContext.cs @@ -0,0 +1,128 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// A context for a precompiled query that's being executed. This wraps the via which the query is being +/// executed, as well as a regular EF . It is flown through all intercepted LINQ operators, until the +/// terminating operator interceptor which actually executes the query. Note that it implements so that +/// it can be flown from one intercepted LINQ operator to another. +/// +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +[Experimental(EFDiagnostics.PrecompiledQueryExperimental)] +public class PrecompiledQueryContext : IOrderedQueryable +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public PrecompiledQueryContext(DbContext dbContext) + : this(dbContext, dbContext.GetService().Create()) + { + } + + private PrecompiledQueryContext(DbContext dbContext, QueryContext queryContext) + { + DbContext = dbContext; + QueryContext = queryContext; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual DbContext DbContext { get; set; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryContext QueryContext { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual PrecompiledQueryContext ToType() + => new(DbContext, QueryContext); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IncludablePrecompiledQueryContext ToIncludable() + => new(DbContext, QueryContext); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IEnumerator GetEnumerator() + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + IEnumerator IEnumerable.GetEnumerator() + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public Type ElementType + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public Expression Expression + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IQueryProvider Provider + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class IncludablePrecompiledQueryContext(DbContext dbContext, QueryContext queryContext) + : PrecompiledQueryContext(dbContext, queryContext), IIncludableQueryable; +} diff --git a/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs b/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs new file mode 100644 index 00000000000..e01894b27aa --- /dev/null +++ b/src/EFCore/Query/Internal/PrecompiledQueryableAsyncEnumerableAdapter.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +[Experimental(EFDiagnostics.PrecompiledQueryExperimental)] +public class PrecompiledQueryableAsyncEnumerableAdapter(IAsyncEnumerable asyncEnumerable) + : IQueryable, IAsyncEnumerable +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => asyncEnumerable.GetAsyncEnumerator(cancellationToken); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IEnumerator GetEnumerator() + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + IEnumerator IEnumerable.GetEnumerator() + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public Type ElementType + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public Expression Expression + => throw new NotSupportedException(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IQueryProvider Provider + => throw new NotSupportedException(); +} diff --git a/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs b/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs index 72d21b06c26..857ef233ec4 100644 --- a/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs +++ b/src/EFCore/Query/Internal/QueryCompilationContextFactory.cs @@ -27,6 +27,15 @@ public QueryCompilationContextFactory(QueryCompilationContextDependencies depend /// protected virtual QueryCompilationContextDependencies Dependencies { get; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryCompilationContext Create(bool async, bool precompiling) + => new(Dependencies, async, precompiling); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -34,5 +43,5 @@ public QueryCompilationContextFactory(QueryCompilationContextDependencies depend /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual QueryCompilationContext Create(bool async) - => new(Dependencies, async); + => throw new UnreachableException("The overload with `precompiling` should be called"); } diff --git a/src/EFCore/Query/Internal/QueryCompiler.cs b/src/EFCore/Query/Internal/QueryCompiler.cs index a0536aff1bb..3dfa3aea333 100644 --- a/src/EFCore/Query/Internal/QueryCompiler.cs +++ b/src/EFCore/Query/Internal/QueryCompiler.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; namespace Microsoft.EntityFrameworkCore.Query.Internal; @@ -125,6 +126,19 @@ var compiledQuery bool async) => database.CompileQuery(query, async); + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + public virtual Expression> PrecompileQuery(Expression query, bool async) + { + query = ExtractParameters(query, _queryContextFactory.Create(), _logger, precompiledQuery: true); + return _database.CompileQueryExpression(query, async); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -136,7 +150,8 @@ var compiledQuery IParameterValues parameterValues, IDiagnosticsLogger logger, bool compiledQuery = false, + bool precompiledQuery = false, bool generateContextAccessors = false) => new ExpressionTreeFuncletizer(_model, _evaluatableExpressionFilter, _contextType, generateContextAccessors: false, logger) - .ExtractParameters(query, parameterValues, parameterize: !compiledQuery, clearParameterizedValues: true); + .ExtractParameters(query, parameterValues, precompiledQuery, parameterize: !compiledQuery, clearParameterizedValues: true); } diff --git a/src/EFCore/Query/LiftableConstantProcessor.cs b/src/EFCore/Query/LiftableConstantProcessor.cs index bf1935bd227..acd2f796c31 100644 --- a/src/EFCore/Query/LiftableConstantProcessor.cs +++ b/src/EFCore/Query/LiftableConstantProcessor.cs @@ -219,6 +219,11 @@ protected virtual ParameterExpression LiftConstant(LiftableConstantExpression li body = convertNode.Operand; } + if (body.Type != liftableConstant.Type) + { + body = Expression.Convert(body, liftableConstant.Type); + } + // Register the lifted constant; note that the name will be uniquified later var variableParameter = Expression.Parameter(liftableConstant.Type, liftableConstant.VariableName); _liftedConstants.Add(new(variableParameter, body)); diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index 8a86ed0d5ec..a6809179cbd 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.EntityFrameworkCore.Query; /// @@ -69,11 +71,27 @@ public class QueryCompilationContext public QueryCompilationContext( QueryCompilationContextDependencies dependencies, bool async) + : this(dependencies, async, precompiling: false) + { + } + + /// + /// Creates a new instance of the class. + /// + /// Parameter object containing dependencies for this class. + /// A bool value indicating whether it is for async query. + /// Indicates whether the query is being precompiled. + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + public QueryCompilationContext( + QueryCompilationContextDependencies dependencies, + bool async, + bool precompiling) { Dependencies = dependencies; IsAsync = async; QueryTrackingBehavior = dependencies.QueryTrackingBehavior; IsBuffering = ExecutionStrategy.Current?.RetriesOnFailure ?? dependencies.IsRetryingExecutionStrategy; + IsPrecompiling = precompiling; Model = dependencies.Model; ContextOptions = dependencies.ContextOptions; ContextType = dependencies.ContextType; @@ -116,6 +134,11 @@ public class QueryCompilationContext /// public virtual bool IsBuffering { get; } + /// + /// Indicates whether the query is being precompiled. + /// + public virtual bool IsPrecompiling { get; } + /// /// A value indicating whether query filters are ignored in this query. /// @@ -167,7 +190,7 @@ public virtual void AddTag(string tag) // across invocations of the query. // In normal mode, these nodes should simply be evaluated, and a ConstantExpression to those instances embedded directly in the // tree (for precompiled queries we generate C# code for resolving those instances instead). - var queryExecutorAfterLiftingExpression = + var queryExecutorAfterLiftingExpression = (Expression>)Dependencies.LiftableConstantProcessor.InlineConstants(queryExecutorExpression, SupportsPrecompiledQuery); try @@ -186,6 +209,7 @@ public virtual void AddTag(string tag) /// The result type of this query. /// The query to generate executor for. /// Returns which can be invoked to get results of this query. + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] public virtual Expression> CreateQueryExecutorExpression(Expression query) { var queryAndEventData = Logger.QueryCompilationStarting(Dependencies.Context, _expressionPrinter, query); @@ -204,11 +228,9 @@ public virtual void AddTag(string tag) // wrap the query with code adding those parameters to the query context query = InsertRuntimeParameters(query); - var queryExecutorExpression = Expression.Lambda>( + return Expression.Lambda>( query, QueryContextParameter); - - return queryExecutorExpression; } /// @@ -223,7 +245,7 @@ public virtual ParameterExpression RegisterRuntimeParameter(string name, LambdaE { valueExtractorBody = _runtimeParameterConstantLifter.Visit(valueExtractorBody); } - + valueExtractor = Expression.Lambda(valueExtractorBody, valueExtractor.Parameters); if (valueExtractor.Parameters.Count != 1 diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index 315f457eccd..ef4dd17b185 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -230,12 +230,12 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio { Value: IEntityType entityTypeValue } => liftableConstantFactory.CreateLiftableConstant( constantExpression.Value, LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(entityTypeValue), - entityTypeValue.Name + "EntityType", + entityTypeValue.ShortName() + "EntityType", constantExpression.Type), { Value: IComplexType complexTypeValue } => liftableConstantFactory.CreateLiftableConstant( constantExpression.Value, LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(complexTypeValue), - complexTypeValue.Name + "ComplexType", + complexTypeValue.ShortName() + "ComplexType", constantExpression.Type), { Value: IProperty propertyValue } => liftableConstantFactory.CreateLiftableConstant( constantExpression.Value, @@ -563,7 +563,7 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper) LiftableConstantExpressionHelpers.BuildMemberAccessForEntityOrComplexType(typeBase, resolverPrm), EntityTypeFindPrimaryKeyMethod), resolverPrm), - typeBase.Name + "Key", + /*typeBase.Name +*/ "key", typeof(IKey)) : Constant(primaryKey), NewArrayInit( @@ -697,8 +697,8 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper) Snapshot.Empty, static _ => Snapshot.Empty, "emptySnapshot", - typeof(Snapshot)) - : Constant(Snapshot.Empty))); + typeof(ISnapshot)) + : Constant(Snapshot.Empty, typeof(ISnapshot)))); var returnType = typeBase.ClrType; var valueBufferExpression = Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod); @@ -725,7 +725,7 @@ private Expression ProcessEntityShaper(StructuralTypeShaperExpression shaper) ? _liftableConstantFactory.CreateLiftableConstant( concreteEntityTypes[i], LiftableConstantExpressionHelpers.BuildMemberAccessLambdaForEntityOrComplexType(concreteEntityType), - concreteEntityType.Name + (typeBase is IEntityType ? "EntityType" : "ComplexType"), + concreteEntityType.ShortName() + (typeBase is IEntityType ? "EntityType" : "ComplexType"), typeBase is IEntityType ? typeof(IEntityType) : typeof(IComplexType)) : Constant(concreteEntityTypes[i], typeBase is IEntityType ? typeof(IEntityType) : typeof(IComplexType))); } diff --git a/src/EFCore/Storage/Database.cs b/src/EFCore/Storage/Database.cs index 66344581ce9..95606b3df64 100644 --- a/src/EFCore/Storage/Database.cs +++ b/src/EFCore/Storage/Database.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.EntityFrameworkCore.Storage; /// @@ -64,6 +66,13 @@ protected Database(DatabaseDependencies dependencies) /// public virtual Func CompileQuery(Expression query, bool async) => Dependencies.QueryCompilationContextFactory - .Create(async) + .Create(async, precompiling: false) .CreateQueryExecutor(query); + + /// + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + public virtual Expression> CompileQueryExpression(Expression query, bool async) + => Dependencies.QueryCompilationContextFactory + .Create(async, precompiling: true) + .CreateQueryExecutorExpression(query); } diff --git a/src/EFCore/Storage/IDatabase.cs b/src/EFCore/Storage/IDatabase.cs index eeeb574a812..69ee1952b1d 100644 --- a/src/EFCore/Storage/IDatabase.cs +++ b/src/EFCore/Storage/IDatabase.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.EntityFrameworkCore.Storage; /// @@ -55,4 +57,14 @@ public interface IDatabase /// A value indicating whether this is an async query. /// A which can be invoked to get results of the query. Func CompileQuery(Expression query, bool async); + + /// + /// Compiles the given query to generate an expression tree which can be used to execute the query. + /// + /// The type of query result. + /// The query to compile. + /// A value indicating whether this is an async query. + /// An expression tree which can be used to execute the query. + [Experimental(EFDiagnostics.PrecompiledQueryExperimental)] + Expression> CompileQueryExpression(Expression query, bool async); } diff --git a/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj b/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj index 1ce0a1d4856..9562fd1d55e 100644 --- a/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj +++ b/test/EFCore.Design.Tests/EFCore.Design.Tests.csproj @@ -56,7 +56,7 @@ - + diff --git a/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs b/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs new file mode 100644 index 00000000000..3fa2f2c2f6f --- /dev/null +++ b/test/EFCore.Design.Tests/Query/CSharpToLinqTranslatorTest.cs @@ -0,0 +1,537 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.EntityFrameworkCore.Query.Internal; +using static System.Linq.Expressions.Expression; + +namespace Microsoft.EntityFrameworkCore.Query; + +// ReSharper disable InconsistentNaming +// ReSharper disable RedundantCast + +#nullable enable + +public class CSharpToLinqTranslatorTest +{ + [Fact] + public void ArrayCreation() + => AssertExpression( + () => new int[3], + "new int[3]"); + + // ReSharper disable RedundantExplicitArrayCreation + [Fact] + public void ArrayCreation_with_initializer() + => AssertExpression( + () => new int[] { 1, 2 }, + "new int[] { 1, 2 }"); + // ReSharper restore RedundantExplicitArrayCreation + + // ReSharper disable BuiltInTypeReferenceStyle + [Fact] + public void As() + => AssertExpression( + () => "foo" as String, + """ "foo" as String"""); + // ReSharper restore BuiltInTypeReferenceStyle + + [Fact] + public void As_with_predefined_type() + => AssertExpression( + () => "foo" as string, + """ "foo" as string"""); + + [Theory] + [InlineData("1 + 2", ExpressionType.Add)] + [InlineData("1 - 2", ExpressionType.Subtract)] + [InlineData("1 * 2", ExpressionType.Multiply)] + [InlineData("1 / 2", ExpressionType.Divide)] + [InlineData("1 % 2", ExpressionType.Modulo)] + [InlineData("1 & 2", ExpressionType.And)] + [InlineData("1 | 2", ExpressionType.Or)] + [InlineData("1 ^ 2", ExpressionType.ExclusiveOr)] + [InlineData("1 >> 2", ExpressionType.RightShift)] + [InlineData("1 << 2", ExpressionType.LeftShift)] + [InlineData("1 < 2", ExpressionType.LessThan)] + [InlineData("1 <= 2", ExpressionType.LessThanOrEqual)] + [InlineData("1 > 2", ExpressionType.GreaterThan)] + [InlineData("1 >= 2", ExpressionType.GreaterThanOrEqual)] + [InlineData("1 == 2", ExpressionType.Equal)] + [InlineData("1 != 2", ExpressionType.NotEqual)] + public void Binary_int(string code, ExpressionType binaryType) + => AssertExpression( + MakeBinary(binaryType, Constant(1), Constant(2)), + code); + + [Theory] + [InlineData("true && false", ExpressionType.AndAlso)] + [InlineData("true || false", ExpressionType.OrElse)] + [InlineData("true ^ false", ExpressionType.ExclusiveOr)] + public void Binary_bool(string code, ExpressionType binaryType) + => AssertExpression( + Lambda>( + MakeBinary(binaryType, Constant(true), Constant(false))), + code); + + [Fact] + public void Binary_add_string() + => AssertExpression( + () => new[] { "foo", "bar" }.Select(s => s + "foo"), + """new[] { "foo", "bar" }.Select(s => s + "foo")"""); + + [Fact] + public void Cast() + => AssertExpression( + () => (object)1, + "(object)1"); + + [Fact] + public void Coalesce() + => AssertExpression( + () => (object?)"foo" ?? (object)"bar", + """(object?)"foo" ?? (object)"bar" """); + + [Fact] + public void ElementAccess_over_array() + => AssertExpression( + () => new[] { 1, 2, 3 }[1], + "new[] { 1, 2, 3 } [1]"); + + [Fact] + public void ElementAccess_over_list() + => AssertExpression( + () => new List { 1, 2, 3 }[1], + "new List { 1, 2, 3 }[1]"); + + [Fact] + public void IdentifierName_for_lambda_parameter() + => AssertExpression( + () => new[] { 1, 2, 3 }.Where(i => i == 2), + "new[] { 1, 2, 3 }.Where(i => i == 2);"); + + [Fact] + public void ImplicitArrayCreation() + => AssertExpression( + () => new[] { 1, 2 }, + "new[] { 1, 2 }"); + + [Fact] + public void Interpolated_string() + => AssertExpression( + () => string.Format("Foo: {0}", new[] { (object)8 }), + """$"Foo: {8}" """); + + [Fact] + public void Interpolated_string_formattable() + => AssertExpression( + () => FormattableStringMethod(FormattableStringFactory.Create("Foo: {0}, {1}", (object)8, (object) 9)), + """CSharpToLinqTranslatorTest.FormattableStringMethod($"Foo: {8}, {9}")"""); + + [Fact] + public void Index_over_array() + => AssertExpression( + () => new[] { 1, 2 }[0], + "new[] { 1, 2 }[0]"); + + [Fact] + public void Index_over_List() + => AssertExpression( + () => new List { 1, 2 }[0], + "new List { 1, 2 }[0]"); + + [Fact] + public void Invocation_instance_method() + => AssertExpression( + () => "foo".Substring(2), + """ "foo".Substring(2)"""); + + [Fact] + public void Invocation_method_with_optional_parameter() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(4), + NewArrayInit(typeof(int))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4)"); + + [Fact] + public void Invocation_method_with_optional_parameter_missing_argument() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(3), + NewArrayInit(typeof(int))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1)"); + + [Fact] + public void Invocation_method_with_params_parameter_no_arguments() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(4), + NewArrayInit(typeof(int))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4)"); + + [Fact] + public void Invocation_method_with_params_parameter_one_argument() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(4), + NewArrayInit(typeof(int), Constant(5))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4, 5)"); + + [Fact] + public void Invocation_method_with_params_parameter_multiple_arguments() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(4), + NewArrayInit(typeof(int), Constant(5), Constant(6))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1, 4, 5, 6)"); + + [Fact] + public void Invocation_method_with_params_parameter_missing_argument() + => AssertExpression( + Call( + typeof(CSharpToLinqTranslatorTest).GetMethod(nameof(ParamsAndOptionalMethod), [typeof(int), typeof(int), typeof(int[])])!, + Constant(1), + Constant(3), + NewArrayInit(typeof(int))), + "CSharpToLinqTranslatorTest.ParamsAndOptionalMethod(1)"); + + [Fact] + public void Invocation_static_method() + => AssertExpression( + () => DateTime.Parse("2020-01-01"), + """DateTime.Parse("2020-01-01")"""); + + [Fact] + public void Invocation_extension_method() + => AssertExpression( + () => typeof(string).GetTypeInfo(), + "typeof(string).GetTypeInfo()"); + + // ReSharper disable InvokeAsExtensionMethod + [Fact] + public void Invocation_extension_method_with_non_extension_syntax() + => AssertExpression( + () => IntrospectionExtensions.GetTypeInfo(typeof(string)), + "typeof(string).GetTypeInfo()"); + // ReSharper restore InvokeAsExtensionMethod + + [Fact] + public void Invocation_generic_method() + => AssertExpression( + () => Enumerable.Repeat("foo", 5), + """Enumerable.Repeat("foo", 5)"""); + + [Fact] + public void Invocation_generic_extension_method() + => AssertExpression( + () => new[] { 1, 2, 3 }.Where(i => i > 1), + "new[] { 1, 2, 3 }.Where(i => i > 1)"); + + [Fact] + public void Invocation_generic_queryable_extension_method() + => AssertExpression( + () => new[] { 1, 2, 3 }.AsQueryable().Where(i => i > 1), + "new[] { 1, 2, 3 }.AsQueryable().Where(i => i > 1)"); + + [Fact] + public void Invocation_non_generic_method_on_generic_type() + => AssertExpression( + () => SomeGenericType.SomeFunction(1), + "CSharpToLinqTranslatorTest.SomeGenericType.SomeFunction(1)"); + + [Fact] + public void Invocation_generic_method_on_generic_type() + => AssertExpression( + () => SomeGenericType.SomeGenericFunction(1, "foo"), + """CSharpToLinqTranslatorTest.SomeGenericType.SomeGenericFunction(1, "foo")"""); + + [Theory] + [InlineData(""" + "hello" + """, "hello")] + [InlineData("1", 1)] + [InlineData("1L", 1L)] + [InlineData("1U", 1U)] + [InlineData("1UL", 1UL)] + [InlineData("1.5D", 1.5)] + [InlineData("1.5F", 1.5F)] + [InlineData("true", true)] + public void Literal(string csharpLiteral, object expectedValue) + => AssertExpression( + Constant(expectedValue), + csharpLiteral); + + [Fact] + public void Literal_decimal() + => AssertExpression( + () => 1.5m, + "1.5m"); + + [Fact] + public void Literal_null() + => AssertExpression( + Equal(Constant("foo"), Constant(null, typeof(string))), + """ "foo" == null"""); + + [Fact] + public void Literal_enum() + => AssertExpression( + () => SomeEnum.Two, + "CSharpToLinqTranslatorTest.SomeEnum.Two"); + + [Fact] + public void Literal_enum_with_multiple_values() + => AssertExpression( + Convert( + Or( + Convert(Constant(SomeEnum.One), typeof(int)), + Convert(Constant(SomeEnum.Two), typeof(int))), + typeof(SomeEnum)), + "CSharpToLinqTranslatorTest.SomeEnum.One | CSharpToLinqTranslatorTest.SomeEnum.Two"); + + [Fact] + public void MemberAccess_array_length() + => AssertExpression( + () => new[] { 1, 2, 3 }.Length, + "new[] { 1, 2, 3 }.Length"); + + [Fact] + public void MemberAccess_instance_property() + => AssertExpression( + () => "foo".Length, + """ "foo".Length"""); + + [Fact] + public void MemberAccess_static_property() + => AssertExpression( + () => DateTime.Now, + "DateTime.Now"); + + // TODO: MemberAccess on fields + + [Fact] + public void Nested_type() + => AssertExpression( + () => (object)new Blog(), + "(object)new CSharpToLinqTranslatorTest.Blog()"); + + [Fact] + public void Not_boolean() + => AssertExpression( + Not(Constant(true)), + "!true"); + + [Fact] + public void ObjectCreation() + => AssertExpression( + () => new List(), + "new List()"); + + [Fact] + public void ObjectCreation_with_arguments() + => AssertExpression( + () => new List(10), + "new List(10)"); + + [Fact] + public void ObjectCreation_with_initializers() + => AssertExpression( + () => new Blog(8) { Name = "foo" }, + """new CSharpToLinqTranslatorTest.Blog(8) { Name = "foo" }"""); + + [Fact] + public void ObjectCreation_with_parameterless_struct_constructor() + => AssertExpression( + () => new DateTime(), + "new DateTime()"); + + [Fact] + public void Parenthesized() + => AssertExpression( + () => 1, + "(1)"); + + [Theory] + [InlineData("+8", 8, ExpressionType.UnaryPlus)] + [InlineData("-8", 8, ExpressionType.Negate)] + [InlineData("~8", 8, ExpressionType.Not)] + public void PrefixUnary(string code, object operandValue, ExpressionType expectedNodeType) + => AssertExpression( + MakeUnary(expectedNodeType, Constant(8), typeof(int)), + code); + + // ReSharper disable RedundantSuppressNullableWarningExpression + [Fact] + public void SuppressNullableWarningExpression() + => AssertExpression( + () => "foo"!, + """ "foo"! """); + // ReSharper restore RedundantSuppressNullableWarningExpression + + [ConditionalFact] + public void Typeof() + => AssertExpression( + () => typeof(string), + "typeof(string)"); + + [ConditionalFact] + public void Array_type() + => AssertExpression( + () => typeof(ParameterExpression[]), + "typeof(ParameterExpression[])"); + + protected virtual void AssertExpression(Expression> expected, string code) + => AssertExpression( + expected.Body, + code); + + protected virtual void AssertExpression(Expression expected, string code) + { + code = $""" +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Query; + +_ = {code}; +"""; + + var compilation = Compile(code); + + var syntaxTree = compilation.SyntaxTrees.Single(); + + if (syntaxTree.GetRoot() is CompilationUnitSyntax { Members: [GlobalStatementSyntax globalStatement, ..] }) + { + var expression = globalStatement switch + { + { Statement: ExpressionStatementSyntax { Expression: AssignmentExpressionSyntax { Right: var e } } } => e, + { Statement: LocalDeclarationStatementSyntax e } => e.Declaration.Variables[0].Initializer!.Value, + { Statement: ExpressionStatementSyntax { Expression: var e } } => e, + + _ => throw new InvalidOperationException("Could not find expression to assert on") + }; + + var actual = Translate(expression, compilation); + + Assert.Equal(expected, actual, ExpressionEqualityComparer.Instance); + } + else + { + Assert.Fail("Could not find expression to assert on"); + } + } + + private Compilation Compile(string code) + { + var syntaxTree = CSharpSyntaxTree.ParseText(code); + + var compilation = CSharpCompilation.Create( + "TestCompilation", + syntaxTrees: new[] { syntaxTree }, + references: MetadataReferences); + + var diagnostics = compilation.GetDiagnostics() + .Where(d => d.Severity is DiagnosticSeverity.Error) + .ToArray(); + + if (diagnostics.Any()) + { + var stringBuilder = new StringBuilder() + .AppendLine("Compilation errors:"); + + foreach (var diagnostic in diagnostics) + { + stringBuilder.AppendLine(diagnostic.ToString()); + } + + Assert.Fail(stringBuilder.ToString()); + } + + return compilation; + } + + private Expression Translate(SyntaxNode node, Compilation compilation) + { + var blogContext = new BlogContext(); + var translator = new CSharpToLinqTranslator(); + translator.Load(compilation, blogContext); + return translator.Translate(node, compilation.GetSemanticModel(node.SyntaxTree)); + } + + private static readonly MetadataReference[] MetadataReferences; + + static CSharpToLinqTranslatorTest() + { + var metadataReferences = new List + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Queryable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IQueryable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DbContext).Assembly.Location), + MetadataReference.CreateFromFile(typeof(BlogContext).Assembly.Location) + }; + + var netAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + + metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "mscorlib.dll"))); + metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.dll"))); + metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Core.dll"))); + metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Runtime.dll"))); + + MetadataReferences = metadataReferences.ToArray(); + } + + [Flags] + public enum SomeEnum + { + One = 1, + Two = 2 + } + + private class BlogContext : DbContext; + + public class Blog + { + public Blog() + { + } + + public Blog(int id) + => Id = id; + + public int Id { get; set; } + public string? Name { get; set; } + } + + public class SomeGenericType + { + public static int SomeFunction(T1 t1) + => 0; + + public static int SomeGenericFunction(T1 t1, T2 t2) + => 0; + } + + public static int ParamsAndOptionalMethod(int a, int b = 3, params int[] c) + => throw new NotSupportedException(); + + public static int FormattableStringMethod(FormattableString formattableString) + => throw new NotSupportedException(); +} diff --git a/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs b/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs index 15e7f815257..5f7821429b7 100644 --- a/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs +++ b/test/EFCore.Design.Tests/Query/LinqToCSharpSyntaxTranslatorTest.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; using Microsoft.EntityFrameworkCore.Design.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; @@ -16,8 +17,6 @@ namespace Microsoft.EntityFrameworkCore.Query; public class LinqToCSharpSyntaxTranslatorTest(ITestOutputHelper testOutputHelper) { - private readonly ITestOutputHelper _testOutputHelper = testOutputHelper; - [Theory] [InlineData("hello", "\"hello\"")] [InlineData(1, "1")] @@ -33,9 +32,7 @@ public class LinqToCSharpSyntaxTranslatorTest(ITestOutputHelper testOutputHelper [InlineData(true, "true")] [InlineData(typeof(string), "typeof(string)")] public void Constant_values(object constantValue, string literalRepresentation) - => AssertExpression( - Constant(constantValue), - literalRepresentation); + => AssertExpression(Constant(constantValue), literalRepresentation); [Fact] public void Constant_DateTime_default() @@ -105,14 +102,6 @@ public void Binary_PowerAssign() PowerAssign(Parameter(typeof(double), "d"), Constant(3.0)), "d = Math.Pow(d, 3D)"); - [Fact] - public void Private_instance_field_SimpleAssign() - => AssertExpression( - Assign( - Field(Parameter(typeof(Blog), "blog"), "_privateField"), - Constant(3)), - """typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 3)"""); - [Theory] [InlineData(ExpressionType.AddAssign, "+=")] [InlineData(ExpressionType.MultiplyAssign, "*=")] @@ -130,7 +119,14 @@ public void Private_instance_field_AssignOperators(ExpressionType expressionType expressionType, Field(Parameter(typeof(Blog), "blog"), "_privateField"), Constant(3)), - $"""typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 3)"""); + $"UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(blog) {op} 3", + unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")] +private static extern ref int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(LinqToCSharpSyntaxTranslatorTest.Blog instance); +""", + Assert.Single(unsafeAccessors), + ignoreLineEndingDifferences: true)); [Theory] [InlineData(ExpressionType.AddAssign, "+=")] @@ -149,11 +145,9 @@ public void Private_instance_field_AssignOperators_with_replacements(ExpressionT expressionType, Field(Parameter(typeof(Blog), "blog"), "_privateField"), Constant(3)), - $"""AccessPrivateField(blog) {op} Three""", - new Dictionary() { { 3, "Three" } }, - new Dictionary() { + $"""AccessPrivateField(blog) {op} Three""", new Dictionary() { { 3, "Three" } }, new Dictionary() { { BlogPrivateField, new QualifiedName("AccessPrivateField", "") } - }); + }); [Theory] [InlineData(ExpressionType.Negate, "-i")] @@ -192,7 +186,7 @@ public void Unary_statement(ExpressionType expressionType, string expected) MakeUnary(expressionType, i, typeof(int))), $$""" { - int i; + int i = default; {{expected}}; } """); @@ -252,25 +246,64 @@ public void Static_property() typeof(DateTime).GetProperty(nameof(DateTime.Now))!), "DateTime.Now"); + [Fact] + public void Indexer_property() + => AssertExpression( + Call( + New(typeof(List)), + typeof(List).GetProperties().Single( + p => p.GetIndexParameters() is { Length: 1 } indexParameters && indexParameters[0].ParameterType == typeof(int)) + .GetMethod!, + Constant(1)), "new List()[1]"); + [Fact] public void Private_instance_field_read() => AssertExpression( - Field(Parameter(typeof(Blog), "blog"), "_privateField"), - """(int)typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).GetValue(blog)"""); + Field( + Parameter(typeof(Blog), "blog"), + "_privateField"), + "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Get(blog)", unsafeAccessorsAsserter: accessors => + Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")] +private static extern int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Get(LinqToCSharpSyntaxTranslatorTest.Blog instance); +""", + Assert.Single(accessors), + ignoreLineEndingDifferences: true)); [Fact] public void Private_instance_field_write() => AssertStatement( Assign( - Field(Parameter(typeof(Blog), "blog"), "_privateField"), + Field( + Parameter(typeof(Blog), "blog"), + "_privateField"), Constant(8)), - """typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("_privateField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).SetValue(blog, 8)"""); + "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(blog) = 8", unsafeAccessorsAsserter: accessors => + Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")] +private static extern ref int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog__privateField_Set(LinqToCSharpSyntaxTranslatorTest.Blog instance); +""", + Assert.Single(accessors), + ignoreLineEndingDifferences: true)); + + // TODO: Also test accessing private static fields + // TODO: Also test accessing private properties, instance and static [Fact] public void Internal_instance_field_read() => AssertExpression( - Field(Parameter(typeof(Blog), "blog"), "InternalField"), - """(int)typeof(LinqToCSharpSyntaxTranslatorTest.Blog).GetField("InternalField", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly).GetValue(blog)"""); + Field( + Parameter(typeof(Blog), "blog"), + "InternalField"), + "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog_InternalField_Get(blog)", unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "InternalField")] +private static extern int UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_Blog_InternalField_Get(LinqToCSharpSyntaxTranslatorTest.Blog instance); +""", + Assert.Single(unsafeAccessors), + ignoreLineEndingDifferences: true)); [Fact] public void Not() @@ -429,7 +462,7 @@ public void Method_call_namespace_is_collected() { var (translator, _) = CreateTranslator(); var namespaces = new HashSet(); - _ = translator.TranslateExpression(Call(FooMethod), null, namespaces); + _ = translator.TranslateExpression(Call(FooMethod), null, namespaces, new HashSet()); Assert.Collection( namespaces, ns => Assert.Equal(typeof(LinqToCSharpSyntaxTranslatorTest).Namespace, ns)); @@ -448,9 +481,9 @@ public void Method_call_with_in_out_ref_parameters() Call(WithInOutRefParameterMethod, [inParam, outParam, refParam])), """ { - int inParam; - int outParam; - int refParam; + int inParam = default; + int outParam = default; + int refParam = default; LinqToCSharpSyntaxTranslatorTest.WithInOutRefParameter(in inParam, out outParam, ref refParam); } """); @@ -467,19 +500,36 @@ public void Instantiation() [Fact] public void Instantiation_with_required_properties_and_parameterless_constructor() => AssertExpression( - New( - typeof(BlogWithRequiredProperties).GetConstructor([])!), - """ -Activator.CreateInstance() -"""); + New(typeof(BlogWithRequiredProperties).GetConstructor([])!), + "UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor()", + unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Constructor)] +private static extern LinqToCSharpSyntaxTranslatorTest.BlogWithRequiredProperties UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor(); +""", + Assert.Single(unsafeAccessors), + ignoreLineEndingDifferences: true)); + +// => AssertExpression( +// New(typeof(BlogWithRequiredProperties).GetConstructor([])!), +// """ +// Activator.CreateInstance() +// """); [Fact] public void Instantiation_with_required_properties_and_non_parameterless_constructor() - => Assert.Throws( - () => AssertExpression( - New( - typeof(BlogWithRequiredProperties).GetConstructor([typeof(string)])!, - Constant("foo")), "")); + => AssertExpression( + New( + typeof(BlogWithRequiredProperties).GetConstructor([typeof(string)])!, + Constant("foo")), + """UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor("foo")""", + unsafeAccessorsAsserter: unsafeAccessors => Assert.Equal( + """ +[UnsafeAccessor(UnsafeAccessorKind.Constructor)] +private static extern LinqToCSharpSyntaxTranslatorTest.BlogWithRequiredProperties UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_BlogWithRequiredProperties_Ctor(string name); +""", + Assert.Single(unsafeAccessors), + ignoreLineEndingDifferences: true)); [Fact] public void Instantiation_with_required_properties_with_SetsRequiredMembers() @@ -571,6 +621,17 @@ public void Invocation_with_argument_that_has_side_effects() """); } + [Fact] + public void Invocation_with_property_argument() + => AssertExpression( + Invoke( + Property( + expression: null, + typeof(LinqToCSharpSyntaxTranslatorTest).GetProperty( + nameof(LambdaExpressionProperty), BindingFlags.Public | BindingFlags.Static)!), + Constant(8)), + "LinqToCSharpSyntaxTranslatorTest.LambdaExpressionProperty(8)"); + [Fact] public void Conditional_expression() => AssertExpression( @@ -669,7 +730,7 @@ public void IfThenElse_nested() Block(Assign(variable, Constant(3)))))), """ { - int i; + int i = default; if (true) { i = 1; @@ -788,7 +849,7 @@ public void Switch_expression_nested() Constant(200))))), """ { - int k; + int k = default; var j = 8; var i = j switch { @@ -852,7 +913,7 @@ public void Switch_statement_without_default() SwitchCase(Block(typeof(void), Assign(parameter, Constant(10))), Constant(-10)))), """ { - int i; + int i = default; switch (7) { case -9: @@ -886,7 +947,7 @@ public void Switch_statement_with_default() SwitchCase(Assign(parameter, Constant(10)), Constant(-10)))), """ { - int i; + int i = default; switch (7) { case -9: @@ -918,7 +979,7 @@ public void Switch_statement_with_multiple_labels() SwitchCase(Assign(parameter, Constant(10)), Constant(-10)))), """ { - int i; + int i = default; switch (7) { case -9: @@ -1081,8 +1142,7 @@ public void Same_parameter_instance_is_used_twice_in_nested_lambdas() [Fact] public void Block_with_non_standalone_expression_as_statement() - => AssertStatement( - Block(Add(Constant(1), Constant(2))), + => AssertStatement(Block(Add(Constant(1), Constant(2))), """ { _ = 1 + 2; @@ -1337,7 +1397,7 @@ public void Lift_variable_in_expression_block() Constant(9))))), """ { - int j; + int j = default; LinqToCSharpSyntaxTranslatorTest.Foo(); j = 8; var i = 9; @@ -1420,7 +1480,7 @@ public void Lift_switch_expression() SwitchCase(Constant(2), Constant(9))))), """ { - int i; + int i = default; var j = 8; switch (j) { @@ -1474,8 +1534,8 @@ public void Lift_nested_switch_expression() Constant(200))))), """ { - int i; - int k; + int i = default; + int k = default; var j = 8; switch (j) { @@ -1538,7 +1598,7 @@ public void Lift_non_literal_switch_expression() SwitchCase(Constant(3), Parameter(typeof(Blog), "blog4"))))), """ { - int i; + int i = default; if (blog1 == blog2) { LinqToCSharpSyntaxTranslatorTest.ReturnsIntWithParam(8); @@ -1813,7 +1873,7 @@ public void Try_catch_finally_statement() { LinqToCSharpSyntaxTranslatorTest.Bar(); } -catch (InvalidOperationException e)when (e.Message == "foo") +catch (InvalidOperationException e)when (((Exception)e).Message == "foo") { LinqToCSharpSyntaxTranslatorTest.Baz(); } @@ -1844,7 +1904,7 @@ public void Try_catch_statement_with_filter() { LinqToCSharpSyntaxTranslatorTest.Foo(); } -catch (InvalidOperationException e)when (e.Message == "foo") +catch (InvalidOperationException e)when (((Exception)e).Message == "foo") { LinqToCSharpSyntaxTranslatorTest.Bar(); } @@ -1889,19 +1949,29 @@ public void Try_fault_statement() // TODO: try/catch expressions - private void AssertStatement(Expression expression, string expected, + private void AssertStatement( + Expression expression, + string expected, Dictionary? constantReplacements = null, - Dictionary? memberAccessReplacements = null) - => AssertCore(expression, isStatement: true, expected, constantReplacements, memberAccessReplacements); + Dictionary? memberAccessReplacements = null, + Action>? unsafeAccessorsAsserter = null) + => AssertCore(expected, isStatement: true, expression, constantReplacements, memberAccessReplacements, unsafeAccessorsAsserter); - private void AssertExpression(Expression expression, string expected, + private void AssertExpression( + Expression expression, + string expected, Dictionary? constantReplacements = null, - Dictionary? memberAccessReplacements = null) - => AssertCore(expression, isStatement: false, expected, constantReplacements, memberAccessReplacements); - - private void AssertCore(Expression expression, bool isStatement, string expected, + Dictionary? memberAccessReplacements = null, + Action>? unsafeAccessorsAsserter = null) + => AssertCore(expected, isStatement: false, expression, constantReplacements, memberAccessReplacements, unsafeAccessorsAsserter); + + private void AssertCore( + string expected, + bool isStatement, + Expression expression, Dictionary? constantReplacements, - Dictionary? memberAccessReplacements) + Dictionary? memberAccessReplacements, + Action>? unsafeAccessorsAsserter) { var typeMappingSource = new SqlServerTypeMappingSource( TestServiceFactory.Instance.Create(), @@ -1909,14 +1979,15 @@ public void Try_fault_statement() var translator = new CSharpHelper(typeMappingSource); var namespaces = new HashSet(); + var unsafeAccessors = new HashSet(); var actual = isStatement - ? translator.Statement(expression, namespaces, constantReplacements, memberAccessReplacements) - : translator.Expression(expression, namespaces, constantReplacements, memberAccessReplacements); + ? translator.Statement(expression, namespaces, unsafeAccessors, constantReplacements, memberAccessReplacements) + : translator.Expression(expression, namespaces, unsafeAccessors, constantReplacements, memberAccessReplacements); if (_outputExpressionTrees) { - _testOutputHelper.WriteLine("---- Input LINQ expression tree:"); - _testOutputHelper.WriteLine(_expressionPrinter.PrintExpression(expression)); + testOutputHelper.WriteLine("---- Input LINQ expression tree:"); + testOutputHelper.WriteLine(_expressionPrinter.PrintExpression(expression)); } // TODO: Actually compile the output C# code to make sure it's valid. @@ -1928,17 +1999,26 @@ public void Try_fault_statement() if (_outputExpressionTrees) { - _testOutputHelper.WriteLine("---- Output Roslyn syntax tree:"); - _testOutputHelper.WriteLine(actual); + testOutputHelper.WriteLine("---- Output Roslyn syntax tree:"); + testOutputHelper.WriteLine(actual); } } catch (EqualException) { - _testOutputHelper.WriteLine("---- Output Roslyn syntax tree:"); - _testOutputHelper.WriteLine(actual); + testOutputHelper.WriteLine("---- Output Roslyn syntax tree:"); + testOutputHelper.WriteLine(actual); throw; } + + if (unsafeAccessorsAsserter is null) + { + Assert.Empty(unsafeAccessors); + } + else + { + unsafeAccessorsAsserter(unsafeAccessors); + } } private (LinqToCSharpSyntaxTranslator, AdhocWorkspace) CreateTranslator() @@ -1993,6 +2073,7 @@ public static int Baz() public static int MethodWithSixParams(int a, int b, int c, int d, int e, int f) => a + b + c + d + e + f; + public static Expression> LambdaExpressionProperty => f => f > 5; private static readonly FieldInfo BlogPrivateField = typeof(Blog).GetField("_privateField", BindingFlags.NonPublic | BindingFlags.Instance)!; diff --git a/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj b/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj index c4b531fd2ad..b07088761fb 100644 --- a/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj +++ b/test/EFCore.Relational.Specification.Tests/EFCore.Relational.Specification.Tests.csproj @@ -48,7 +48,11 @@ - + + + + + diff --git a/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs new file mode 100644 index 00000000000..dd352965ca3 --- /dev/null +++ b/test/EFCore.Relational.Specification.Tests/Query/AdHocPrecompiledQueryRelationalTestBase.cs @@ -0,0 +1,247 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using Microsoft.EntityFrameworkCore.Query.Internal; +using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers; + +namespace Microsoft.EntityFrameworkCore.Query; + +public abstract class AdHocPrecompiledQueryRelationalTestBase(ITestOutputHelper testOutputHelper) : NonSharedModelTestBase +{ + [ConditionalFact] + public virtual async Task Index_no_evaluatability() + { + var contextFactory = await InitializeAsync(); + var options = contextFactory.GetOptions(); + + await Test( + """ +await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions); +await context.Database.BeginTransactionAsync(); + +var blogs = context.JsonEntities.Where(b => b.IntList[b.Id] == 2).ToList(); +""", + typeof(JsonContext), + options); + } + + [ConditionalFact] + public virtual async Task Index_with_captured_variable() + { + var contextFactory = await InitializeAsync(); + var options = contextFactory.GetOptions(); + + await Test( + """ +await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions); +await context.Database.BeginTransactionAsync(); + +var id = 1; +var blogs = context.JsonEntities.Where(b => b.IntList[id] == 2).ToList(); +""", + typeof(JsonContext), + options); + } + + [ConditionalFact] + public virtual async Task JsonScalar() + { + var contextFactory = await InitializeAsync(); + var options = contextFactory.GetOptions(); + + await Test( + """ +await using var context = new AdHocPrecompiledQueryRelationalTestBase.JsonContext(dbContextOptions); +await context.Database.BeginTransactionAsync(); + +_ = context.JsonEntities.Where(b => b.JsonThing.StringProperty == "foo").ToList(); +""", + typeof(JsonContext), + options); + } + + public class JsonContext(DbContextOptions options) : DbContext(options) + { + public DbSet JsonEntities { get; set; } = null!; + + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity().OwnsOne(j => j.JsonThing, n => n.ToJson()); + } + + public class JsonEntity + { + public int Id { get; set; } + public List IntList { get; set; } = null!; + public JsonThing JsonThing { get; set; } = null!; + } + + public class JsonThing + { + public string StringProperty { get; set; } = null!; + } + + [ConditionalFact] + public virtual async Task Materialize_non_public() + { + var contextFactory = await InitializeAsync(); + var options = contextFactory.GetOptions(); + + await Test( + """ +await using var context = new AdHocPrecompiledQueryRelationalTestBase.NonPublicContext(dbContextOptions); + +var nonPublicEntity = (AdHocPrecompiledQueryRelationalTestBase.NonPublicEntity)Activator.CreateInstance(typeof(AdHocPrecompiledQueryRelationalTestBase.NonPublicEntity), nonPublic: true); +nonPublicEntity.PrivateFieldExposer = 8; +nonPublicEntity.PrivatePropertyExposer = 9; +nonPublicEntity.PrivateAutoPropertyExposer = 10; +context.NonPublicEntities.Add(nonPublicEntity); +await context.SaveChangesAsync(); + +context.ChangeTracker.Clear(); + +var e = await context.NonPublicEntities.SingleAsync(); +Assert.Equal(8, e.PrivateFieldExposer); +Assert.Equal(9, e.PrivatePropertyExposer); +Assert.Equal(10, e.PrivateAutoPropertyExposer); +""", + typeof(NonPublicContext), + options, + interceptorCodeAsserter: code => + { + Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_privateField")]""", code); + Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]""", code); + Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")]""", code); + Assert.Contains("""[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "set_PrivateProperty")]""", code); + + Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity__privateField(instance) =", code); + Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity_PrivateAutoProperty(instance) =", code); + Assert.Contains("UnsafeAccessor_Microsoft_EntityFrameworkCore_Query_NonPublicEntity_set_PrivateProperty(instance,", code); + }); + } + + public class NonPublicContext(DbContextOptions options) : DbContext(options) + { + public DbSet NonPublicEntities { get; set; } = null!; + + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity( + b => + { + b.Property("_privateField"); + b.Property("PrivateProperty"); + b.Property("PrivateAutoProperty"); + b.Ignore(b => b.PrivateFieldExposer); + b.Ignore(b => b.PrivatePropertyExposer); + b.Ignore(b => b.PrivateAutoPropertyExposer); + }); + } + +#pragma warning disable CS0169 +#pragma warning disable CS0649 + public class NonPublicEntity + { + private NonPublicEntity() + { + } + + public int Id { get; set; } + + private int? _privateField; + + // ReSharper disable once ConvertToAutoProperty + private int? PrivateProperty + { + get => _privatePropertyBackingField; + set => _privatePropertyBackingField = value; + } + private int? _privatePropertyBackingField; + + private int? PrivateAutoProperty { get; set; } + + // ReSharper disable once ConvertToAutoProperty + public int? PrivateFieldExposer + { + get => _privateField; + set => _privateField = value; + } + + public int? PrivatePropertyExposer + { + get => PrivateProperty; + set => PrivateProperty = value; + } + + public int? PrivateAutoPropertyExposer + { + get => PrivateAutoProperty; + set => PrivateAutoProperty = value; + } + } +#pragma warning restore CS0649 +#pragma warning restore CS0169 + +// [ConditionalFact] +// public virtual Task JsonScalar() +// => Test( +// // TODO: Remove Select() to Id after JSON is supported in materialization +// """_ = context.Blogs.Where(b => b.JsonThing.SomeProperty == "foo").Select(b => b.Id).ToList();""", +// modelSourceCode: providerOptions => $$""" +// public class BlogContext : DbContext +// { +// public DbSet Blogs { get; set; } +// +// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) +// => optionsBuilder +// {{providerOptions}} +// .ReplaceService(); +// +// protected override void OnModelCreating(ModelBuilder modelBuilder) +// => modelBuilder.Entity().OwnsOne(b => b.JsonThing, n => n.ToJson()); +// } +// +// public class Blog +// { +// public int Id { get; set; } +// public JsonThing JsonThing { get; set; } +// } +// +// public class JsonThing +// { +// public string SomeProperty { get; set; } +// } +// """); + + protected TestSqlLoggerFactory TestSqlLoggerFactory + => (TestSqlLoggerFactory)ListLoggerFactory; + + protected void ClearLog() + => TestSqlLoggerFactory.Clear(); + + protected void AssertSql(params string[] expected) + => TestSqlLoggerFactory.AssertBaseline(expected); + + protected virtual Task Test( + string sourceCode, + Type dbContextType, + DbContextOptions dbContextOptions, + Action? interceptorCodeAsserter = null, + Action>? precompilationErrorAsserter = null, + [CallerMemberName] string callerName = "") + => PrecompiledQueryTestHelpers.Test( + sourceCode, dbContextOptions, dbContextType, interceptorCodeAsserter, precompilationErrorAsserter, testOutputHelper, + AlwaysPrintGeneratedSources, + callerName); + + protected virtual bool AlwaysPrintGeneratedSources + => false; + + protected abstract PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers { get; } + + protected override IServiceCollection AddServices(IServiceCollection serviceCollection) + => base.AddServices(serviceCollection) + .AddScoped(); + + protected override string StoreName + => "AdHocPrecompiledQueryTest"; +} diff --git a/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs new file mode 100644 index 00000000000..d0b2cd519dc --- /dev/null +++ b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalFixture.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.Internal; +using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers; +using Blog = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Blog; +namespace Microsoft.EntityFrameworkCore.Query; + +public abstract class PrecompiledQueryRelationalFixture + : SharedStoreFixtureBase, ITestSqlLoggerFactory +{ + protected override string StoreName + => "PrecompiledQueryTest"; + + public TestSqlLoggerFactory TestSqlLoggerFactory + => (TestSqlLoggerFactory)ListLoggerFactory; + + protected override IServiceCollection AddServices(IServiceCollection serviceCollection) + => base.AddServices(serviceCollection) + .AddScoped(); + + protected override async Task SeedAsync(PrecompiledQueryRelationalTestBase.PrecompiledQueryContext context) + { + context.Blogs.AddRange( + new Blog { Id = 8, Name = "Blog1" }, + new Blog { Id = 9, Name = "Blog2" }); + await context.SaveChangesAsync(); + } + + public abstract PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers { get; } +} diff --git a/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs new file mode 100644 index 00000000000..77acb29a6ee --- /dev/null +++ b/test/EFCore.Relational.Specification.Tests/Query/PrecompiledQueryRelationalTestBase.cs @@ -0,0 +1,1134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations.Schema; +using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; +using Xunit.Sdk; +using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers; + +namespace Microsoft.EntityFrameworkCore.Query; + +// ReSharper disable InconsistentNaming + +public class PrecompiledQueryRelationalTestBase +{ + public PrecompiledQueryRelationalTestBase(PrecompiledQueryRelationalFixture fixture, ITestOutputHelper testOutputHelper) + { + Fixture = fixture; + TestOutputHelper = testOutputHelper; + + Fixture.TestSqlLoggerFactory.Clear(); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + } + + #region Expression types + + [ConditionalFact] + public virtual Task BinaryExpression() + => Test(""" +var id = 3; +var blogs = await context.Blogs.Where(b => b.Id > id).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Conditional_no_evaluatable() + => Test(""" +var id = 3; +var blogs = await context.Blogs.Select(b => b.Id == 2 ? "yes" : "no").ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Conditional_contains_captured_variable() + => Test(""" +var yes = "yes"; +var blogs = await context.Blogs.Select(b => b.Id == 2 ? yes : "no").ToListAsync(); +"""); + + // We do not support embedding Expression builder API calls into the query; this would require CSharpToLinqTranslator to actually + // evaluate those APIs and embed the results into the tree. It's (at least potentially) a form of dynamic query, unsupported for now. + [ConditionalFact] + public virtual Task Invoke_no_evaluatability_is_not_supported() + => Test( + """ +Expression> lambda = b => b.Name == "foo"; +var parameter = Expression.Parameter(typeof(Blog), "b"); + +var blogs = await context.Blogs + .Where(Expression.Lambda>(Expression.Invoke(lambda, parameter), parameter)) + .ToListAsync(); +""", + errorAsserter: errors => Assert.IsType(errors.Single().Exception)); + + [ConditionalFact] + public virtual Task ListInit_no_evaluatability() + => Test("_ = await context.Blogs.Select(b => new List { b.Id, b.Id + 1 }).ToListAsync();"); + + [ConditionalFact] + public virtual Task ListInit_with_evaluatable_with_captured_variable() + => Test( + """ +var i = 1; +_ = await context.Blogs.Select(b => new List { b.Id, i }).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task ListInit_with_evaluatable_without_captured_variable() + => Test( + """ +var i = 1; +_ = await context.Blogs.Select(b => new List { b.Id, 8 }).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task ListInit_fully_evaluatable() + => Test(""" +var blog = await context.Blogs.Where(b => new List { 7, 8 }.Contains(b.Id)).SingleAsync(); +Assert.Equal("Blog1", blog.Name); +"""); + + [ConditionalFact] + public virtual Task MethodCallExpression_no_evaluatability() + => Test("_ = await context.Blogs.Where(b => b.Name.StartsWith(b.Name)).ToListAsync();"); + + [ConditionalFact] + public virtual Task MethodCallExpression_with_evaluatable_with_captured_variable() + => Test(""" +var pattern = "foo"; +_ = await context.Blogs.Where(b => b.Name.StartsWith(pattern)).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task MethodCallExpression_with_evaluatable_without_captured_variable() + => Test("""_ = await context.Blogs.Where(b => b.Name.StartsWith("foo")).ToListAsync();"""); + + [ConditionalFact] + public virtual Task MethodCallExpression_fully_evaluatable() + => Test("""_ = await context.Blogs.Where(b => "foobar".StartsWith("foo")).ToListAsync();"""); + + [ConditionalFact] + public virtual Task New_with_no_arguments() + => Test( + """ +var i = 8; +_ = await context.Blogs.Where(b => b == new Blog()).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Where_New_with_captured_variable() + => Test( + """ +var i = 8; +_ = await context.Blogs.Where(b => b == new Blog(i, b.Name)).ToListAsync(); +""", + errorAsserter: errors => Assert.StartsWith("Translation of", errors.Single().Exception.Message)); + + [ConditionalFact] + public virtual Task Select_New_with_captured_variable() + => Test( + """ +var i = 8; +_ = await context.Blogs.Select(b => new Blog(i, b.Name)).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task MemberInit_no_evaluatable() + => Test("_ = await context.Blogs.Select(b => new Blog { Id = b.Id, Name = b.Name }).ToListAsync();"); + + [ConditionalFact] + public virtual Task MemberInit_contains_captured_variable() + => Test( + """ +var id = 8; +_ = await context.Blogs.Select(b => new Blog { Id = id, Name = b.Name }).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task MemberInit_evaluatable_as_constant() + => Test("""_ = await context.Blogs.Select(b => new Blog { Id = 1, Name = "foo" }).ToListAsync();"""); + + [ConditionalFact] + public virtual Task MemberInit_evaluatable_as_parameter() + => Test( + """ +var id = 8; +var foo = "foo"; +_ = await context.Blogs.Select(b => new Blog { Id = id, Name = foo }).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task NewArray() + => Test( + """ +var i = 8; +_ = await context.Blogs.Select(b => new[] { b.Id, b.Id + i }).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Unary() + => Test("_ = await context.Blogs.Where(b => (short)b.Id == (short)8).ToListAsync();"); + + #endregion Expression types + + #region Terminating operators + + [ConditionalFact] + public virtual Task Terminating_AsEnumerable() + => Test(""" +var blogs = context.Blogs.AsEnumerable().ToList(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_AsAsyncEnumerable_on_DbSet() + => Test(""" +var sum = 0; +await foreach (var blog in context.Blogs.AsAsyncEnumerable()) +{ + sum += blog.Id; +} +Assert.Equal(17, sum); +"""); + + [ConditionalFact] + public virtual Task Terminating_AsAsyncEnumerable_on_IQueryable() + => Test(""" +var sum = 0; +await foreach (var blog in context.Blogs.Where(b => b.Id > 8).AsAsyncEnumerable()) +{ + sum += blog.Id; +} +Assert.Equal(9, sum); +"""); + + [ConditionalFact] + public virtual Task Foreach_sync_over_operator() + => Test( + """ +foreach (var blog in context.Blogs.Where(b => b.Id > 8)) +{ +} +"""); + + [ConditionalFact] + public virtual Task Terminating_ToArray() + => Test( + """ +var blogs = context.Blogs.ToArray(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToArrayAsync() + => Test( + """ +var blogs = await context.Blogs.ToArrayAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToDictionary() + => Test( + """ +var blogs = context.Blogs.ToDictionary(b => b.Id, b => b.Name); +Assert.Equal(2, blogs.Count); +Assert.Equal("Blog1", blogs[8]); +Assert.Equal("Blog2", blogs[9]); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToDictionaryAsync() + => Test( + """ +var blogs = await context.Blogs.ToDictionaryAsync(b => b.Id, b => b.Name); +Assert.Equal(2, blogs.Count); +Assert.Equal("Blog1", blogs[8]); +Assert.Equal("Blog2", blogs[9]); +"""); + + [ConditionalFact] + public virtual Task ToDictionary_over_anonymous_type() + => Test("_ = context.Blogs.Select(b => new { b.Id, b.Name }).ToDictionary(x => x.Id, x => x.Name);"); + + [ConditionalFact] + public virtual Task ToDictionaryAsync_over_anonymous_type() + => Test("_ = await context.Blogs.Select(b => new { b.Id, b.Name }).ToDictionaryAsync(x => x.Id, x => x.Name);"); + + [ConditionalFact] + public virtual Task Terminating_ToHashSet() + => Test( + """ +var blogs = context.Blogs.ToHashSet(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToHashSetAsync() + => Test( + """ +var blogs = await context.Blogs.ToHashSetAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToLookup() + => Test("_ = context.Blogs.ToLookup(b => b.Name);"); + + [ConditionalFact] + public virtual Task Terminating_ToList() + => Test( + """ +var blogs = context.Blogs.ToList(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ToListAsync() + => Test( + """ +var blogs = await context.Blogs.ToListAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + // foreach/await foreach directly over DbSet properties doesn't isn't supported, since we can't intercept property accesses. + [ConditionalFact] + public virtual async Task Foreach_sync_over_DbSet_property_is_not_supported() + { + // TODO: Assert diagnostics about non-intercepted query + var exception = await Assert.ThrowsAsync( + () => Test( + """ +foreach (var blog in context.Blogs) +{ +} +""")); + Assert.Equal(NonCompilingQueryCompiler.ErrorMessage, exception.Message); + } + + // foreach/await foreach directly over DbSet properties doesn't isn't supported, since we can't intercept property accesses. + [ConditionalFact] + public virtual async Task Foreach_async_is_not_supported() + { + // TODO: Assert diagnostics about non-intercepted query + var exception = await Assert.ThrowsAsync( + () => Test( + """ +await foreach (var blog in context.Blogs) +{ +} +""")); + Assert.Equal(NonCompilingQueryCompiler.ErrorMessage, exception.Message); + } + + #endregion Terminating operators + + #region Reducing terminating operators + + [ConditionalFact] + public virtual Task Terminating_All() + => Test( + """ +Assert.True(context.Blogs.All(b => b.Id > 7)); +Assert.False(context.Blogs.All(b => b.Id > 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_AllAsync() + => Test( + """ +Assert.True(await context.Blogs.AllAsync(b => b.Id > 7)); +Assert.False(await context.Blogs.AllAsync(b => b.Id > 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Any() + => Test( + """ +Assert.True(context.Blogs.Where(b => b.Id > 7).Any()); +Assert.False(context.Blogs.Where(b => b.Id < 7).Any()); + +Assert.True(context.Blogs.Any(b => b.Id > 7)); +Assert.False(context.Blogs.Any(b => b.Id < 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_AnyAsync() + => Test( + """ +Assert.True(await context.Blogs.Where(b => b.Id > 7).AnyAsync()); +Assert.False(await context.Blogs.Where(b => b.Id < 7).AnyAsync()); + +Assert.True(await context.Blogs.AnyAsync(b => b.Id > 7)); +Assert.False(await context.Blogs.AnyAsync(b => b.Id < 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Average() + => Test( + """ +Assert.Equal(8.5, context.Blogs.Select(b => b.Id).Average()); +Assert.Equal(8.5, context.Blogs.Average(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_AverageAsync() + => Test( + """ +Assert.Equal(8.5, await context.Blogs.Select(b => b.Id).AverageAsync()); +Assert.Equal(8.5, await context.Blogs.AverageAsync(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Contains() + => Test( + """ +Assert.True(context.Blogs.Select(b => b.Id).Contains(8)); +Assert.False(context.Blogs.Select(b => b.Id).Contains(7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ContainsAsync() + => Test( + """ +Assert.True(await context.Blogs.Select(b => b.Id).ContainsAsync(8)); +Assert.False(await context.Blogs.Select(b => b.Id).ContainsAsync(7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Count() + => Test( + """ +Assert.Equal(2, context.Blogs.Count()); +Assert.Equal(1, context.Blogs.Count(b => b.Id > 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_CountAsync() + => Test( + """ +Assert.Equal(2, await context.Blogs.CountAsync()); +Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id > 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ElementAt() + => Test( + """ +Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).ElementAt(1).Name); +Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).ElementAt(3)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ElementAtAsync() + => Test( + """ +Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).ElementAtAsync(1)).Name); +await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).ElementAtAsync(3)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ElementAtOrDefault() + => Test( + """ +Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).ElementAtOrDefault(1).Name); +Assert.Null(context.Blogs.OrderBy(b => b.Id).ElementAtOrDefault(3)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ElementAtOrDefaultAsync() + => Test( + """ +Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).ElementAtOrDefaultAsync(1)).Name); +Assert.Null(await context.Blogs.OrderBy(b => b.Id).ElementAtOrDefaultAsync(3)); +"""); + + [ConditionalFact] + public virtual Task Terminating_First() + => Test( + """ +Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).First().Name); +Assert.Throws(() => context.Blogs.Where(b => b.Id == 7).First()); + +Assert.Equal("Blog1", context.Blogs.First(b => b.Id == 8).Name); +Assert.Throws(() => context.Blogs.First(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_FirstAsync() + => Test( + """ +Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).FirstAsync()).Name); +await Assert.ThrowsAsync(() => context.Blogs.Where(b => b.Id == 7).FirstAsync()); + +Assert.Equal("Blog1", (await context.Blogs.FirstAsync(b => b.Id == 8)).Name); +await Assert.ThrowsAsync(() => context.Blogs.FirstAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_FirstOrDefault() + => Test( + """ +Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).FirstOrDefault().Name); +Assert.Null(context.Blogs.Where(b => b.Id == 7).FirstOrDefault()); + +Assert.Equal("Blog1", context.Blogs.FirstOrDefault(b => b.Id == 8).Name); +Assert.Null(context.Blogs.FirstOrDefault(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_FirstOrDefaultAsync() + => Test( + """ +Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).FirstOrDefaultAsync()).Name); +Assert.Null(await context.Blogs.Where(b => b.Id == 7).FirstOrDefaultAsync()); + +Assert.Equal("Blog1", (await context.Blogs.FirstOrDefaultAsync(b => b.Id == 8)).Name); +Assert.Null(await context.Blogs.FirstOrDefaultAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_GetEnumerator() + => Test( + """ +using var enumerator = context.Blogs.Where(b => b.Id == 8).GetEnumerator(); +Assert.True(enumerator.MoveNext()); +Assert.Equal("Blog1", enumerator.Current.Name); +Assert.False(enumerator.MoveNext()); +"""); + + [ConditionalFact] + public virtual Task Terminating_Last() + => Test( + """ +Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).Last().Name); +Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).Last()); + +Assert.Equal("Blog1", context.Blogs.OrderBy(b => b.Id).Last(b => b.Id == 8).Name); +Assert.Throws(() => context.Blogs.OrderBy(b => b.Id).Last(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_LastAsync() + => Test( + """ +Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).LastAsync()).Name); +await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastAsync()); + +Assert.Equal("Blog1", (await context.Blogs.OrderBy(b => b.Id).LastAsync(b => b.Id == 8)).Name); +await Assert.ThrowsAsync(() => context.Blogs.OrderBy(b => b.Id).LastAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_LastOrDefault() + => Test( + """ +Assert.Equal("Blog2", context.Blogs.OrderBy(b => b.Id).LastOrDefault().Name); +Assert.Null(context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastOrDefault()); + +Assert.Equal("Blog1", context.Blogs.OrderBy(b => b.Id).LastOrDefault(b => b.Id == 8).Name); +Assert.Null(context.Blogs.OrderBy(b => b.Id).LastOrDefault(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_LastOrDefaultAsync() + => Test( + """ +Assert.Equal("Blog2", (await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync()).Name); +Assert.Null(await context.Blogs.OrderBy(b => b.Id).Where(b => b.Id == 7).LastOrDefaultAsync()); + +Assert.Equal("Blog1", (await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync(b => b.Id == 8)).Name); +Assert.Null(await context.Blogs.OrderBy(b => b.Id).LastOrDefaultAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_LongCount() + => Test( + """ +Assert.Equal(2, context.Blogs.LongCount()); +Assert.Equal(1, context.Blogs.LongCount(b => b.Id == 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_LongCountAsync() + => Test( + """ +Assert.Equal(2, await context.Blogs.LongCountAsync()); +Assert.Equal(1, await context.Blogs.LongCountAsync(b => b.Id == 8)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Max() + => Test( + """ +Assert.Equal(9, context.Blogs.Select(b => b.Id).Max()); +Assert.Equal(9, context.Blogs.Max(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_MaxAsync() + => Test( + """ +Assert.Equal(9, await context.Blogs.Select(b => b.Id).MaxAsync()); +Assert.Equal(9, await context.Blogs.MaxAsync(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Min() + => Test( + """ +Assert.Equal(8, context.Blogs.Select(b => b.Id).Min()); +Assert.Equal(8, context.Blogs.Min(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_MinAsync() + => Test( + """ +Assert.Equal(8, await context.Blogs.Select(b => b.Id).MinAsync()); +Assert.Equal(8, await context.Blogs.MinAsync(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Single() + => Test( + """ +Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).Single().Name); +Assert.Throws(() => context.Blogs.Where(b => b.Id == 7).Single()); + +Assert.Equal("Blog1", context.Blogs.Single(b => b.Id == 8).Name); +Assert.Throws(() => context.Blogs.Single(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_SingleAsync() + => Test( + """ +Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).SingleAsync()).Name); +await Assert.ThrowsAsync(() => context.Blogs.Where(b => b.Id == 7).SingleAsync()); + +Assert.Equal("Blog1", (await context.Blogs.SingleAsync(b => b.Id == 8)).Name); +await Assert.ThrowsAsync(() => context.Blogs.SingleAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_SingleOrDefault() + => Test( + """ +Assert.Equal("Blog1", context.Blogs.Where(b => b.Id == 8).SingleOrDefault().Name); +Assert.Null(context.Blogs.Where(b => b.Id == 7).SingleOrDefault()); + +Assert.Equal("Blog1", context.Blogs.SingleOrDefault(b => b.Id == 8).Name); +Assert.Null(context.Blogs.SingleOrDefault(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_SingleOrDefaultAsync() + => Test( + """ +Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).SingleOrDefaultAsync()).Name); +Assert.Null(await context.Blogs.Where(b => b.Id == 7).SingleOrDefaultAsync()); + +Assert.Equal("Blog1", (await context.Blogs.SingleOrDefaultAsync(b => b.Id == 8)).Name); +Assert.Null(await context.Blogs.SingleOrDefaultAsync(b => b.Id == 7)); +"""); + + [ConditionalFact] + public virtual Task Terminating_Sum() + => Test( + """ +Assert.Equal(17, context.Blogs.Select(b => b.Id).Sum()); +Assert.Equal(17, context.Blogs.Sum(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_SumAsync() + => Test( + """ +Assert.Equal(17, await context.Blogs.Select(b => b.Id).SumAsync()); +Assert.Equal(17, await context.Blogs.SumAsync(b => b.Id)); +"""); + + [ConditionalFact] + public virtual Task Terminating_ExecuteDelete() + => Test( + """ +await context.Database.BeginTransactionAsync(); + +var rowsAffected = context.Blogs.Where(b => b.Id > 8).ExecuteDelete(); +Assert.Equal(1, rowsAffected); +Assert.Equal(1, await context.Blogs.CountAsync()); +"""); + + [ConditionalFact] + public virtual Task Terminating_ExecuteDeleteAsync() + => Test( + """ +await context.Database.BeginTransactionAsync(); + +var rowsAffected = await context.Blogs.Where(b => b.Id > 8).ExecuteDeleteAsync(); +Assert.Equal(1, rowsAffected); +Assert.Equal(1, await context.Blogs.CountAsync()); +"""); + + [ConditionalFact] + public virtual Task Terminating_ExecuteUpdate() + => Test( + """ +await context.Database.BeginTransactionAsync(); + +var suffix = "Suffix"; +var rowsAffected = context.Blogs.Where(b => b.Id > 8).ExecuteUpdate(setters => setters.SetProperty(b => b.Name, b => b.Name + suffix)); +Assert.Equal(1, rowsAffected); +Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id == 9 && b.Name == "Blog2Suffix")); +"""); + + [ConditionalFact] + public virtual Task Terminating_ExecuteUpdateAsync() + => Test( + """ +await context.Database.BeginTransactionAsync(); + +var suffix = "Suffix"; +var rowsAffected = await context.Blogs.Where(b => b.Id > 8).ExecuteUpdateAsync(setters => setters.SetProperty(b => b.Name, b => b.Name + suffix)); +Assert.Equal(1, rowsAffected); +Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id == 9 && b.Name == "Blog2Suffix")); +"""); + + #endregion Reducing terminating operators + + #region SQL expression quotability + + [ConditionalFact] + public virtual Task Union() + => Test( + """ +var blogs = await context.Blogs.Where(b => b.Id > 7) + .Union(context.Blogs.Where(b => b.Id < 10)) + .OrderBy(b => b.Id) + .ToListAsync(); + +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Concat() + => Test( + """ +var blogs = await context.Blogs.Where(b => b.Id > 7) + .Concat(context.Blogs.Where(b => b.Id < 10)) + .OrderBy(b => b.Id) + .ToListAsync(); + +Assert.Collection(blogs, + b => Assert.Equal(8, b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id), + b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Intersect() + => Test( + """ +var blogs = await context.Blogs.Where(b => b.Id > 7) + .Intersect(context.Blogs.Where(b => b.Id > 8)) + .OrderBy(b => b.Id) + .ToListAsync(); + +Assert.Collection(blogs, b => Assert.Equal(9, b.Id)); +"""); + + [ConditionalFact] + public virtual Task Except() + => Test( + """ +var blogs = await context.Blogs.Where(b => b.Id > 7) + .Except(context.Blogs.Where(b => b.Id > 8)) + .OrderBy(b => b.Id) + .ToListAsync(); + +Assert.Collection(blogs, b => Assert.Equal(8, b.Id)); +"""); + + [ConditionalFact] + public virtual Task ValuesExpression() + => Test("_ = await context.Blogs.Where(b => new[] { 7, b.Id }.Count(i => i > 8) == 2).ToListAsync();"); + + // Tests e.g. OPENJSON on SQL Server + [ConditionalFact] + public virtual Task Contains_with_parameterized_collection() + => Test( + """ +int[] ids = [1, 2, 3]; +_ = await context.Blogs.Where(b => ids.Contains(b.Id)).ToListAsync(); +"""); + + // TODO: SQL Server-specific + [ConditionalFact] + public virtual Task FromSqlRaw() + => Test("""_ = await context.Blogs.FromSqlRaw("SELECT * FROM Blogs").OrderBy(b => b.Id).ToListAsync();"""); + + [ConditionalFact] + public virtual Task FromSql_with_FormattableString_parameters() + => Test("""_ = await context.Blogs.FromSql($"SELECT * FROM Blogs WHERE Id > {8} AND Id < {9}").OrderBy(b => b.Id).ToListAsync();"""); + + #endregion SQL expression quotability + + #region Different DbContext expressions + + [ConditionalFact] + public virtual Task DbContext_as_local_variable() + => Test( + """ +var context2 = context; + +_ = await context2.Blogs.ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task DbContext_as_field() + => FullSourceTest( + """ +public static class TestContainer +{ + private static PrecompiledQueryContext _context; + + public static async Task Test(DbContextOptions dbContextOptions) + { + using (_context = new PrecompiledQueryContext(dbContextOptions)) + { + var blogs = await _context.Blogs.ToListAsync(); + Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); + } + } +} +"""); + + [ConditionalFact] + public virtual Task DbContext_as_property() + => FullSourceTest( + """ +public static class TestContainer +{ + private static PrecompiledQueryContext Context { get; set; } + + public static async Task Test(DbContextOptions dbContextOptions) + { + using (Context = new PrecompiledQueryContext(dbContextOptions)) + { + var blogs = await Context.Blogs.ToListAsync(); + Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); + } + } +} +"""); + + [ConditionalFact] + public virtual Task DbContext_as_captured_variable() + => Test( + """ +Func> foo = () => context.Blogs.ToList(); +_ = foo(); +"""); + + [ConditionalFact] + public virtual Task DbContext_as_method_invocation_result() + => FullSourceTest( + """ +public static class TestContainer +{ + private static PrecompiledQueryContext _context; + + public static async Task Test(DbContextOptions dbContextOptions) + { + using (_context = new PrecompiledQueryContext(dbContextOptions)) + { + var blogs = await GetContext().Blogs.ToListAsync(); + Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); + } + } + + private static PrecompiledQueryContext GetContext() + => _context; +} +"""); + + #endregion Different DbContext expressions + + #region Negative cases + + [ConditionalFact] + public virtual Task Dynamic_query_does_not_get_precompiled() + => Test( + """ +var query = context.Blogs; +var blogs = await query.ToListAsync(); +""", + errorAsserter: errors => + { + var dynamicQueryError = errors.Single(); + Assert.IsType(dynamicQueryError.Exception); + Assert.Equal(DesignStrings.DynamicQueryNotSupported, dynamicQueryError.Exception.Message); + Assert.Equal("query.ToListAsync()", dynamicQueryError.SyntaxNode.NormalizeWhitespace().ToFullString()); + }); + + [ConditionalFact] + public virtual Task ToList_over_objects_does_not_get_precompiled() + => Test( + """ +int[] numbers = [1, 2, 3]; +var lessNumbers = numbers.Where(i => i > 1).ToList(); +"""); + + [ConditionalFact] + public virtual async Task Query_compilation_failure() + => await Test( + "_ = await context.Blogs.Where(b => PrecompiledQueryRelationalTestBase.Untranslatable(b.Id) == 999).ToListAsync();", + errorAsserter: errors + => Assert.Contains( + CoreStrings.TranslationFailedWithDetails( + "DbSet()\n .Where(b => PrecompiledQueryRelationalTestBase.Untranslatable(b.Id) == 999)", + "Translation of method 'Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Untranslatable' failed. If this method can be mapped to your custom function, see https://go.microsoft.com/fwlink/?linkid=2132413 for more information."), + errors.Single().Exception.Message)); + + public static int Untranslatable(int foo) + => throw new InvalidOperationException(); + + [ConditionalFact] + public virtual Task EF_Constant_is_not_supported() + => Test( + "_ = await context.Blogs.Where(b => b.Id > EF.Constant(8)).ToListAsync();", + errorAsserter: errors + => Assert.Equal(CoreStrings.EFConstantNotSupportedInPrecompiledQueries, errors.Single().Exception.Message)); + + [ConditionalFact] + public virtual Task NotParameterizedAttribute_with_constant() + => Test( + """ +var blog = await context.Blogs.Where(b => EF.Property(b, "Name") == "Blog2").SingleAsync(); +Assert.Equal(9, blog.Id); +"""); + + [ConditionalFact] + public virtual Task NotParameterizedAttribute_is_not_supported_with_non_constant_argument() + => Test( + """ +var propertyName = "Name"; +var blog = await context.Blogs.Where(b => EF.Property(b, propertyName) == "Blog2").SingleAsync(); +""", + errorAsserter: errors + => Assert.Equal( + CoreStrings.NotParameterizedAttributeWithNonConstantNotSupportedInPrecompiledQueries("propertyName", "Property"), + errors.Single().Exception.Message)); + + [ConditionalFact] + public virtual Task Query_syntax_is_not_supported() + => Test( + """ +var id = 3; +var blogs = await ( + from b in context.Blogs + where b.Id > 8 + select b).ToListAsync(); +""", + errorAsserter: errors + => Assert.Equal(DesignStrings.QueryComprehensionSyntaxNotSupportedInPrecompiledQueries, errors.Single().Exception.Message)); + + #endregion Negative cases + + [ConditionalFact] + public virtual Task Select_changes_type() + => Test("_ = await context.Blogs.Select(b => b.Name).ToListAsync();"); + + [ConditionalFact] + public virtual Task OrderBy() + => Test("_ = await context.Blogs.OrderBy(b => b.Name).ToListAsync();"); + + [ConditionalFact] + public virtual Task Skip() + => Test("_ = await context.Blogs.OrderBy(b => b.Name).Skip(1).ToListAsync();"); + + [ConditionalFact] + public virtual Task Take() + => Test("_ = await context.Blogs.OrderBy(b => b.Name).Take(1).ToListAsync();"); + + [ConditionalFact] + public virtual Task Project_anonymous_object() + => Test("""_ = await context.Blogs.Select(b => new { Foo = b.Name + "Foo" }).ToListAsync();"""); + + [ConditionalFact] + public virtual Task Two_captured_variables_in_same_lambda() + => Test(""" +var yes = "yes"; +var no = "no"; +var blogs = await context.Blogs.Select(b => b.Id == 3 ? yes : no).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Two_captured_variables_in_different_lambdas() + => Test(""" +var starts = "blog"; +var ends = "2"; +var blog = await context.Blogs.Where(b => b.Name.StartsWith(starts)).Where(b => b.Name.EndsWith(ends)).SingleAsync(); +Assert.Equal(9, blog.Id); +"""); + + [ConditionalFact] + public virtual Task Same_captured_variable_twice_in_same_lambda() + => Test(""" +var foo = "X"; +var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo) && b.Name.EndsWith(foo)).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Same_captured_variable_twice_in_different_lambdas() + => Test(""" +var foo = "X"; +var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo)).Where(b => b.Name.EndsWith(foo)).ToListAsync(); +"""); + + [ConditionalFact] + public virtual Task Include_single() + => Test("var blogs = await context.Blogs.Include(b => b.Posts).Where(b => b.Id > 8).ToListAsync();"); + + [ConditionalFact] + public virtual Task Include_split() + => Test("var blogs = await context.Blogs.AsSplitQuery().Include(b => b.Posts).ToListAsync();"); + + [ConditionalFact] + public virtual Task Final_GroupBy() + => Test("""var blogs = await context.Blogs.GroupBy(b => b.Name).ToListAsync();"""); + + [ConditionalFact] + public virtual Task Multiple_queries_with_captured_variables() + => Test(""" +var id1 = 8; +var id2 = 9; +var blogs = await context.Blogs.Where(b => b.Id == id1 || b.Id == id2).ToListAsync(); +var blog1 = await context.Blogs.Where(b => b.Id == id1).SingleAsync(); +Assert.Collection( + blogs.OrderBy(b => b.Id), + b => Assert.Equal(8, b.Id), + b => Assert.Equal(9, b.Id)); +Assert.Equal("Blog1", blog1.Name); +"""); + + [ConditionalFact] + public virtual Task Unsafe_accessor_gets_generated_once_for_multiple_queries() + => Test(""" +var blogs1 = await context.Blogs.ToListAsync(); +var blogs2 = await context.Blogs.ToListAsync(); +""", + interceptorCodeAsserter: code => Assert.Equal(2, code.Split("GetSet_Microsoft_EntityFrameworkCore_Query_Blog_Id").Length)); + + public class PrecompiledQueryContext(DbContextOptions options) : DbContext(options) + { + public DbSet Blogs { get; set; } = null!; + public DbSet Posts { get; set; } = null!; + } + + protected PrecompiledQueryRelationalFixture Fixture { get; } + protected ITestOutputHelper TestOutputHelper { get; } + + protected void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); + + protected virtual Task Test( + string sourceCode, + Action? interceptorCodeAsserter = null, + Action>? errorAsserter = null, + [CallerMemberName] string callerName = "") + => Fixture.PrecompiledQueryTestHelpers.Test( + """ +await using var context = new PrecompiledQueryContext(dbContextOptions); + +""" + sourceCode, + Fixture.ServiceProvider.GetRequiredService(), + typeof(PrecompiledQueryContext), + interceptorCodeAsserter, + errorAsserter, + TestOutputHelper, + AlwaysPrintGeneratedSources, + callerName); + + protected virtual Task FullSourceTest( + string sourceCode, + Action? interceptorCodeAsserter = null, + Action>? errorAsserter = null, + [CallerMemberName] string callerName = "") + => Fixture.PrecompiledQueryTestHelpers.FullSourceTest( + sourceCode, + Fixture.ServiceProvider.GetRequiredService(), + typeof(PrecompiledQueryContext), + interceptorCodeAsserter, + errorAsserter, + TestOutputHelper, + AlwaysPrintGeneratedSources, + callerName); + + protected virtual bool AlwaysPrintGeneratedSources + => false; + + public class Blog + { + public Blog() + { + } + + public Blog(int id, string name) + { + Id = id; + Name = name; + } + + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public int Id { get; set; } + public string? Name { get; set; } + + public List Posts { get; set; } = new(); + } + + public class Post + { + public int Id { get; set; } + public string? Title { get; set; } + + public Blog? Blog { get; set; } + } + + public static IEnumerable IsAsyncData = new object[][] { [false], [true] }; +} diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs new file mode 100644 index 00000000000..960b3bab4f2 --- /dev/null +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/PrecompiledQueryTestHelpers.cs @@ -0,0 +1,297 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel; +using System.ComponentModel.DataAnnotations.Schema; +using System.Runtime.Loader; +using System.Text.Encodings.Web; +using System.Text.Json; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Editing; +using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.Extensions.Caching.Memory; + +namespace Microsoft.EntityFrameworkCore.TestUtilities; + +public abstract class PrecompiledQueryTestHelpers +{ + private readonly MetadataReference[] _metadataReferences; + + protected PrecompiledQueryTestHelpers() + => _metadataReferences = BuildMetadataReferences().ToArray(); + + public Task Test( + string sourceCode, + DbContextOptions dbContextOptions, + Type dbContextType, + Action? interceptorCodeAsserter, + Action>? errorAsserter, + ITestOutputHelper testOutputHelper, + bool alwaysPrintGeneratedSources, + string callerName) + { + var source = $$""" +public static class TestContainer +{ + public static async Task Test(DbContextOptions dbContextOptions) + { +{{sourceCode}} + } +} +"""; + return FullSourceTest( + source, dbContextOptions, dbContextType, interceptorCodeAsserter, errorAsserter, testOutputHelper, alwaysPrintGeneratedSources, + callerName); + } + + public async Task FullSourceTest( + string sourceCode, + DbContextOptions dbContextOptions, + Type dbContextType, + Action? interceptorCodeAsserter, + Action>? errorAsserter, + ITestOutputHelper testOutputHelper, + bool alwaysPrintGeneratedSources, + string callerName) + { + // The overall end-to-end testing for precompiled queries is as follows: + // 1. Compile the user code, produce an assembly from it and load it. We need to do this since precompiled query generation requires + // an actual DbContext instance, from which we get the model, services, ec. + // 2. Do precompiled query generation. This outputs additional source files (syntax trees) containing interceptors for the located + // EF LINQ queries. + // 3. Integrate the additional syntax trees into the compilation, and again, produce an assembly from it and load it. + // 4. Use reflection to find the EntryPoint (Main method) on this assembly, and invoke it. + var source = $""" +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Threading.Tasks; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Query; +using Xunit; +using static Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase; +//using Microsoft.EntityFrameworkCore.PrecompiledQueryTest; + +{sourceCode} +"""; + + // This turns on the interceptors feature for the designated namespace(s). + var parseOptions = new CSharpParseOptions().WithFeatures( + new[] + { + new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.EntityFrameworkCore.GeneratedInterceptors") + }); + + var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions, path: "Test.cs"); + + var compilation = CSharpCompilation.Create( + "TestCompilation", + syntaxTrees: [syntaxTree], + _metadataReferences, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + IReadOnlyList? generatedFiles = null; + + try + { + // The test code compiled - emit and assembly and load it. + var (assemblyLoadContext, assembly) = EmitAndLoadAssembly(compilation, callerName + "_Original"); + try + { + var workspace = new AdhocWorkspace(); + var syntaxGenerator = SyntaxGenerator.GetGenerator(workspace, LanguageNames.CSharp); + + // TODO: Look up as regular dependencies + var precompiledQueryCodeGenerator = new PrecompiledQueryCodeGenerator(); + + await using var dbContext = (DbContext)Activator.CreateInstance(dbContextType, args: [dbContextOptions])!; + + // Perform precompilation + var precompilationErrors = new List(); + generatedFiles = precompiledQueryCodeGenerator.GeneratePrecompiledQueries( + compilation, syntaxGenerator, dbContext, precompilationErrors, additionalAssembly: assembly); + + if (errorAsserter is null) + { + if (precompilationErrors.Count > 0) + { + Assert.Fail("Precompilation error: " + precompilationErrors[0].Exception); + } + } + else + { + errorAsserter(precompilationErrors); + return; + } + } + finally + { + assemblyLoadContext.Unload(); + } + + // We now have the code-generated interceptors; add them to the compilation and re-emit. + compilation = compilation.AddSyntaxTrees( + generatedFiles.Select(f => CSharpSyntaxTree.ParseText(f.Code, parseOptions, f.Path))); + + // We have the final compilation, including the interceptors. Emit and load it, and then invoke its entry point, which contains + // the original test code with the EF LINQ query, etc. + (assemblyLoadContext, assembly) = EmitAndLoadAssembly(compilation, callerName + "_WithInterceptors"); + try + { + await using var dbContext = (DbContext)Activator.CreateInstance(dbContextType, dbContextOptions)!; + + var testContainer = assembly.ExportedTypes.Single(t => t.Name == "TestContainer"); + var testMethod = testContainer.GetMethod("Test")!; + await (Task)testMethod.Invoke(obj: null, parameters: [dbContextOptions])!; + } + finally + { + assemblyLoadContext.Unload(); + } + } + catch + { + PrintGeneratedSources(); + + throw; + } + + if (alwaysPrintGeneratedSources) + { + PrintGeneratedSources(); + } + + void PrintGeneratedSources() + { + if (generatedFiles is not null) + { + foreach (var generatedFile in generatedFiles) + { + testOutputHelper.WriteLine($"Generated file {generatedFile.Path}: "); + testOutputHelper.WriteLine(""); + testOutputHelper.WriteLine(generatedFile.Code); + } + } + } + + static (AssemblyLoadContext, Assembly) EmitAndLoadAssembly(Compilation compilation, string assemblyLoadContextName) + { + var errorDiagnostics = compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).ToList(); + if (errorDiagnostics.Count > 0) + { + var stringBuilder = new StringBuilder(); + stringBuilder.AppendLine("Compilation failed:").AppendLine(); + + foreach (var errorDiagnostic in errorDiagnostics) + { + stringBuilder.AppendLine(errorDiagnostic.ToString()); + + var textLines = errorDiagnostic.Location.SourceTree!.GetText().Lines; + var startLine = errorDiagnostic.Location.GetLineSpan().StartLinePosition.Line; + var endLine = errorDiagnostic.Location.GetLineSpan().EndLinePosition.Line; + + if (startLine == endLine) + { + stringBuilder.Append("Line: ").AppendLine(textLines[startLine].ToString().TrimStart()); + } + else + { + stringBuilder.AppendLine("Lines:"); + for (var i = startLine; i <= endLine; i++) + { + stringBuilder.AppendLine(textLines[i].ToString()); + } + } + } + + throw new InvalidOperationException("Compilation failed:" + stringBuilder); + } + + using var memoryStream = new MemoryStream(); + var emitResult = compilation.Emit(memoryStream); + memoryStream.Position = 0; + + errorDiagnostics = emitResult.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToList(); + if (errorDiagnostics.Count > 0) + { + throw new InvalidOperationException( + "Compilation emit failed:" + Environment.NewLine + string.Join(Environment.NewLine, errorDiagnostics)); + } + + var assemblyLoadContext = new AssemblyLoadContext(assemblyLoadContextName, isCollectible: true); + var assembly = assemblyLoadContext.LoadFromStream(memoryStream); + return (assemblyLoadContext, assembly); + } + } + + protected virtual IEnumerable BuildMetadataReferences() + { + var netAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + + return new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Queryable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IQueryable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(List<>).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Regex).Assembly.Location), + MetadataReference.CreateFromFile(typeof(JsonSerializer).Assembly.Location), + MetadataReference.CreateFromFile(typeof(JavaScriptEncoder).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DatabaseGeneratedAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DbContext).Assembly.Location), + MetadataReference.CreateFromFile(typeof(RelationalOptionsExtension).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DbConnection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IListSource).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceProvider).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IMemoryCache).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Assert).Assembly.Location), + // This is to allow referencing types from this file, e.g. NonCompilingQueryCompiler + MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location), + MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "mscorlib.dll")), + MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.dll")), + MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Core.dll")), + MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Runtime.dll")), + MetadataReference.CreateFromFile(Path.Combine(netAssemblyPath, "System.Collections.dll")) + } + .Concat(BuildProviderMetadataReferences()); + } + + protected abstract IEnumerable BuildProviderMetadataReferences(); + + // Used from inside the tested code to ensure that we never end up compiling queries at runtime. + // TODO: Probably remove this later, once we have a regular mechanism for failing non-intercepted queries at runtime. + // ReSharper disable once UnusedMember.Global + public class NonCompilingQueryCompiler( + IQueryContextFactory queryContextFactory, + ICompiledQueryCache compiledQueryCache, + ICompiledQueryCacheKeyGenerator compiledQueryCacheKeyGenerator, + IDatabase database, + IDiagnosticsLogger logger, + ICurrentDbContext currentContext, + IEvaluatableExpressionFilter evaluatableExpressionFilter, + IModel model) + : QueryCompiler(queryContextFactory, compiledQueryCache, compiledQueryCacheKeyGenerator, database, logger, + currentContext, evaluatableExpressionFilter, model) + { + public const string ErrorMessage = + "A query reached the query compilation pipeline, indicating that it was not intercepted as a precompiled query."; + + public override TResult Execute(Expression query) + { + Assert.Fail(ErrorMessage); + throw new UnreachableException(); + } + + public override TResult ExecuteAsync(Expression query, CancellationToken cancellationToken = default) + { + Assert.Fail(ErrorMessage); + throw new UnreachableException(); + } + } +} diff --git a/test/EFCore.Specification.Tests/JsonTypesTestBase.cs b/test/EFCore.Specification.Tests/JsonTypesTestBase.cs index 5a79e1a355e..4b6a3638399 100644 --- a/test/EFCore.Specification.Tests/JsonTypesTestBase.cs +++ b/test/EFCore.Specification.Tests/JsonTypesTestBase.cs @@ -3924,7 +3924,6 @@ protected class BinaryListArrayArrayListType { var contextFactory = await CreateContextFactory( buildModel, - addServices: AddServices, configureConventions: configureConventions); using var context = contextFactory.CreateContext(); @@ -4002,9 +4001,6 @@ protected class BinaryListArrayArrayListType protected override string StoreName => "JsonTypesTest"; - protected virtual IServiceCollection AddServices(IServiceCollection serviceCollection) - => serviceCollection; - protected virtual void AssertElementFacets(IElementType element, Dictionary? facets) { Assert.Equal(FacetValue(CoreAnnotationNames.Precision), element.GetPrecision()); diff --git a/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs b/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs index ff738874eb2..7f8f924020e 100644 --- a/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs +++ b/test/EFCore.Specification.Tests/NonSharedModelTestBase.cs @@ -72,10 +72,11 @@ public virtual Task InitializeAsync() : CreateTestStore(); shouldLogCategory ??= _ => false; - var services = (useServiceProvider + var services = AddServices( + (useServiceProvider ? TestStoreFactory.AddProviderServices(new ServiceCollection()) : new ServiceCollection()) - .AddSingleton(TestStoreFactory.CreateListLoggerFactory(shouldLogCategory)); + .AddSingleton(TestStoreFactory.CreateListLoggerFactory(shouldLogCategory))); if (onModelCreating != null) { @@ -117,6 +118,9 @@ public virtual Task InitializeAsync() return optionsBuilder; } + protected virtual IServiceCollection AddServices(IServiceCollection serviceCollection) + => serviceCollection; + protected virtual DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) => builder .EnableSensitiveDataLogging() @@ -178,5 +182,8 @@ public virtual TContext CreateContext() => UsePooling ? PooledContextFactory!.CreateDbContext() : (TContext)ServiceProvider.GetRequiredService(typeof(TContext)); + + public virtual DbContextOptions GetOptions() + => ServiceProvider.GetRequiredService(); } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs new file mode 100644 index 00000000000..99148b0d900 --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/Query/AdHocPrecompiledQuerySqlServerTest.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +public class AdHocPrecompiledQuerySqlServerTest(ITestOutputHelper testOutputHelper) + : AdHocPrecompiledQueryRelationalTestBase(testOutputHelper) +{ + protected override bool AlwaysPrintGeneratedSources + => false; + + public override async Task Index_no_evaluatability() + { + await base.Index_no_evaluatability(); + + AssertSql(""" +SELECT [j].[Id], [j].[IntList], [j].[JsonThing] +FROM [JsonEntities] AS [j] +WHERE CAST(JSON_VALUE([j].[IntList], '$[' + CAST([j].[Id] AS nvarchar(max)) + ']') AS int) = 2 +"""); + } + + public override async Task Index_with_captured_variable() + { + await base.Index_with_captured_variable(); + + AssertSql(""" +@__id_0='1' + +SELECT [j].[Id], [j].[IntList], [j].[JsonThing] +FROM [JsonEntities] AS [j] +WHERE CAST(JSON_VALUE([j].[IntList], '$[' + CAST(@__id_0 AS nvarchar(max)) + ']') AS int) = 2 +"""); + } + + public override async Task JsonScalar() + { + await base.JsonScalar(); + + AssertSql(""" +SELECT [j].[Id], [j].[IntList], [j].[JsonThing] +FROM [JsonEntities] AS [j] +WHERE JSON_VALUE([j].[JsonThing], '$.StringProperty') = N'foo' +"""); + } + + public override async Task Materialize_non_public() + { + await base.Materialize_non_public(); + + AssertSql( + """ +@p0='10' (Nullable = true) +@p1='9' (Nullable = true) +@p2='8' (Nullable = true) + +SET IMPLICIT_TRANSACTIONS OFF; +SET NOCOUNT ON; +INSERT INTO [NonPublicEntities] ([PrivateAutoProperty], [PrivateProperty], [_privateField]) +OUTPUT INSERTED.[Id] +VALUES (@p0, @p1, @p2); +""", + // + """ +SELECT TOP(2) [n].[Id], [n].[PrivateAutoProperty], [n].[PrivateProperty], [n].[_privateField] +FROM [NonPublicEntities] AS [n] +"""); + } + + [ConditionalFact] + public virtual void Check_all_tests_overridden() + => TestHelpers.AssertAllMethodsOverridden(GetType()); + + protected override ITestStoreFactory TestStoreFactory + => SqlServerTestStoreFactory.Instance; + + protected override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers + => SqlServerPrecompiledQueryTestHelpers.Instance; + + protected override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + { + builder = base.AddOptions(builder); + + // TODO: Figure out if there's a nice way to continue using the retrying strategy + var sqlServerOptionsBuilder = new SqlServerDbContextOptionsBuilder(builder); + sqlServerOptionsBuilder.ExecutionStrategy(d => new NonRetryingExecutionStrategy(d)); + return builder; + } +} diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs new file mode 100644 index 00000000000..342808d8613 --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrecompiledQuerySqlServerTest.cs @@ -0,0 +1,2018 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// ReSharper disable InconsistentNaming + +namespace Microsoft.EntityFrameworkCore.Query; + +public class PrecompiledQuerySqlServerTest( + PrecompiledQuerySqlServerTest.PrecompiledQuerySqlServerFixture fixture, + ITestOutputHelper testOutputHelper) + : PrecompiledQueryRelationalTestBase(fixture, testOutputHelper), + IClassFixture +{ + protected override bool AlwaysPrintGeneratedSources + => false; + + #region Expression types + + public override async Task BinaryExpression() + { + await base.BinaryExpression(); + + AssertSql( + """ +@__id_0='3' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] > @__id_0 +"""); + } + + public override async Task Conditional_no_evaluatable() + { + await base.Conditional_no_evaluatable(); + + AssertSql( + """ +SELECT CASE + WHEN [b].[Id] = 2 THEN N'yes' + ELSE N'no' +END +FROM [Blogs] AS [b] +"""); + } + + public override async Task Conditional_contains_captured_variable() + { + await base.Conditional_contains_captured_variable(); + + AssertSql( + """ +@__yes_0='yes' (Size = 4000) + +SELECT CASE + WHEN [b].[Id] = 2 THEN @__yes_0 + ELSE N'no' +END +FROM [Blogs] AS [b] +"""); + } + + public override async Task Invoke_no_evaluatability_is_not_supported() + { + await base.Invoke_no_evaluatability_is_not_supported(); + + AssertSql(); + } + + public override async Task ListInit_no_evaluatability() + { + await base.ListInit_no_evaluatability(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Id] + 1 +FROM [Blogs] AS [b] +"""); + } + + public override async Task ListInit_with_evaluatable_with_captured_variable() + { + await base.ListInit_with_evaluatable_with_captured_variable(); + + AssertSql( + """ +SELECT [b].[Id] +FROM [Blogs] AS [b] +"""); + } + + public override async Task ListInit_with_evaluatable_without_captured_variable() + { + await base.ListInit_with_evaluatable_without_captured_variable(); + + AssertSql( + """ +SELECT [b].[Id] +FROM [Blogs] AS [b] +"""); + } + + public override async Task ListInit_fully_evaluatable() + { + await base.ListInit_fully_evaluatable(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] IN (7, 8) +"""); + } + + public override async Task MethodCallExpression_no_evaluatability() + { + await base.MethodCallExpression_no_evaluatability(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] IS NOT NULL AND LEFT([b].[Name], LEN([b].[Name])) = [b].[Name] +"""); + } + + public override async Task MethodCallExpression_with_evaluatable_with_captured_variable() + { + await base.MethodCallExpression_with_evaluatable_with_captured_variable(); + + AssertSql( + """ +@__pattern_0_startswith='foo%' (Size = 4000) + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] LIKE @__pattern_0_startswith ESCAPE N'\' +"""); + } + + public override async Task MethodCallExpression_with_evaluatable_without_captured_variable() + { + await base.MethodCallExpression_with_evaluatable_without_captured_variable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] LIKE N'foo%' +"""); + } + + public override async Task MethodCallExpression_fully_evaluatable() + { + await base.MethodCallExpression_fully_evaluatable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task New_with_no_arguments() + { + await base.New_with_no_arguments(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 0 +"""); + } + + public override async Task Where_New_with_captured_variable() + { + await base.Where_New_with_captured_variable(); + + AssertSql(); + } + + public override async Task Select_New_with_captured_variable() + { + await base.Select_New_with_captured_variable(); + + AssertSql( + """ +SELECT [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task MemberInit_no_evaluatable() + { + await base.MemberInit_no_evaluatable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task MemberInit_contains_captured_variable() + { + await base.MemberInit_contains_captured_variable(); + + AssertSql( + """ +@__id_0='8' + +SELECT @__id_0 AS [Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task MemberInit_evaluatable_as_constant() + { + await base.MemberInit_evaluatable_as_constant(); + + AssertSql( + """ +SELECT 1 AS [Id], N'foo' AS [Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task MemberInit_evaluatable_as_parameter() + { + await base.MemberInit_evaluatable_as_parameter(); + + AssertSql( + """ +SELECT 1 +FROM [Blogs] AS [b] +"""); + } + + public override async Task NewArray() + { + await base.NewArray(); + + AssertSql( + """ +@__i_0='8' + +SELECT [b].[Id], [b].[Id] + @__i_0 +FROM [Blogs] AS [b] +"""); + } + + public override async Task Unary() + { + await base.Unary(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE CAST([b].[Id] AS smallint) = CAST(8 AS smallint) +"""); + } + + public virtual async Task Collate() + { + await Test("""_ = context.Blogs.Where(b => EF.Functions.Collate(b.Name, "German_PhoneBook_CI_AS") == "foo").ToList();"""); + + AssertSql(); + } + + #endregion Expression types + + #region Terminating operators + + public override async Task Terminating_AsEnumerable() + { + await base.Terminating_AsEnumerable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_AsAsyncEnumerable_on_DbSet() + { + await base.Terminating_AsAsyncEnumerable_on_DbSet(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_AsAsyncEnumerable_on_IQueryable() + { + await base.Terminating_AsAsyncEnumerable_on_IQueryable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +"""); + } + + public override async Task Foreach_sync_over_operator() + { + await base.Foreach_sync_over_operator(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +"""); + } + + public override async Task Terminating_ToArray() + { + await base.Terminating_ToArray(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToArrayAsync() + { + await base.Terminating_ToArrayAsync(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToDictionary() + { + await base.Terminating_ToDictionary(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToDictionaryAsync() + { + await base.Terminating_ToDictionaryAsync(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task ToDictionary_over_anonymous_type() + { + await base.ToDictionary_over_anonymous_type(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task ToDictionaryAsync_over_anonymous_type() + { + await base.ToDictionaryAsync_over_anonymous_type(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToHashSet() + { + await base.Terminating_ToHashSet(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToHashSetAsync() + { + await base.Terminating_ToHashSetAsync(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToLookup() + { + await base.Terminating_ToLookup(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToList() + { + await base.Terminating_ToList(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ToListAsync() + { + await base.Terminating_ToListAsync(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Foreach_sync_over_DbSet_property_is_not_supported() + { + await base.Foreach_sync_over_DbSet_property_is_not_supported(); + + AssertSql(); + } + + public override async Task Foreach_async_is_not_supported() + { + await base.Foreach_async_is_not_supported(); + + AssertSql(); + } + + #endregion Terminating operators + + #region Reducing terminating operators + + public override async Task Terminating_All() + { + await base.Terminating_All(); + + AssertSql( + """ +SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] <= 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] <= 8) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_AllAsync() + { + await base.Terminating_AllAsync(); + + AssertSql( + """ +SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] <= 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] <= 8) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_Any() + { + await base.Terminating_Any(); + + AssertSql( + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] < 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] < 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_AnyAsync() + { + await base.Terminating_AnyAsync(); + + AssertSql( + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] < 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Blogs] AS [b] + WHERE [b].[Id] < 7) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_Average() + { + await base.Terminating_Average(); + + AssertSql( + """ +SELECT AVG(CAST([b].[Id] AS float)) +FROM [Blogs] AS [b] +""", + // + """ +SELECT AVG(CAST([b].[Id] AS float)) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_AverageAsync() + { + await base.Terminating_AverageAsync(); + + AssertSql( + """ +SELECT AVG(CAST([b].[Id] AS float)) +FROM [Blogs] AS [b] +""", + // + """ +SELECT AVG(CAST([b].[Id] AS float)) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_Contains() + { + await base.Terminating_Contains(); + + AssertSql( + """ +@__p_0='8' + +SELECT CASE + WHEN @__p_0 IN ( + SELECT [b].[Id] + FROM [Blogs] AS [b] + ) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +@__p_0='7' + +SELECT CASE + WHEN @__p_0 IN ( + SELECT [b].[Id] + FROM [Blogs] AS [b] + ) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_ContainsAsync() + { + await base.Terminating_ContainsAsync(); + + AssertSql( + """ +@__p_0='8' + +SELECT CASE + WHEN @__p_0 IN ( + SELECT [b].[Id] + FROM [Blogs] AS [b] + ) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +""", + // + """ +@__p_0='7' + +SELECT CASE + WHEN @__p_0 IN ( + SELECT [b].[Id] + FROM [Blogs] AS [b] + ) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +"""); + } + + public override async Task Terminating_Count() + { + await base.Terminating_Count(); + + AssertSql( + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +"""); + } + + public override async Task Terminating_CountAsync() + { + await base.Terminating_CountAsync(); + + AssertSql( + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +"""); + } + + public override async Task Terminating_ElementAt() + { + await base.Terminating_ElementAt(); + + AssertSql( + """ +@__p_0='1' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +""", + // + """ +@__p_0='3' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +"""); + } + + public override async Task Terminating_ElementAtAsync() + { + await base.Terminating_ElementAtAsync(); + + AssertSql( + """ +@__p_0='1' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +""", + // + """ +@__p_0='3' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +"""); + } + + public override async Task Terminating_ElementAtOrDefault() + { + await base.Terminating_ElementAtOrDefault(); + + AssertSql( + """ +@__p_0='1' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +""", + // + """ +@__p_0='3' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +"""); + } + + public override async Task Terminating_ElementAtOrDefaultAsync() + { + await base.Terminating_ElementAtOrDefaultAsync(); + + AssertSql( + """ +@__p_0='1' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +""", + // + """ +@__p_0='3' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +OFFSET @__p_0 ROWS FETCH NEXT 1 ROWS ONLY +"""); + } + + public override async Task Terminating_First() + { + await base.Terminating_First(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_FirstAsync() + { + await base.Terminating_FirstAsync(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_FirstOrDefault() + { + await base.Terminating_FirstOrDefault(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_FirstOrDefaultAsync() + { + await base.Terminating_FirstOrDefaultAsync(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_GetEnumerator() + { + await base.Terminating_GetEnumerator(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +"""); + } + + public override async Task Terminating_Last() + { + await base.Terminating_Last(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +"""); + } + + public override async Task Terminating_LastAsync() + { + await base.Terminating_LastAsync(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +"""); + } + + public override async Task Terminating_LastOrDefault() + { + await base.Terminating_LastOrDefault(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +"""); + } + + public override async Task Terminating_LastOrDefaultAsync() + { + await base.Terminating_LastOrDefaultAsync(); + + AssertSql( + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +ORDER BY [b].[Id] DESC +""", + // + """ +SELECT TOP(1) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +ORDER BY [b].[Id] DESC +"""); + } + + public override async Task Terminating_LongCount() + { + await base.Terminating_LongCount(); + + AssertSql( + """ +SELECT COUNT_BIG(*) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COUNT_BIG(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +"""); + } + + public override async Task Terminating_LongCountAsync() + { + await base.Terminating_LongCountAsync(); + + AssertSql( + """ +SELECT COUNT_BIG(*) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COUNT_BIG(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +"""); + } + + public override async Task Terminating_Max() + { + await base.Terminating_Max(); + + AssertSql( + """ +SELECT MAX([b].[Id]) +FROM [Blogs] AS [b] +""", + // + """ +SELECT MAX([b].[Id]) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_MaxAsync() + { + await base.Terminating_MaxAsync(); + + AssertSql( + """ +SELECT MAX([b].[Id]) +FROM [Blogs] AS [b] +""", + // + """ +SELECT MAX([b].[Id]) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_Min() + { + await base.Terminating_Min(); + + AssertSql( + """ +SELECT MIN([b].[Id]) +FROM [Blogs] AS [b] +""", + // + """ +SELECT MIN([b].[Id]) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_MinAsync() + { + await base.Terminating_MinAsync(); + + AssertSql( + """ +SELECT MIN([b].[Id]) +FROM [Blogs] AS [b] +""", + // + """ +SELECT MIN([b].[Id]) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_Single() + { + await base.Terminating_Single(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_SingleAsync() + { + await base.Terminating_SingleAsync(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_SingleOrDefault() + { + await base.Terminating_SingleOrDefault(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_SingleOrDefaultAsync() + { + await base.Terminating_SingleOrDefaultAsync(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 8 +""", + // + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = 7 +"""); + } + + public override async Task Terminating_Sum() + { + await base.Terminating_Sum(); + + AssertSql( + """ +SELECT COALESCE(SUM([b].[Id]), 0) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COALESCE(SUM([b].[Id]), 0) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_SumAsync() + { + await base.Terminating_SumAsync(); + + AssertSql( + """ +SELECT COALESCE(SUM([b].[Id]), 0) +FROM [Blogs] AS [b] +""", + // + """ +SELECT COALESCE(SUM([b].[Id]), 0) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ExecuteDelete() + { + await base.Terminating_ExecuteDelete(); + + AssertSql( + """ +DELETE FROM [b] +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ExecuteDeleteAsync() + { + await base.Terminating_ExecuteDeleteAsync(); + + AssertSql( + """ +DELETE FROM [b] +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +"""); + } + + public override async Task Terminating_ExecuteUpdate() + { + await base.Terminating_ExecuteUpdate(); + + AssertSql( + """ +@__suffix_0='Suffix' (Size = 4000) + +UPDATE [b] +SET [b].[Name] = COALESCE([b].[Name], N'') + @__suffix_0 +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] = 9 AND [b].[Name] = N'Blog2Suffix' +"""); + } + + public override async Task Terminating_ExecuteUpdateAsync() + { + await base.Terminating_ExecuteUpdateAsync(); + + AssertSql( + """ +@__suffix_0='Suffix' (Size = 4000) + +UPDATE [b] +SET [b].[Name] = COALESCE([b].[Name], N'') + @__suffix_0 +FROM [Blogs] AS [b] +WHERE [b].[Id] > 8 +""", + // + """ +SELECT COUNT(*) +FROM [Blogs] AS [b] +WHERE [b].[Id] = 9 AND [b].[Name] = N'Blog2Suffix' +"""); + } + + #endregion Reducing terminating operators + + #region SQL expression quotability + + public override async Task Union() + { + await base.Union(); + + AssertSql( + """ +SELECT [u].[Id], [u].[Name] +FROM ( + SELECT [b].[Id], [b].[Name] + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7 + UNION + SELECT [b0].[Id], [b0].[Name] + FROM [Blogs] AS [b0] + WHERE [b0].[Id] < 10 +) AS [u] +ORDER BY [u].[Id] +"""); + } + + public override async Task Concat() + { + await base.Concat(); + + AssertSql( + """ +SELECT [u].[Id], [u].[Name] +FROM ( + SELECT [b].[Id], [b].[Name] + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7 + UNION ALL + SELECT [b0].[Id], [b0].[Name] + FROM [Blogs] AS [b0] + WHERE [b0].[Id] < 10 +) AS [u] +ORDER BY [u].[Id] +"""); + } + + public override async Task Intersect() + { + await base.Intersect(); + + AssertSql( + """ +SELECT [i].[Id], [i].[Name] +FROM ( + SELECT [b].[Id], [b].[Name] + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7 + INTERSECT + SELECT [b0].[Id], [b0].[Name] + FROM [Blogs] AS [b0] + WHERE [b0].[Id] > 8 +) AS [i] +ORDER BY [i].[Id] +"""); + } + + public override async Task Except() + { + await base.Except(); + + AssertSql( + """ +SELECT [e].[Id], [e].[Name] +FROM ( + SELECT [b].[Id], [b].[Name] + FROM [Blogs] AS [b] + WHERE [b].[Id] > 7 + EXCEPT + SELECT [b0].[Id], [b0].[Name] + FROM [Blogs] AS [b0] + WHERE [b0].[Id] > 8 +) AS [e] +ORDER BY [e].[Id] +"""); + } + + public override async Task ValuesExpression() + { + await base.ValuesExpression(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE ( + SELECT COUNT(*) + FROM (VALUES (CAST(7 AS int)), ([b].[Id])) AS [v]([Value]) + WHERE [v].[Value] > 8) = 2 +"""); + } + + public override async Task Contains_with_parameterized_collection() + { + await base.Contains_with_parameterized_collection(); + + AssertSql( + """ +@__ids_0='[1,2,3]' (Size = 4000) + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] IN ( + SELECT [i].[value] + FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i] +) +"""); + } + + public override async Task FromSqlRaw() + { + await base.FromSqlRaw(); + + AssertSql( + """ +SELECT [m].[Id], [m].[Name] +FROM ( + SELECT * FROM Blogs +) AS [m] +ORDER BY [m].[Id] +"""); + } + + public override async Task FromSql_with_FormattableString_parameters() + { + await base.FromSql_with_FormattableString_parameters(); + + AssertSql( + """ +p0='8' +p1='9' + +SELECT [m].[Id], [m].[Name] +FROM ( + SELECT * FROM Blogs WHERE Id > @p0 AND Id < @p1 +) AS [m] +ORDER BY [m].[Id] +"""); + } + + [ConditionalFact] + public virtual async Task SqlServerAggregateFunctionExpression() + { + await Test( + """ +_ = context.Blogs + .GroupBy(b => b.Id) + .Select(g => string.Join(", ", g.OrderBy(b => b.Name).Select(b => b.Name))) + .ToList(); +"""); + + AssertSql( + """ +SELECT COALESCE(STRING_AGG(COALESCE([b].[Name], N''), N', ') WITHIN GROUP (ORDER BY [b].[Name]), N'') +FROM [Blogs] AS [b] +GROUP BY [b].[Id] +"""); + } + + // SqlServerOpenJsonExpression is covered by PrecompiledQueryRelationalTestBase.Contains_with_parameterized_collection + +// [ConditionalFact] +// public virtual Task TableValuedFunctionExpression_toplevel() +// => Test( +// "_ = context.GetBlogsWithAtLeast(9).ToList();", +// modelSourceCode: providerOptions => $$""" +// public class BlogContext : DbContext +// { +// public DbSet Blogs { get; set; } +// +// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) +// => optionsBuilder +// {{providerOptions}} +// .ReplaceService(); +// +// protected override void OnModelCreating(ModelBuilder modelBuilder) +// { +// modelBuilder.HasDbFunction(typeof(BlogContext).GetMethod(nameof(GetBlogsWithAtLeast))); +// } +// +// public IQueryable GetBlogsWithAtLeast(int minBlogId) => FromExpression(() => GetBlogsWithAtLeast(minBlogId)); +// } +// +// public class Blog +// { +// [DatabaseGenerated(DatabaseGeneratedOption.None)] +// public int Id { get; set; } +// public string StringProperty { get; set; } +// } +// """, +// setupSql: """ +// CREATE FUNCTION dbo.GetBlogsWithAtLeast(@minBlogId int) +// RETURNS TABLE AS RETURN +// ( +// SELECT [b].[Id], [b].[Name] FROM [Blogs] AS [b] WHERE [b].[Id] >= @minBlogId +// ) +// """, +// cleanupSql: "DROP FUNCTION dbo.GetBlogsWithAtLeast;"); +// +// [ConditionalFact] +// public virtual Task TableValuedFunctionExpression_non_toplevel() +// => Test( +// "_ = context.Blogs.Where(b => context.GetPosts(b.Id).Count() == 2).ToList();", +// modelSourceCode: providerOptions => $$""" +// public class BlogContext : DbContext +// { +// public DbSet Blogs { get; set; } +// +// protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) +// => optionsBuilder +// {{providerOptions}} +// .ReplaceService(); +// +// protected override void OnModelCreating(ModelBuilder modelBuilder) +// { +// modelBuilder.HasDbFunction(typeof(BlogContext).GetMethod(nameof(GetPosts))); +// } +// +// public IQueryable GetPosts(int blogId) => FromExpression(() => GetPosts(blogId)); +// } +// +// public class Blog +// { +// public int Id { get; set; } +// public string StringProperty { get; set; } +// public List Post { get; set; } +// } +// +// public class Post +// { +// public int Id { get; set; } +// public string Title { get; set; } +// +// public Blog Blog { get; set; } +// } +// """, +// setupSql: """ +// CREATE FUNCTION dbo.GetPosts(@blogId int) +// RETURNS TABLE AS RETURN +// ( +// SELECT [p].[Id], [p].[Title], [p].[BlogId] FROM [Posts] AS [p] WHERE [p].[BlogId] = @blogId +// ) +// """, +// cleanupSql: "DROP FUNCTION dbo.GetPosts;"); + + #endregion SQL expression quotability + + #region Different query roots + + public override async Task DbContext_as_local_variable() + { + await base.DbContext_as_local_variable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task DbContext_as_field() + { + await base.DbContext_as_field(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task DbContext_as_property() + { + await base.DbContext_as_property(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task DbContext_as_captured_variable() + { + await base.DbContext_as_captured_variable(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task DbContext_as_method_invocation_result() + { + await base.DbContext_as_method_invocation_result(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + #endregion Different query roots + + #region Negative cases + + public override async Task Dynamic_query_does_not_get_precompiled() + { + await base.Dynamic_query_does_not_get_precompiled(); + + AssertSql(); + } + + public override async Task ToList_over_objects_does_not_get_precompiled() + { + await base.ToList_over_objects_does_not_get_precompiled(); + + AssertSql(); + } + + public override async Task Query_compilation_failure() + { + await base.Query_compilation_failure(); + + AssertSql(); + } + + public override async Task EF_Constant_is_not_supported() + { + await base.EF_Constant_is_not_supported(); + + AssertSql(); + } + + public override async Task NotParameterizedAttribute_with_constant() + { + await base.NotParameterizedAttribute_with_constant(); + + AssertSql( + """ +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] = N'Blog2' +"""); + } + + public override async Task NotParameterizedAttribute_is_not_supported_with_non_constant_argument() + { + await base.NotParameterizedAttribute_is_not_supported_with_non_constant_argument(); + + AssertSql(); + } + + public override async Task Query_syntax_is_not_supported() + { + await base.Query_syntax_is_not_supported(); + + AssertSql(); + } + + #endregion Negative cases + + public override async Task Select_changes_type() + { + await base.Select_changes_type(); + + AssertSql( + """ +SELECT [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + public override async Task OrderBy() + { + await base.OrderBy(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Name] +"""); + } + + public override async Task Skip() + { + await base.Skip(); + + AssertSql( + """ +@__p_0='1' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Name] +OFFSET @__p_0 ROWS +"""); + } + + public override async Task Take() + { + await base.Take(); + + AssertSql( + """ +@__p_0='1' + +SELECT TOP(@__p_0) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Name] +"""); + } + + public override async Task Project_anonymous_object() + { + await base.Project_anonymous_object(); + + AssertSql( + """ +SELECT COALESCE([b].[Name], N'') + N'Foo' AS [Foo] +FROM [Blogs] AS [b] +"""); + } + + public override async Task Two_captured_variables_in_same_lambda() + { + await base.Two_captured_variables_in_same_lambda(); + + AssertSql( + """ +@__yes_0='yes' (Size = 4000) +@__no_1='no' (Size = 4000) + +SELECT CASE + WHEN [b].[Id] = 3 THEN @__yes_0 + ELSE @__no_1 +END +FROM [Blogs] AS [b] +"""); + } + + public override async Task Two_captured_variables_in_different_lambdas() + { + await base.Two_captured_variables_in_different_lambdas(); + + AssertSql( + """ +@__starts_0_startswith='blog%' (Size = 4000) +@__ends_1_endswith='%2' (Size = 4000) + +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] LIKE @__starts_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__ends_1_endswith ESCAPE N'\' +"""); + } + + public override async Task Same_captured_variable_twice_in_same_lambda() + { + await base.Same_captured_variable_twice_in_same_lambda(); + + AssertSql( + """ +@__foo_0_startswith='X%' (Size = 4000) +@__foo_0_endswith='%X' (Size = 4000) + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] LIKE @__foo_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__foo_0_endswith ESCAPE N'\' +"""); + } + + public override async Task Same_captured_variable_twice_in_different_lambdas() + { + await base.Same_captured_variable_twice_in_different_lambdas(); + + AssertSql( + """ +@__foo_0_startswith='X%' (Size = 4000) +@__foo_0_endswith='%X' (Size = 4000) + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Name] LIKE @__foo_0_startswith ESCAPE N'\' AND [b].[Name] LIKE @__foo_0_endswith ESCAPE N'\' +"""); + } + + public override async Task Include_single() + { + await base.Include_single(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name], [p].[Id], [p].[BlogId], [p].[Title] +FROM [Blogs] AS [b] +LEFT JOIN [Posts] AS [p] ON [b].[Id] = [p].[BlogId] +WHERE [b].[Id] > 8 +ORDER BY [b].[Id] +"""); + } + + public override async Task Include_split() + { + await base.Include_split(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +ORDER BY [b].[Id] +""", + // + """ +SELECT [p].[Id], [p].[BlogId], [p].[Title], [b].[Id] +FROM [Blogs] AS [b] +INNER JOIN [Posts] AS [p] ON [b].[Id] = [p].[BlogId] +ORDER BY [b].[Id] +"""); + } + + public override async Task Final_GroupBy() + { + await base.Final_GroupBy(); + + AssertSql( + """ +SELECT [b].[Name], [b].[Id] +FROM [Blogs] AS [b] +ORDER BY [b].[Name] +"""); + } + + public override async Task Multiple_queries_with_captured_variables() + { + await base.Multiple_queries_with_captured_variables(); + + AssertSql( + """ +@__id1_0='8' +@__id2_1='9' + +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = @__id1_0 OR [b].[Id] = @__id2_1 +""", + // + """ +@__id1_0='8' + +SELECT TOP(2) [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +WHERE [b].[Id] = @__id1_0 +"""); + } + + public override async Task Unsafe_accessor_gets_generated_once_for_multiple_queries() + { + await base.Unsafe_accessor_gets_generated_once_for_multiple_queries(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +""", + // + """ +SELECT [b].[Id], [b].[Name] +FROM [Blogs] AS [b] +"""); + } + + [ConditionalFact] + public virtual void Check_all_tests_overridden() + => TestHelpers.AssertAllMethodsOverridden(GetType()); + + public class PrecompiledQuerySqlServerFixture : PrecompiledQueryRelationalFixture + { + protected override ITestStoreFactory TestStoreFactory + => SqlServerTestStoreFactory.Instance; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + { + builder = base.AddOptions(builder); + + // TODO: Figure out if there's a nice way to continue using the retrying strategy + var sqlServerOptionsBuilder = new SqlServerDbContextOptionsBuilder(builder); + sqlServerOptionsBuilder.ExecutionStrategy(d => new NonRetryingExecutionStrategy(d)); + return builder; + } + + public override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers => SqlServerPrecompiledQueryTestHelpers.Instance; + } +} diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs index 47969a88022..943b0f500f6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_parameter_type_mapping/FunctionParameterTypeMappingContextModelBuilder.cs @@ -50,15 +50,15 @@ private FunctionParameterTypeMappingContextModel() param.TypeMapping = StringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "varchar", @@ -67,15 +67,15 @@ private FunctionParameterTypeMappingContextModel() getSqlFragmentStatic.TypeMapping = SqlServerStringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "nvarchar(max)", diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs index b3068bed402..d9d601aa499 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/Custom_function_type_mapping/FunctionTypeMappingContextModelBuilder.cs @@ -50,15 +50,15 @@ private FunctionTypeMappingContextModel() param.TypeMapping = SqlServerStringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "nvarchar(max)", @@ -69,15 +69,15 @@ private FunctionTypeMappingContextModel() getSqlFragmentStatic.TypeMapping = StringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "varchar", diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs index e2d187cd8ce..4a979a96d2b 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DataEntityType.cs @@ -60,7 +60,7 @@ public static RuntimeEntityType Create(RuntimeModel model, RuntimeEntityType bas blob.TypeMapping = SqlServerByteArrayTypeMapping.Default.Clone( comparer: new ValueComparer( (byte[] v1, byte[] v2) => StructuralComparisons.StructuralEqualityComparer.Equals((object)v1, (object)v2), - (byte[] v) => v.GetHashCode(), + (byte[] v) => ((object)v).GetHashCode(), (byte[] v) => v), keyComparer: new ValueComparer( (byte[] v1, byte[] v2) => StructuralComparisons.StructuralEqualityComparer.Equals((object)v1, (object)v2), @@ -86,7 +86,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType) (InternalEntityEntry source) => { var entity = (CompiledModelTestBase.Data)source.Entity; - return (ISnapshot)new Snapshot(source.GetCurrentValue(blob) == null ? null : ((ValueComparer)blob.GetValueComparer()).Snapshot(source.GetCurrentValue(blob))); + return (ISnapshot)new Snapshot(source.GetCurrentValue(blob) == null ? null : ((ValueComparer)((IProperty)blob).GetValueComparer()).Snapshot(source.GetCurrentValue(blob))); }); runtimeEntityType.SetStoreGeneratedValuesFactory( () => Snapshot.Empty); diff --git a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs index b5f009363bc..47595a59c8e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Scaffolding/Baselines/DbFunctions/DbFunctionContextModelBuilder.cs @@ -70,15 +70,15 @@ private DbFunctionContextModel() id.TypeMapping = GuidTypeMapping.Default.Clone( comparer: new ValueComparer( (Guid v1, Guid v2) => v1 == v2, - (Guid v) => v.GetHashCode(), + (Guid v) => ((object)v).GetHashCode(), (Guid v) => v), keyComparer: new ValueComparer( (Guid v1, Guid v2) => v1 == v2, - (Guid v) => v.GetHashCode(), + (Guid v) => ((object)v).GetHashCode(), (Guid v) => v), providerValueComparer: new ValueComparer( (Guid v1, Guid v2) => v1 == v2, - (Guid v) => v.GetHashCode(), + (Guid v) => ((object)v).GetHashCode(), (Guid v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "uniqueidentifier")); @@ -92,15 +92,15 @@ private DbFunctionContextModel() condition.TypeMapping = SqlServerStringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "nchar(256)", @@ -197,15 +197,15 @@ private DbFunctionContextModel() date.TypeMapping = SqlServerStringTypeMapping.Default.Clone( comparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), keyComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), providerValueComparer: new ValueComparer( (string v1, string v2) => v1 == v2, - (string v) => v.GetHashCode(), + (string v) => ((object)v).GetHashCode(), (string v) => v), mappingInfo: new RelationalTypeMappingInfo( storeTypeName: "nchar(256)", @@ -217,15 +217,15 @@ private DbFunctionContextModel() isDateStatic.TypeMapping = SqlServerBoolTypeMapping.Default.Clone( comparer: new ValueComparer( (bool v1, bool v2) => v1 == v2, - (bool v) => v.GetHashCode(), + (bool v) => ((object)v).GetHashCode(), (bool v) => v), keyComparer: new ValueComparer( (bool v1, bool v2) => v1 == v2, - (bool v) => v.GetHashCode(), + (bool v) => ((object)v).GetHashCode(), (bool v) => v), providerValueComparer: new ValueComparer( (bool v1, bool v2) => v1 == v2, - (bool v) => v.GetHashCode(), + (bool v) => ((object)v).GetHashCode(), (bool v) => v)); isDateStatic.AddAnnotation("MyGuid", new Guid("00000000-0000-0000-0000-000000000000")); functions["Microsoft.EntityFrameworkCore.Scaffolding.CompiledModelRelationalTestBase+DbFunctionContext.IsDateStatic(string)"] = isDateStatic; diff --git a/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs new file mode 100644 index 00000000000..e25d0a1df12 --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/SqlServerPrecompiledQueryTestHelpers.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal; + +namespace Microsoft.EntityFrameworkCore.TestUtilities; + +public class SqlServerPrecompiledQueryTestHelpers : PrecompiledQueryTestHelpers +{ + public static SqlServerPrecompiledQueryTestHelpers Instance = new(); + + protected override IEnumerable BuildProviderMetadataReferences() + { + yield return MetadataReference.CreateFromFile(typeof(SqlServerOptionsExtension).Assembly.Location); + yield return MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location); + } +} diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs new file mode 100644 index 00000000000..076e532fb33 --- /dev/null +++ b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocPrecompiledQuerySqliteTest.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +public class AdHocPrecompiledQuerySqliteTest(ITestOutputHelper testOutputHelper) + : AdHocPrecompiledQueryRelationalTestBase(testOutputHelper) +{ + protected override bool AlwaysPrintGeneratedSources + => false; + + protected override ITestStoreFactory TestStoreFactory + => SqliteTestStoreFactory.Instance; + + protected override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers + => SqlitePrecompiledQueryTestHelpers.Instance; +} diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs new file mode 100644 index 00000000000..f124247ed7f --- /dev/null +++ b/test/EFCore.Sqlite.FunctionalTests/Query/PrecompiledQuerySqliteTest.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// ReSharper disable InconsistentNaming + +namespace Microsoft.EntityFrameworkCore.Query; + +public class PrecompiledQuerySqliteTest( + PrecompiledQuerySqliteTest.PrecompiledQuerySqliteFixture fixture, + ITestOutputHelper testOutputHelper) + : PrecompiledQueryRelationalTestBase(fixture, testOutputHelper), + IClassFixture +{ + protected override bool AlwaysPrintGeneratedSources + => false; + + [ConditionalFact] + public virtual Task Glob() + => Test("""_ = context.Blogs.Where(b => EF.Functions.Glob(b.Name, "*foo*")).ToList();"""); + + [ConditionalFact] + public virtual Task Regexp() + => Test("""_ = context.Blogs.Where(b => Regex.IsMatch(b.Name, "^foo")).ToList();"""); + + public class PrecompiledQuerySqliteFixture : PrecompiledQueryRelationalFixture + { + protected override ITestStoreFactory TestStoreFactory + => SqliteTestStoreFactory.Instance; + + public override PrecompiledQueryTestHelpers PrecompiledQueryTestHelpers => SqlitePrecompiledQueryTestHelpers.Instance; + } +} diff --git a/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs new file mode 100644 index 00000000000..4e9984a139c --- /dev/null +++ b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqlitePrecompiledQueryTestHelpers.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using Microsoft.EntityFrameworkCore.Sqlite.Infrastructure.Internal; + +namespace Microsoft.EntityFrameworkCore.TestUtilities; + +public class SqlitePrecompiledQueryTestHelpers : PrecompiledQueryTestHelpers +{ + public static SqlitePrecompiledQueryTestHelpers Instance = new(); + + protected override IEnumerable BuildProviderMetadataReferences() + { + yield return MetadataReference.CreateFromFile(typeof(SqliteOptionsExtension).Assembly.Location); + yield return MetadataReference.CreateFromFile(Assembly.GetExecutingAssembly().Location); + } +} diff --git a/test/EFCore.Tests/EFCore.Tests.csproj b/test/EFCore.Tests/EFCore.Tests.csproj index f81c6d6cc98..c38c3eb46b9 100644 --- a/test/EFCore.Tests/EFCore.Tests.csproj +++ b/test/EFCore.Tests/EFCore.Tests.csproj @@ -6,6 +6,7 @@ Microsoft.EntityFrameworkCore disable true + $(NoWarn);EF9100 diff --git a/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs b/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs index aa8517389b9..6144b826051 100644 --- a/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs +++ b/test/EFCore.Tests/Query/Internal/NavigationExpandingExpressionVisitorTests.cs @@ -22,19 +22,19 @@ public TestNavigationExpandingExpressionVisitor() null, new QueryCompilationContext( new QueryCompilationContextDependencies( - null, - null, - null, - null, - null, - null, - null, + model: null, + queryTranslationPreprocessorFactory: null, + queryableMethodTranslatingExpressionVisitorFactory: null, + queryTranslationPostprocessorFactory: null, + shapedQueryCompilingExpressionVisitorFactory: null, + liftableConstantFactory: null, + liftableConstantProcessor: null, new ExecutionStrategyTest.TestExecutionStrategy(new MyDemoContext()), new CurrentDbContext(new MyDemoContext()), - null, - null, + contextOptions: null, + logger: null, new TestInterceptors() - ), false), + ), async: false, precompiling: false), null, null) { @@ -77,7 +77,7 @@ private class A } [ConditionalFact] - public void Visits_extention_childrens() + public void Visits_extension_children() { var model = new Model(); var e = model.AddEntityType(typeof(A), false, ConfigurationSource.Explicit);