Skip to content

Commit

Permalink
dialect/sql: add support for window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Mar 27, 2022
1 parent cb7e0c1 commit 83cdd2a
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 35 deletions.
172 changes: 151 additions & 21 deletions dialect/sql/builder.go
Expand Up @@ -2172,15 +2172,26 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector {
return s
}

// AppendSelectExpr appends additional expressions to the SELECT statement.
// AppendSelectExpr appends additional expressions to the SELECT statement.
func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector {
for i := range exprs {
s.selection = append(s.selection, exprs[i])
}
return s
}

// SelectedColumns returns the selected columns of the Selector.
// AppendSelectExprAs appends additional expressions to the SELECT statement with the given name.
func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector {
s.selection = append(s.selection, ExprFunc(func(b *Builder) {
b.WriteByte('(')
b.Join(expr)
b.WriteString(") AS ")
b.Ident(as)
}))
return s
}

// SelectedColumns returns the selected columns in the Selector.
func (s *Selector) SelectedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
Expand All @@ -2191,6 +2202,28 @@ func (s *Selector) SelectedColumns() []string {
return columns
}

// UnqualifiedColumns returns the an unqualified version of the
// selected columns in the Selector. e.g. "t1"."c" => "c".
func (s *Selector) UnqualifiedColumns() []string {
columns := make([]string, 0, len(s.selection))
for i := range s.selection {
c, ok := s.selection[i].(string)
if !ok {
continue
}
if s.isIdent(c) {
parts := strings.FieldsFunc(c, func(r rune) bool {
return r == '`' || r == '"'
})
if n := len(parts); n > 0 && parts[n-1] != "" {
c = parts[n-1]
}
}
columns = append(columns, c)
}
return columns
}

// From sets the source of `FROM` clause.
func (s *Selector) From(t TableView) *Selector {
s.from = t
Expand Down Expand Up @@ -2578,6 +2611,18 @@ func (s *Selector) OrderBy(columns ...string) *Selector {
return s
}

// OrderColumns returns the ordered columns in the Selector.
// Note, this function skips columns selected with expressions.
func (s *Selector) OrderColumns() []string {
columns := make([]string, 0, len(s.order))
for i := range s.order {
if c, ok := s.order[i].(string); ok {
columns = append(columns, c)
}
}
return columns
}

// OrderExpr appends the `ORDER BY` clause to the `SELECT`
// statement with custom list of expressions.
func (s *Selector) OrderExpr(exprs ...Querier) *Selector {
Expand Down Expand Up @@ -2626,7 +2671,7 @@ func (s *Selector) Query() (string, []interface{}) {
b.Ident(t.as)
case *WithBuilder:
t.SetDialect(s.dialect)
b.Ident(t.name)
b.Ident(t.Name())
}
for _, join := range s.joins {
b.WriteString(" " + join.kind + " ")
Expand All @@ -2643,7 +2688,7 @@ func (s *Selector) Query() (string, []interface{}) {
b.Ident(view.as)
case *WithBuilder:
view.SetDialect(s.dialect)
b.Ident(view.name)
b.Ident(view.Name())
}
if join.on != nil {
b.WriteString(" ON ")
Expand Down Expand Up @@ -2764,9 +2809,11 @@ func (*Selector) view() {}
type WithBuilder struct {
Builder
recursive bool
name string
columns []string
s *Selector
ctes []struct {
name string
columns []string
s *Selector
}
}

// With returns a new builder for the `WITH` statement.
Expand All @@ -2778,7 +2825,15 @@ type WithBuilder struct {
// return n.Query()
//
func With(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns}
return &WithBuilder{
ctes: []struct {
name string
columns []string
s *Selector
}{
{name: name, columns: columns},
},
}
}

// WithRecursive returns a new builder for the `WITH RECURSIVE` statement.
Expand All @@ -2790,22 +2845,32 @@ func With(name string, columns ...string) *WithBuilder {
// return n.Query()
//
func WithRecursive(name string, columns ...string) *WithBuilder {
return &WithBuilder{name: name, columns: columns, recursive: true}
w := With(name, columns...)
w.recursive = true
return w
}

// Name returns the name of the view.
func (w *WithBuilder) Name() string { return w.name }
func (w *WithBuilder) Name() string {
return w.ctes[0].name
}

// As sets the view sub query.
func (w *WithBuilder) As(s *Selector) *WithBuilder {
w.s = s
w.ctes[len(w.ctes)-1].s = s
return w
}

// With appends another named CTE to the statement.
func (w *WithBuilder) With(name string, columns ...string) *WithBuilder {
w.ctes = append(w.ctes, With(name, columns...).ctes...)
return w
}

// C returns a formatted string for the WITH column.
func (w *WithBuilder) C(column string) string {
b := &Builder{dialect: w.dialect}
b.Ident(w.name).WriteByte('.').Ident(column)
b.Ident(w.Name()).WriteByte('.').Ident(column)
return b.String()
}

Expand All @@ -2815,22 +2880,87 @@ func (w *WithBuilder) Query() (string, []interface{}) {
if w.recursive {
w.WriteString("RECURSIVE ")
}
w.Ident(w.name)
if len(w.columns) > 0 {
w.WriteByte('(')
w.IdentComma(w.columns...)
w.WriteByte(')')
for i, cte := range w.ctes {
if i > 0 {
w.Comma()
}
w.Ident(cte.name)
if len(cte.columns) > 0 {
w.WriteByte('(')
w.IdentComma(cte.columns...)
w.WriteByte(')')
}
w.WriteString(" AS ")
w.Nested(func(b *Builder) {
b.Join(cte.s)
})
}
w.WriteString(" AS ")
w.Nested(func(b *Builder) {
b.Join(w.s)
})
return w.String(), w.args
}

// implement the table view interface.
func (*WithBuilder) view() {}

// WindowBuilder represents a builder for a window clause.
// Note that window functions support is limited and used
// only to query rows-limited edges in pagination.
type WindowBuilder struct {
Builder
fn string // e.g. ROW_NUMBER(), RANK().
partition func(*Builder)
order func(*Builder)
}

// RowNumber returns a new window clause with the ROW_NUMBER() as a function.
// Using this function will assign a each row a number, from 1 to N, in the
// order defined by the ORDER BY clause in the window spec.
func RowNumber() *WindowBuilder {
return &WindowBuilder{fn: "ROW_NUMBER"}
}

// PartitionBy indicates to divide the query rows into groups by the given columns.
// Note that, standard SQL spec allows partition only by columns, and in order to
// use the "expression" version, use the PartitionByExpr.
func (w *WindowBuilder) PartitionBy(columns ...string) *WindowBuilder {
w.partition = func(b *Builder) {
b.IdentComma(columns...)
}
return w
}

// PartitionExpr indicates to divide the query rows into groups by the given expression.
func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder {
w.partition = func(b *Builder) {
b.Join(x)
}
return w
}

// OrderBy indicates how to sort rows in each partition.
func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder {
w.order = func(b *Builder) {
b.IdentComma(columns...)
}
return w
}

// Query returns query representation of the window function.
func (w *WindowBuilder) Query() (string, []interface{}) {
w.WriteString(w.fn)
w.WriteString("() OVER ")
w.Nested(func(b *Builder) {
if w.partition != nil {
b.WriteString("PARTITION BY ")
w.partition(b)
}
if w.order != nil {
b.WriteString(" ORDER BY ")
w.order(b)
}
})
return w.Builder.String(), w.args
}

// Wrapper wraps a given Querier with different format.
// Used to prefix/suffix other queries.
type Wrapper struct {
Expand Down
35 changes: 35 additions & 0 deletions dialect/sql/builder_test.go
Expand Up @@ -2026,3 +2026,38 @@ func TestBoolPredicates(t *testing.T) {
require.Nil(t, args)
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
}

func TestWindowFunction(t *testing.T) {
posts := Table("posts")
base := Select(posts.Columns("id", "content", "author_id")...).
From(posts).
Where(EQ("active", true))
with := With("active_posts").
As(base).
With("selected_posts").
As(
Select().
AppendSelect("*").
AppendSelectExprAs(
RowNumber().PartitionBy("author_id").OrderBy("id"),
"row_number",
).
From(Table("active_posts")),
)
query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query()
require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query)
require.Equal(t, []interface{}{2}, args)
}

func TestSelector_UnqualifiedColumns(t *testing.T) {
t1, t2 := Table("t1"), Table("t2")
s := Select(t1.C("a"), t2.C("b"))
require.Equal(t, []string{"`t1`.`a`", "`t2`.`b`"}, s.SelectedColumns())
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())

d := Dialect(dialect.Postgres)
t1, t2 = d.Table("t1"), d.Table("t2")
s = d.Select(t1.C("a"), t2.C("b"))
require.Equal(t, []string{`"t1"."a"`, `"t2"."b"`}, s.SelectedColumns())
require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns())
}

0 comments on commit 83cdd2a

Please sign in to comment.