Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialect/sql: add support for window functions #2431

Merged
merged 1 commit into from Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
172 changes: 151 additions & 21 deletions dialect/sql/builder.go
Expand Up @@ -2172,15 +2172,26 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
return s
}

// AppendSelectExpr appends additional expressions to the SELECT statement.
// AppendSelectExpr appends additional expressions to the SELECT statement.
func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
for i := range exprs {
s.selection = append(s.selection, exprs[i])
}
return s
}

// SelectedColumns returns the selected columns of the Selector.
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
b.WriteByte('(')
b.Join(expr)
b.WriteString(") AS ")
b.Ident(as)
}))
return s
}

// SelectedColumns returns the selected columns in the Selector.
func (s *Selector) SelectedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
Expand All @@ -2191,6 +2202,28 @@ func (s *Selector) SelectedColumns() []string {
return columns
}

// UnqualifiedColumns returns the an unqualified version of the
// selected columns in the Selector. e.g. "t1"."c" => "c".
func (s *Selector) UnqualifiedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
c, ok := s.selection[i].(string)
if !ok {
continue
}
if s.isIdent(c) {
parts := strings.FieldsFunc(c, func(r rune) bool {
return r == '`' || r == '"'
})
if n := len(parts); n > 0 && parts[n-1] != "" {
c = parts[n-1]
}
}
columns = append(columns, c)
}
return columns
}

// From sets the source of `FROM` clause.
func (s *Selector) From(t TableView) *Selector {
s.from = t
Expand Down Expand Up @@ -2578,6 +2611,18 @@ func (s *Selector) OrderBy(columns ...string) *Selector {
return s
}

// OrderColumns returns the ordered columns in the Selector.
// Note, this function skips columns selected with expressions.
func (s *Selector) OrderColumns() []string {
columns := make([]string, 0, len(s.order))
for i := range s.order {
if c, ok := s.order[i].(string); ok {
columns = append(columns, c)
}
}
return columns
}

// OrderExpr appends the `ORDER BY` clause to the `SELECT`
// statement with custom list of expressions.
func (s *Selector) OrderExpr(exprs ...Querier) *Selector {
Expand Down Expand Up @@ -2626,7 +2671,7 @@ func (s *Selector) Query() (string, []interface{}) {
b.Ident(t.as)
case *WithBuilder:
t.SetDialect(s.dialect)
b.Ident(t.name)
b.Ident(t.Name())
}
for _, join := range s.joins {
b.WriteString(" " + join.kind + " ")
Expand All @@ -2643,7 +2688,7 @@ func (s *Selector) Query() (string, []interface{}) {
b.Ident(view.as)
case *WithBuilder:
view.SetDialect(s.dialect)
b.Ident(view.name)
b.Ident(view.Name())
}
if join.on != nil {
b.WriteString(" ON ")
Expand Down Expand Up @@ -2764,9 +2809,11 @@ func (*Selector) view() {}
type WithBuilder struct {
Builder
recursive bool
name string
columns []string
s *Selector
ctes []struct {
name string
columns []string
s *Selector
}
}

// With returns a new builder for the `WITH` statement.
Expand All @@ -2778,7 +2825,15 @@ type WithBuilder struct {
// return n.Query()
//
func With(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns}
return &WithBuilder{
ctes: []struct {
name string
columns []string
s *Selector
}{
{name: name, columns: columns},
},
}
}

// WithRecursive returns a new builder for the `WITH RECURSIVE` statement.
Expand All @@ -2790,22 +2845,32 @@ func With(name string, columns ...string) *WithBuilder {
// return n.Query()
//
func WithRecursive(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns, recursive: true}
w := With(name, columns...)
w.recursive = true
return w
}

// Name returns the name of the view.
func (w *WithBuilder) Name() string { return w.name }
func (w *WithBuilder) Name() string {
return w.ctes[0].name
}

// As sets the view sub query.
func (w *WithBuilder) As(s *Selector) *WithBuilder {
w.s = s
w.ctes[len(w.ctes)-1].s = s
return w
}

// With appends another named CTE to the statement.
func (w *WithBuilder) With(name string, columns ...string) *WithBuilder {
w.ctes = append(w.ctes, With(name, columns...).ctes...)
return w
}

// C returns a formatted string for the WITH column.
func (w *WithBuilder) C(column string) string {
b := &Builder{dialect: w.dialect}
b.Ident(w.name).WriteByte('.').Ident(column)
b.Ident(w.Name()).WriteByte('.').Ident(column)
return b.String()
}

Expand All @@ -2815,22 +2880,87 @@ func (w *WithBuilder) Query() (string, []interface{}) {
if w.recursive {
w.WriteString("RECURSIVE ")
}
w.Ident(w.name)
if len(w.columns) > 0 {
w.WriteByte('(')
w.IdentComma(w.columns...)
w.WriteByte(')')
for i, cte := range w.ctes {
if i > 0 {
w.Comma()
}
w.Ident(cte.name)
if len(cte.columns) > 0 {
w.WriteByte('(')
w.IdentComma(cte.columns...)
w.WriteByte(')')
}
w.WriteString(" AS ")
w.Nested(func(b *Builder) {
b.Join(cte.s)
})
}
w.WriteString(" AS ")
w.Nested(func(b *Builder) {
b.Join(w.s)
})
return w.String(), w.args
}

// implement the table view interface.
func (*WithBuilder) view() {}

// WindowBuilder represents a builder for a window clause.
// Note that window functions support is limited and used
// only to query rows-limited edges in pagination.
type WindowBuilder struct {
Builder
fn string // e.g. ROW_NUMBER(), RANK().
partition func(*Builder)
order func(*Builder)
}

// RowNumber returns a new window clause with the ROW_NUMBER() as a function.
// Using this function will assign a each row a number, from 1 to N, in the
// order defined by the ORDER BY clause in the window spec.
func RowNumber() *WindowBuilder {
return &WindowBuilder{fn: "ROW_NUMBER"}
}

// PartitionBy indicates to divide the query rows into groups by the given columns.
// Note that, standard SQL spec allows partition only by columns, and in order to
// use the "expression" version, use the PartitionByExpr.
func (w *WindowBuilder) PartitionBy(columns ...string) *WindowBuilder {
w.partition = func(b *Builder) {
b.IdentComma(columns...)
}
return w
}

// PartitionExpr indicates to divide the query rows into groups by the given expression.
func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder {
w.partition = func(b *Builder) {
b.Join(x)
}
return w
}

// OrderBy indicates how to sort rows in each partition.
func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder {
w.order = func(b *Builder) {
b.IdentComma(columns...)
}
return w
}

// Query returns query representation of the window function.
func (w *WindowBuilder) Query() (string, []interface{}) {
w.WriteString(w.fn)
w.WriteString("() OVER ")
w.Nested(func(b *Builder) {
if w.partition != nil {
b.WriteString("PARTITION BY ")
w.partition(b)
}
if w.order != nil {
b.WriteString(" ORDER BY ")
w.order(b)
}
})
return w.Builder.String(), w.args
}

// Wrapper wraps a given Querier with different format.
// Used to prefix/suffix other queries.
type Wrapper struct {
Expand Down
35 changes: 35 additions & 0 deletions dialect/sql/builder_test.go
Expand Up @@ -2026,3 +2026,38 @@ func TestBoolPredicates(t *testing.T) {
require.Nil(t, args)
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
}

func TestWindowFunction(t *testing.T) {
posts := Table("posts")
base := Select(posts.Columns("id", "content", "author_id")...).
From(posts).
Where(EQ("active", true))
with := With("active_posts").
As(base).
With("selected_posts").
As(
Select().
AppendSelect("*").
AppendSelectExprAs(
RowNumber().PartitionBy("author_id").OrderBy("id"),
"row_number",
).
From(Table("active_posts")),
)
query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query()
require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query)
require.Equal(t, []interface{}{2}, args)
}

func TestSelector_UnqualifiedColumns(t *testing.T) {
t1, t2 := Table("t1"), Table("t2")
s := Select(t1.C("a"), t2.C("b"))
require.Equal(t, []string{"`t1`.`a`", "`t2`.`b`"}, s.SelectedColumns())
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())

d := Dialect(dialect.Postgres)
t1, t2 = d.Table("t1"), d.Table("t2")
s = d.Select(t1.C("a"), t2.C("b"))
require.Equal(t, []string{`"t1"."a"`, `"t2"."b"`}, s.SelectedColumns())
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())
}