diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 4685e0dee2..b25789e6e4 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -2172,7 +2172,7 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector { return s } -// AppendSelectExpr appends additional expressions to the SELECT statement. +// AppendSelectExpr appends additional expressions to the SELECT statement. func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector { for i := range exprs { s.selection = append(s.selection, exprs[i]) @@ -2180,7 +2180,18 @@ func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector { return s } -// SelectedColumns returns the selected columns of the Selector. +// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name. +func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector { + s.selection = append(s.selection, ExprFunc(func(b *Builder) { + b.WriteByte('(') + b.Join(expr) + b.WriteString(") AS ") + b.Ident(as) + })) + return s +} + +// SelectedColumns returns the selected columns in the Selector. func (s *Selector) SelectedColumns() []string { columns := make([]string, 0, len(s.selection)) for i := range s.selection { @@ -2191,6 +2202,28 @@ func (s *Selector) SelectedColumns() []string { return columns } +// UnqualifiedColumns returns the an unqualified version of the +// selected columns in the Selector. e.g. "t1"."c" => "c". +func (s *Selector) UnqualifiedColumns() []string { + columns := make([]string, 0, len(s.selection)) + for i := range s.selection { + c, ok := s.selection[i].(string) + if !ok { + continue + } + if s.isIdent(c) { + parts := strings.FieldsFunc(c, func(r rune) bool { + return r == '`' || r == '"' + }) + if n := len(parts); n > 0 && parts[n-1] != "" { + c = parts[n-1] + } + } + columns = append(columns, c) + } + return columns +} + // From sets the source of `FROM` clause. func (s *Selector) From(t TableView) *Selector { s.from = t @@ -2578,6 +2611,18 @@ func (s *Selector) OrderBy(columns ...string) *Selector { return s } +// OrderColumns returns the ordered columns in the Selector. +// Note, this function skips columns selected with expressions. +func (s *Selector) OrderColumns() []string { + columns := make([]string, 0, len(s.order)) + for i := range s.order { + if c, ok := s.order[i].(string); ok { + columns = append(columns, c) + } + } + return columns +} + // OrderExpr appends the `ORDER BY` clause to the `SELECT` // statement with custom list of expressions. func (s *Selector) OrderExpr(exprs ...Querier) *Selector { @@ -2626,7 +2671,7 @@ func (s *Selector) Query() (string, []interface{}) { b.Ident(t.as) case *WithBuilder: t.SetDialect(s.dialect) - b.Ident(t.name) + b.Ident(t.Name()) } for _, join := range s.joins { b.WriteString(" " + join.kind + " ") @@ -2643,7 +2688,7 @@ func (s *Selector) Query() (string, []interface{}) { b.Ident(view.as) case *WithBuilder: view.SetDialect(s.dialect) - b.Ident(view.name) + b.Ident(view.Name()) } if join.on != nil { b.WriteString(" ON ") @@ -2764,9 +2809,11 @@ func (*Selector) view() {} type WithBuilder struct { Builder recursive bool - name string - columns []string - s *Selector + ctes []struct { + name string + columns []string + s *Selector + } } // With returns a new builder for the `WITH` statement. @@ -2778,7 +2825,15 @@ type WithBuilder struct { // return n.Query() // func With(name string, columns ...string) *WithBuilder { - return &WithBuilder{name: name, columns: columns} + return &WithBuilder{ + ctes: []struct { + name string + columns []string + s *Selector + }{ + {name: name, columns: columns}, + }, + } } // WithRecursive returns a new builder for the `WITH RECURSIVE` statement. @@ -2790,22 +2845,32 @@ func With(name string, columns ...string) *WithBuilder { // return n.Query() // func WithRecursive(name string, columns ...string) *WithBuilder { - return &WithBuilder{name: name, columns: columns, recursive: true} + w := With(name, columns...) + w.recursive = true + return w } // Name returns the name of the view. -func (w *WithBuilder) Name() string { return w.name } +func (w *WithBuilder) Name() string { + return w.ctes[0].name +} // As sets the view sub query. func (w *WithBuilder) As(s *Selector) *WithBuilder { - w.s = s + w.ctes[len(w.ctes)-1].s = s + return w +} + +// With appends another named CTE to the statement. +func (w *WithBuilder) With(name string, columns ...string) *WithBuilder { + w.ctes = append(w.ctes, With(name, columns...).ctes...) return w } // C returns a formatted string for the WITH column. func (w *WithBuilder) C(column string) string { b := &Builder{dialect: w.dialect} - b.Ident(w.name).WriteByte('.').Ident(column) + b.Ident(w.Name()).WriteByte('.').Ident(column) return b.String() } @@ -2815,22 +2880,87 @@ func (w *WithBuilder) Query() (string, []interface{}) { if w.recursive { w.WriteString("RECURSIVE ") } - w.Ident(w.name) - if len(w.columns) > 0 { - w.WriteByte('(') - w.IdentComma(w.columns...) - w.WriteByte(')') + for i, cte := range w.ctes { + if i > 0 { + w.Comma() + } + w.Ident(cte.name) + if len(cte.columns) > 0 { + w.WriteByte('(') + w.IdentComma(cte.columns...) + w.WriteByte(')') + } + w.WriteString(" AS ") + w.Nested(func(b *Builder) { + b.Join(cte.s) + }) } - w.WriteString(" AS ") - w.Nested(func(b *Builder) { - b.Join(w.s) - }) return w.String(), w.args } // implement the table view interface. func (*WithBuilder) view() {} +// WindowBuilder represents a builder for a window clause. +// Note that window functions support is limited and used +// only to query rows-limited edges in pagination. +type WindowBuilder struct { + Builder + fn string // e.g. ROW_NUMBER(), RANK(). + partition func(*Builder) + order func(*Builder) +} + +// RowNumber returns a new window clause with the ROW_NUMBER() as a function. +// Using this function will assign a each row a number, from 1 to N, in the +// order defined by the ORDER BY clause in the window spec. +func RowNumber() *WindowBuilder { + return &WindowBuilder{fn: "ROW_NUMBER"} +} + +// PartitionBy indicates to divide the query rows into groups by the given columns. +// Note that, standard SQL spec allows partition only by columns, and in order to +// use the "expression" version, use the PartitionByExpr. +func (w *WindowBuilder) PartitionBy(columns ...string) *WindowBuilder { + w.partition = func(b *Builder) { + b.IdentComma(columns...) + } + return w +} + +// PartitionExpr indicates to divide the query rows into groups by the given expression. +func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder { + w.partition = func(b *Builder) { + b.Join(x) + } + return w +} + +// OrderBy indicates how to sort rows in each partition. +func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder { + w.order = func(b *Builder) { + b.IdentComma(columns...) + } + return w +} + +// Query returns query representation of the window function. +func (w *WindowBuilder) Query() (string, []interface{}) { + w.WriteString(w.fn) + w.WriteString("() OVER ") + w.Nested(func(b *Builder) { + if w.partition != nil { + b.WriteString("PARTITION BY ") + w.partition(b) + } + if w.order != nil { + b.WriteString(" ORDER BY ") + w.order(b) + } + }) + return w.Builder.String(), w.args +} + // Wrapper wraps a given Querier with different format. // Used to prefix/suffix other queries. type Wrapper struct { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 0798b0fcb4..4886f60859 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2026,3 +2026,38 @@ func TestBoolPredicates(t *testing.T) { require.Nil(t, args) require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query) } + +func TestWindowFunction(t *testing.T) { + posts := Table("posts") + base := Select(posts.Columns("id", "content", "author_id")...). + From(posts). + Where(EQ("active", true)) + with := With("active_posts"). + As(base). + With("selected_posts"). + As( + Select(). + AppendSelect("*"). + AppendSelectExprAs( + RowNumber().PartitionBy("author_id").OrderBy("id"), + "row_number", + ). + From(Table("active_posts")), + ) + query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query() + require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query) + require.Equal(t, []interface{}{2}, args) +} + +func TestSelector_UnqualifiedColumns(t *testing.T) { + t1, t2 := Table("t1"), Table("t2") + s := Select(t1.C("a"), t2.C("b")) + require.Equal(t, []string{"`t1`.`a`", "`t2`.`b`"}, s.SelectedColumns()) + require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns()) + + d := Dialect(dialect.Postgres) + t1, t2 = d.Table("t1"), d.Table("t2") + s = d.Select(t1.C("a"), t2.C("b")) + require.Equal(t, []string{`"t1"."a"`, `"t2"."b"`}, s.SelectedColumns()) + require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns()) +} diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 751fca25ac..8c64f5f0b3 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -1408,9 +1408,7 @@ func Tx(t *testing.T, client *ent.Client) { require.NoError(t, tx.Rollback()) }) t.Run("TxOptions Rollback", func(t *testing.T) { - if client.Dialect() == dialect.SQLite { - t.Skip("Skipping SQLite") - } + skip(t, "SQLite") tx, err := client.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) require.NoError(t, err) var m mocker @@ -1429,9 +1427,7 @@ func Tx(t *testing.T, client *ent.Client) { require.NoError(t, tx.Rollback()) }) t.Run("TxOptions Commit", func(t *testing.T) { - if client.Dialect() == dialect.SQLite { - t.Skip("Skipping SQLite") - } + skip(t, "SQLite") tx, err := client.BeginTx(ctx, &sql.TxOptions{Isolation: stdsql.LevelReadCommitted}) require.NoError(t, err) var m mocker @@ -1685,6 +1681,117 @@ func EagerLoading(t *testing.T, client *ent.Client) { require.Equal(typ.Name, f.Edges.Type.Name) } }) + + t.Run("LimitRows/O2M", func(t *testing.T) { + skip(t, "MySQL/5") + client.Pet.Delete().ExecX(ctx) + client.Pet.Create().SetName("nala").SetOwner(nati).ExecX(ctx) + client.Pet.Create().SetName("xabi3").SetOwner(a8m).ExecX(ctx) + client.Pet.Create().SetName("xabi2").SetOwner(a8m).ExecX(ctx) + client.Pet.Create().SetName("xabi1").SetOwner(a8m).ExecX(ctx) + client.Pet.Create().SetName("lola4").SetOwner(alex).ExecX(ctx) + client.Pet.Create().SetName("lola3").SetOwner(alex).ExecX(ctx) + client.Pet.Create().SetName("lola2").SetOwner(alex).ExecX(ctx) + client.Pet.Create().SetName("lola1").SetOwner(alex).ExecX(ctx) + + users := client.User.Query().WithPets().Order(ent.Asc(user.FieldID)).AllX(ctx) + require.Len(users[0].Edges.Pets, 3) + require.Len(users[1].Edges.Pets, 1) + require.Len(users[2].Edges.Pets, 4) + + users = client.User. + Query(). + WithPets(func(q *ent.PetQuery) { + q.Modify(limitRows(pet.OwnerColumn, 2)) + }). + Order(ent.Asc(user.FieldID)). + AllX(ctx) + require.Len(users[0].Edges.Pets, 2) + require.Equal(users[0].Edges.Pets[0].Name, "xabi3") + require.Equal(users[0].Edges.Pets[1].Name, "xabi2") + require.Len(users[1].Edges.Pets, 1) + require.Equal(users[1].Edges.Pets[0].Name, "nala") + require.Len(users[2].Edges.Pets, 2) + require.Equal(users[2].Edges.Pets[0].Name, "lola4") + require.Equal(users[2].Edges.Pets[1].Name, "lola3") + + users = client.User. + Query(). + WithPets(func(q *ent.PetQuery) { + q.Modify(limitRows(pet.OwnerColumn, 1, pet.FieldName)) + }). + Order(ent.Asc(user.FieldID)). + AllX(ctx) + require.Len(users[0].Edges.Pets, 1) + require.Equal(users[0].Edges.Pets[0].Name, "xabi1") + require.Len(users[1].Edges.Pets, 1) + require.Equal(users[1].Edges.Pets[0].Name, "nala") + require.Len(users[2].Edges.Pets, 1) + require.Equal(users[2].Edges.Pets[0].Name, "lola1") + }) + + t.Run("LimitRows/M2M", func(t *testing.T) { + skip(t, "MySQL/5") + users := client.User.Query().WithGroups().Order(ent.Asc(user.FieldID)).AllX(ctx) + require.Len(users[0].Edges.Groups, 2) + require.Len(users[1].Edges.Groups, 1) + require.Len(users[2].Edges.Groups, 1) + + users = client.User. + Query(). + WithGroups(func(q *ent.GroupQuery) { + q.Modify(limitRows(user.GroupsPrimaryKey[0], 1)) + }). + Order(ent.Asc(user.FieldID)). + AllX(ctx) + require.Len(users[0].Edges.Groups, 1) + require.Equal(users[0].Edges.Groups[0].Name, "GitHub") + require.Len(users[1].Edges.Groups, 1) + require.Equal(users[1].Edges.Groups[0].Name, "GitLab") + require.Len(users[2].Edges.Groups, 1) + require.Equal(users[2].Edges.Groups[0].Name, "GitHub") + + client.Group.Create().SetName("BitBucket").SetExpire(time.Now()).AddUsers(alex, a8m).SetInfo(inf).SaveX(ctx) + users = client.User. + Query(). + WithGroups(func(q *ent.GroupQuery) { + q.Modify(limitRows(user.GroupsPrimaryKey[0], 1, group.FieldName)) + }). + Order(ent.Asc(user.FieldID)). + AllX(ctx) + require.Len(users[0].Edges.Groups, 1) + require.Equal(users[0].Edges.Groups[0].Name, "BitBucket") + require.Len(users[1].Edges.Groups, 1) + require.Equal(users[1].Edges.Groups[0].Name, "GitLab") + require.Len(users[2].Edges.Groups, 1) + require.Equal(users[2].Edges.Groups[0].Name, "BitBucket") + }) +} + +func limitRows(partitionBy string, limit int, orderBy ...string) func(s *sql.Selector) { + return func(s *sql.Selector) { + d := sql.Dialect(s.Dialect()) + s.SetDistinct(false) + if len(orderBy) == 0 { + orderBy = append(orderBy, "id") + } + with := d.With("src_query"). + As(s.Clone()). + With("limited_query"). + As( + d.Select("*"). + AppendSelectExprAs( + sql.RowNumber().PartitionBy(partitionBy).OrderBy(orderBy...), + "row_number", + ). + From(d.Table("src_query")), + ) + t := d.Table("limited_query").As(s.TableName()) + *s = *d.Select(s.UnqualifiedColumns()...). + From(t). + Where(sql.LTE(t.C("row_number"), limit)). + Prefix(with) + } } // writerFunc is an io.Writer implemented by the underlying func. @@ -1851,11 +1958,7 @@ func ConstraintChecks(t *testing.T, client *ent.Client) { } func Lock(t *testing.T, client *ent.Client) { - for _, d := range []string{"SQLite", "MySQL/5", "Maria/10.2"} { - if strings.Contains(t.Name(), d) { - t.Skip("unsupported version") - } - } + skip(t, "SQLite", "MySQL/5", "Maria/10.2") ctx := context.Background() xabi := client.Pet.Create().SetName("Xabi").SaveX(ctx) @@ -1890,9 +1993,7 @@ func Lock(t *testing.T, client *ent.Client) { }) t.Run("ForShare", func(t *testing.T) { - if strings.Contains(t.Name(), "Maria") { - t.Skip("unsupported version") - } + skip(t, "Maria") tx1, err := client.Tx(ctx) require.NoError(t, err) tx2, err := client.Tx(ctx) @@ -1915,6 +2016,14 @@ func Lock(t *testing.T, client *ent.Client) { }) } +func skip(t *testing.T, names ...string) { + for _, n := range names { + if strings.Contains(t.Name(), n) { + t.Skipf("skip %s", n) + } + } +} + func drop(t *testing.T, client *ent.Client) { t.Log("drop data from database") ctx := context.Background()