Skip to content

Commit

Permalink
Improve non-aggregate string.Join translations (#2536)
Browse files Browse the repository at this point in the history
Translate to concat_ws for the simple case

Closes #2485
  • Loading branch information
roji committed Oct 16, 2022
1 parent e974eca commit 0ff4911
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,6 @@ public class NpgsqlArrayTranslator : IMethodCallTranslator, IMemberTranslator
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);

private static readonly MethodInfo String_Join1 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(object[]) })!;

private static readonly MethodInfo String_Join2 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;

private static readonly MethodInfo String_Join3 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(object[]) })!;

private static readonly MethodInfo String_Join4 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(string[]) })!;

private static readonly MethodInfo String_Join_generic1 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(string));

private static readonly MethodInfo String_Join_generic2 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(char));

#endregion Methods

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -124,23 +104,6 @@ public class NpgsqlArrayTranslator : IMethodCallTranslator, IMemberTranslator
return TranslateCommon(arguments[0], arguments.Slice(1));
}

if (method.DeclaringType == typeof(string)
&& (method == String_Join1
|| method == String_Join2
|| method == String_Join3
|| method == String_Join4
|| method.IsClosedFormOf(String_Join_generic1)
|| method.IsClosedFormOf(String_Join_generic2))
&& !IsMappedToNonArray(arguments[0]))
{
return _sqlExpressionFactory.Function(
"array_to_string",
new[] { arguments[1], arguments[0], _sqlExpressionFactory.Constant("") },
nullable: true,
argumentsPropagateNullability: TrueArrays[3],
typeof(string));
}

// Not an array/list
return null;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
Expand Down Expand Up @@ -58,6 +59,21 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator
m => m.Name == nameof(Enumerable.LastOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(typeof(char));

private static readonly MethodInfo String_Join1 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(object[]) })!;
private static readonly MethodInfo String_Join2 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;
private static readonly MethodInfo String_Join3 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(object[]) })!;
private static readonly MethodInfo String_Join4 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(string[]) })!;
private static readonly MethodInfo String_Join_generic1 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(string));
private static readonly MethodInfo String_Join_generic2 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(char));

#endregion

/// <summary>
Expand Down Expand Up @@ -307,6 +323,56 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator
arguments[1].TypeMapping);
}

if (method.DeclaringType == typeof(string)
&& (method == String_Join1
|| method == String_Join2
|| method == String_Join3
|| method == String_Join4
|| method.IsClosedFormOf(String_Join_generic1)
|| method.IsClosedFormOf(String_Join_generic2)))
{
// If the array of strings to be joined is a constant (NewArrayExpression), we translate to concat_ws.
// Otherwise we translate to array_to_string, which also supports array columns and parameters.
if (arguments[1] is PostgresNewArrayExpression newArrayExpression)
{
var rewrittenArguments = new SqlExpression[newArrayExpression.Expressions.Count + 1];
rewrittenArguments[0] = arguments[0];

for (var i = 0; i < newArrayExpression.Expressions.Count; i++)
{
var argument = newArrayExpression.Expressions[i];

rewrittenArguments[i + 1] = argument switch
{
ColumnExpression { IsNullable: false } => argument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? _sqlExpressionFactory.Constant(string.Empty, typeof(string))
: constantExpression,
_ => _sqlExpressionFactory.Coalesce(argument, _sqlExpressionFactory.Constant(string.Empty, typeof(string)))
};
}

// Only the delimiter (first arg) propagates nullability - all others are non-nullable, since we wrap the others in coalesce
// (where needed).
var argumentsPropagateNullability = new bool[rewrittenArguments.Length];
argumentsPropagateNullability[0] = true;

return _sqlExpressionFactory.Function(
"concat_ws",
rewrittenArguments,
nullable: true,
argumentsPropagateNullability,
typeof(string));
}

return _sqlExpressionFactory.Function(
"array_to_string",
new[] { arguments[1], arguments[0], _sqlExpressionFactory.Constant("") },
nullable: true,
argumentsPropagateNullability: TrueArrays[3],
typeof(string));
}

if (method == StartsWith)
{
return TranslateStartsEndsWith(instance!, arguments[0], true);
Expand Down
5 changes: 3 additions & 2 deletions test/EFCore.PG.FunctionalTests/Query/ArrayArrayQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -869,9 +869,10 @@ public override async Task Array_IndexOf2(bool async)
""");
}

public override async Task String_Join(bool async)
// Note: see NorthwindFunctionsQueryNpgsqlTest.String_Join_non_aggregate for regular use without an array column/parameter
public override async Task String_Join_with_array_parameter(bool async)
{
await base.String_Join(async);
await base.String_Join_with_array_parameter(async);

AssertSql(
"""
Expand Down
5 changes: 3 additions & 2 deletions test/EFCore.PG.FunctionalTests/Query/ArrayListQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,10 @@ public override async Task Array_IndexOf2(bool async)
""");
}

public override async Task String_Join(bool async)
// Note: see NorthwindFunctionsQueryNpgsqlTest.String_Join_non_aggregate for regular use without an array column/parameter
public override async Task String_Join_with_array_parameter(bool async)
{
await base.String_Join(async);
await base.String_Join_with_array_parameter(async);

AssertSql(
"""
Expand Down
3 changes: 2 additions & 1 deletion test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,10 @@ public virtual Task Concat(bool async)
[MemberData(nameof(IsAsyncData))]
public abstract Task Array_IndexOf2(bool async);

// Note: see NorthwindFunctionsQueryNpgsqlTest.String_Join_non_aggregate for regular use without an array column/parameter
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Join(bool async)
public virtual Task String_Join_with_array_parameter(bool async)
=> AssertQuery(
async,
ss => ss.Set<ArrayEntity>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class NorthwindFunctionsQueryNpgsqlTest : NorthwindFunctionsQueryRelation
: base(fixture)
{
ClearLog();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
// Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

public override async Task IsNullOrWhiteSpace_in_predicate(bool async)
Expand Down Expand Up @@ -44,6 +44,28 @@ public override Task Where_mathf_round2(bool async)
public override Task Convert_ToString(bool async)
=> AssertTranslationFailed(() => base.Convert_ToString(async));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task String_Join_non_aggregate(bool async)
{
var param = "param";
string nullParam = null;

await AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Join("|", c.CustomerID, c.CompanyName, param, nullParam, "constant", null) == "ALFKI|Alfreds Futterkiste|param||constant|"),
entryCount: 1);

AssertSql(
"""
@__param_0='param'

SELECT c."CustomerID", c."Address", c."City", c."CompanyName", c."ContactName", c."ContactTitle", c."Country", c."Fax", c."Phone", c."PostalCode", c."Region"
FROM "Customers" AS c
WHERE concat_ws('|', c."CustomerID", COALESCE(c."CompanyName", ''), COALESCE(@__param_0, ''), COALESCE(NULL, ''), 'constant', '') = 'ALFKI|Alfreds Futterkiste|param||constant|'
""");
}

#region Substring

[ConditionalTheory]
Expand Down

0 comments on commit 0ff4911

Please sign in to comment.