From 6fa380dd7606231d19522e833534f2213e278236 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Fri, 1 Apr 2022 21:31:00 +0300 Subject: [PATCH] dialect/sql: support for order by expressions in window functions (#2442) --- dialect/sql/builder.go | 30 ++++++++++++------- dialect/sql/builder_test.go | 4 +-- .../dialect/sql/feature/modifier.tmpl | 2 +- entc/integration/ent/card_query.go | 2 +- entc/integration/ent/comment_query.go | 2 +- entc/integration/ent/fieldtype_query.go | 2 +- entc/integration/ent/file_query.go | 2 +- entc/integration/ent/filetype_query.go | 2 +- entc/integration/ent/goods_query.go | 2 +- entc/integration/ent/group_query.go | 2 +- entc/integration/ent/groupinfo_query.go | 2 +- entc/integration/ent/item_query.go | 2 +- entc/integration/ent/node_query.go | 2 +- entc/integration/ent/pet_query.go | 2 +- entc/integration/ent/spec_query.go | 2 +- entc/integration/ent/task_query.go | 2 +- entc/integration/ent/user_query.go | 2 +- .../multischema/ent/group_query.go | 2 +- entc/integration/multischema/ent/pet_query.go | 2 +- .../integration/multischema/ent/user_query.go | 2 +- 20 files changed, 39 insertions(+), 31 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index b25789e6e4..67749b25b7 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2711,7 +2711,7 @@ func (s *Selector) Query() (string, []interface{}) { s.joinUnion(&b) } if len(s.order) > 0 { - s.joinOrder(&b) + joinOrder(s.order, &b) } if s.limit != nil { b.WriteString(" LIMIT ") @@ -2773,17 +2773,17 @@ func (s *Selector) joinUnion(b *Builder) { } } -func (s *Selector) joinOrder(b *Builder) { +func joinOrder(order []interface{}, b *Builder) { b.WriteString(" ORDER BY ") - for i := range s.order { + for i := range order { if i > 0 { b.Comma() } - switch order := s.order[i].(type) { + switch r := order[i].(type) { case string: - b.Ident(order) + b.Ident(r) case Querier: - b.Join(order) + b.Join(r) } } } @@ -2908,7 +2908,7 @@ type WindowBuilder struct { Builder fn string // e.g. ROW_NUMBER(), RANK(). partition func(*Builder) - order func(*Builder) + order []interface{} } // RowNumber returns a new window clause with the ROW_NUMBER() as a function. @@ -2938,8 +2938,17 @@ func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder { // OrderBy indicates how to sort rows in each partition. func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder { - w.order = func(b *Builder) { - b.IdentComma(columns...) + for i := range columns { + w.order = append(w.order, columns[i]) + } + return w +} + +// OrderExpr appends the `ORDER BY` clause to the window +// partition with custom list of expressions. +func (w *WindowBuilder) OrderExpr(exprs ...Querier) *WindowBuilder { + for i := range exprs { + w.order = append(w.order, exprs[i]) } return w } @@ -2954,8 +2963,7 @@ func (w *WindowBuilder) Query() (string, []interface{}) { w.partition(b) } if w.order != nil { - b.WriteString(" ORDER BY ") - w.order(b) + joinOrder(w.order, b) } }) return w.Builder.String(), w.args diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 4886f60859..9435aa8812 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2039,13 +2039,13 @@ func TestWindowFunction(t *testing.T) { Select(). AppendSelect("*"). AppendSelectExprAs( - RowNumber().PartitionBy("author_id").OrderBy("id"), + RowNumber().PartitionBy("author_id").OrderBy("id").OrderExpr(Expr("f(`s`)")), "row_number", ). From(Table("active_posts")), ) query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query() - require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query) + require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`, f(`s`))) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query) require.Equal(t, []interface{}{2}, args) } diff --git a/entc/gen/template/dialect/sql/feature/modifier.tmpl b/entc/gen/template/dialect/sql/feature/modifier.tmpl index 82bd41f22b..5b515d5370 100644 --- a/entc/gen/template/dialect/sql/feature/modifier.tmpl +++ b/entc/gen/template/dialect/sql/feature/modifier.tmpl @@ -11,7 +11,7 @@ in the LICENSE file in the root directory of this source tree. {{/* Template for adding the "modifiers" field to the query builder. */}} {{ define "dialect/sql/query/fields/additional/modify" -}} {{- if or ($.FeatureEnabled "sql/lock") ($.FeatureEnabled "sql/modifier") }} - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) {{- end }} {{- end -}} diff --git a/entc/integration/ent/card_query.go b/entc/integration/ent/card_query.go index 87d8962f13..87b5f0074b 100644 --- a/entc/integration/ent/card_query.go +++ b/entc/integration/ent/card_query.go @@ -35,7 +35,7 @@ type CardQuery struct { withOwner *UserQuery withSpec *SpecQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/comment_query.go b/entc/integration/ent/comment_query.go index 2786e0e927..698a3b425a 100644 --- a/entc/integration/ent/comment_query.go +++ b/entc/integration/ent/comment_query.go @@ -28,7 +28,7 @@ type CommentQuery struct { order []OrderFunc fields []string predicates []predicate.Comment - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/fieldtype_query.go b/entc/integration/ent/fieldtype_query.go index 47c90a9596..6a0ad8e925 100644 --- a/entc/integration/ent/fieldtype_query.go +++ b/entc/integration/ent/fieldtype_query.go @@ -29,7 +29,7 @@ type FieldTypeQuery struct { fields []string predicates []predicate.FieldType withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/file_query.go b/entc/integration/ent/file_query.go index e79495bec5..d0825ce2e1 100644 --- a/entc/integration/ent/file_query.go +++ b/entc/integration/ent/file_query.go @@ -37,7 +37,7 @@ type FileQuery struct { withType *FileTypeQuery withField *FieldTypeQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/filetype_query.go b/entc/integration/ent/filetype_query.go index e8dd352290..20ee5ea68f 100644 --- a/entc/integration/ent/filetype_query.go +++ b/entc/integration/ent/filetype_query.go @@ -32,7 +32,7 @@ type FileTypeQuery struct { predicates []predicate.FileType // eager-loading edges. withFiles *FileQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/goods_query.go b/entc/integration/ent/goods_query.go index 16f4bac4dc..ec63a90e6d 100644 --- a/entc/integration/ent/goods_query.go +++ b/entc/integration/ent/goods_query.go @@ -28,7 +28,7 @@ type GoodsQuery struct { order []OrderFunc fields []string predicates []predicate.Goods - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/group_query.go b/entc/integration/ent/group_query.go index 97cfc87eb5..fb5863799c 100644 --- a/entc/integration/ent/group_query.go +++ b/entc/integration/ent/group_query.go @@ -38,7 +38,7 @@ type GroupQuery struct { withUsers *UserQuery withInfo *GroupInfoQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/groupinfo_query.go b/entc/integration/ent/groupinfo_query.go index f4fddd7d12..3bb0d9c45d 100644 --- a/entc/integration/ent/groupinfo_query.go +++ b/entc/integration/ent/groupinfo_query.go @@ -32,7 +32,7 @@ type GroupInfoQuery struct { predicates []predicate.GroupInfo // eager-loading edges. withGroups *GroupQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/item_query.go b/entc/integration/ent/item_query.go index 6ead8f1ce9..a905bfb0f1 100644 --- a/entc/integration/ent/item_query.go +++ b/entc/integration/ent/item_query.go @@ -28,7 +28,7 @@ type ItemQuery struct { order []OrderFunc fields []string predicates []predicate.Item - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/node_query.go b/entc/integration/ent/node_query.go index a82b000ec4..f04f745716 100644 --- a/entc/integration/ent/node_query.go +++ b/entc/integration/ent/node_query.go @@ -33,7 +33,7 @@ type NodeQuery struct { withPrev *NodeQuery withNext *NodeQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/pet_query.go b/entc/integration/ent/pet_query.go index 1213381954..f57a455efa 100644 --- a/entc/integration/ent/pet_query.go +++ b/entc/integration/ent/pet_query.go @@ -33,7 +33,7 @@ type PetQuery struct { withTeam *UserQuery withOwner *UserQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/spec_query.go b/entc/integration/ent/spec_query.go index 9172d6e8fc..58c1443912 100644 --- a/entc/integration/ent/spec_query.go +++ b/entc/integration/ent/spec_query.go @@ -32,7 +32,7 @@ type SpecQuery struct { predicates []predicate.Spec // eager-loading edges. withCard *CardQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/task_query.go b/entc/integration/ent/task_query.go index 7ef8f972f8..0bac116e36 100644 --- a/entc/integration/ent/task_query.go +++ b/entc/integration/ent/task_query.go @@ -29,7 +29,7 @@ type TaskQuery struct { order []OrderFunc fields []string predicates []predicate.Task - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/ent/user_query.go b/entc/integration/ent/user_query.go index cc9aea3627..e2f590c9d1 100644 --- a/entc/integration/ent/user_query.go +++ b/entc/integration/ent/user_query.go @@ -46,7 +46,7 @@ type UserQuery struct { withChildren *UserQuery withParent *UserQuery withFKs bool - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index adbc6f61f7..c2bb199a34 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -32,7 +32,7 @@ type GroupQuery struct { predicates []predicate.Group // eager-loading edges. withUsers *UserQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index 15dd9c1d24..95fc3aef2f 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -31,7 +31,7 @@ type PetQuery struct { predicates []predicate.Pet // eager-loading edges. withOwner *UserQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index e97b8d0b41..8c0bcb110a 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -34,7 +34,7 @@ type UserQuery struct { // eager-loading edges. withPets *PetQuery withGroups *GroupQuery - modifiers []func(s *sql.Selector) + modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error)