Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common Table Expressions helper #347

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 130 additions & 0 deletions cte.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package squirrel

import (
"bytes"
"fmt"

"github.com/lann/builder"
)

// Common Table Expressions helper
// e.g.
// WITH cte AS (
// ...
// ), cte_2 AS (
// ...
// )
// SELECT ... FROM cte ... cte_2;

type commonTableExpressionsData struct {
PlaceholderFormat PlaceholderFormat
Recursive bool
CurrentCteName string
Ctes []Sqlizer
Statement Sqlizer
}

func (d *commonTableExpressionsData) toSql() (sqlStr string, args []interface{}, err error) {
if len(d.Ctes) == 0 {
err = fmt.Errorf("common table expressions statements must have at least one label and subquery")
return
}

if d.Statement == nil {
err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)")
return
}

sql := &bytes.Buffer{}

sql.WriteString("WITH ")
if d.Recursive {
sql.WriteString("RECURSIVE ")
}

args, err = appendToSql(d.Ctes, sql, ", ", args)
sql.WriteString("\n")
args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args)

sqlStr = sql.String()
return
}

func (d *commonTableExpressionsData) ToSql() (sql string, args []interface{}, err error) {
return d.toSql()
}

// Builder

// CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements.
type CommonTableExpressionsBuilder builder.Builder

func init() {
builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{})
}

// Format methods

// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder {
return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder)
}

// SQL methods

// ToSql builds the query into a SQL string and bound args.
func (b CommonTableExpressionsBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(commonTableExpressionsData)
return data.ToSql()
}

// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CommonTableExpressionsBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}

func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder {
return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder)
}

// Cte starts a new cte
func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder {
return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder)
}

// As sets the expression for the Cte
func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder {
data := builder.GetStruct(b).(commonTableExpressionsData)
return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder)
}

// Select finalizes the CommonTableExpressionsBuilder with a SELECT
func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Insert finalizes the CommonTableExpressionsBuilder with an INSERT
func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Replace finalizes the CommonTableExpressionsBuilder with a REPLACE
func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder {
return b.Insert(statement)
}

// Update finalizes the CommonTableExpressionsBuilder with an UPDATE
func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Delete finalizes the CommonTableExpressionsBuilder with a DELETE
func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
135 changes: 135 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package squirrel

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestWithAsQuery_OneSubquery(t *testing.T) {
w := With("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Select(
Select("col").
From("lab"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"SELECT col FROM lab"
assert.Equal(t, expectedSql, q)

w = WithRecursive("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Select(Select("col").
From("lab"),
)
q, _, err = w.ToSql()
assert.NoError(t, err)

expectedSql = "WITH RECURSIVE lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"SELECT col FROM lab"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_TwoSubqueries(t *testing.T) {
w := With("lab_1").As(
Select("col_1", "col_common").From("tab_1").
Where("simple").
Where("NOT hard"),
).Cte("lab_2").As(
Select("col_2", "col_common").From("tab_2"),
).Select(Select("col_1", "col_2", "col_common").
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab_1 AS (\n" +
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" +
"), lab_2 AS (\n" +
"SELECT col_2, col_common FROM tab_2\n" +
")\n" +
"SELECT col_1, col_2, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_ManySubqueries(t *testing.T) {
w := With("lab_1").As(
Select("col_1", "col_common").From("tab_1").
Where("simple").
Where("NOT hard"),
).Cte("lab_2").As(
Select("col_2", "col_common").From("tab_2"),
).Cte("lab_3").As(
Select("col_3", "col_common").From("tab_3"),
).Cte("lab_4").As(
Select("col_4", "col_common").From("tab_4"),
).Select(
Select("col_1", "col_2", "col_3", "col_4", "col_common").
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common").
Join("lab_3 ON lab_1.col_common = lab_3.col_common").
Join("lab_4 ON lab_1.col_common = lab_4.col_common"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab_1 AS (\n" +
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" +
"), lab_2 AS (\n" +
"SELECT col_2, col_common FROM tab_2\n" +
"), lab_3 AS (\n" +
"SELECT col_3, col_common FROM tab_3\n" +
"), lab_4 AS (\n" +
"SELECT col_4, col_common FROM tab_4\n" +
")\n" +
"SELECT col_1, col_2, col_3, col_4, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common JOIN lab_3 ON lab_1.col_common = lab_3.col_common JOIN lab_4 ON lab_1.col_common = lab_4.col_common"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_Insert(t *testing.T) {
w := With("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Insert(Insert("ins_tab").Columns("ins_col").Select(Select("col").From("lab")))
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"INSERT INTO ins_tab (ins_col) SELECT col FROM lab"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_Update(t *testing.T) {
w := With("lab").As(
Select("col", "common_col").From("tab").
Where("simple").
Where("NOT hard"),
).Update(
Update("upd_tab, lab").
Set("upd_col", Expr("lab.col")).
Where("common_col = lab.common_col"),
)

q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col, common_col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"UPDATE upd_tab, lab SET upd_col = lab.col WHERE common_col = lab.common_col"

assert.Equal(t, expectedSql, q)
}
17 changes: 17 additions & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,20 @@ func isListType(val interface{}) bool {
valVal := reflect.ValueOf(val)
return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice
}

type cteExpr struct {
expr Sqlizer
cte string
}

func Cte(expr Sqlizer, cte string) cteExpr {
return cteExpr{expr, cte}
}

func (e cteExpr) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = e.expr.ToSql()
if err == nil {
sql = fmt.Sprintf("%s AS (\n%s\n)", e.cte, sql)
}
return
}
19 changes: 19 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func (b StatementBuilderType) Delete(from string) DeleteBuilder {
return DeleteBuilder(b).From(from)
}

// With returns a CommonTableExpressionsBuilder for this StatementBuilderType
func (b StatementBuilderType) With(cte string) CommonTableExpressionsBuilder {
return CommonTableExpressionsBuilder(b).Cte(cte)
}

// PlaceholderFormat sets the PlaceholderFormat field for any child builders.
func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType {
return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType)
Expand Down Expand Up @@ -87,6 +92,20 @@ func Delete(from string) DeleteBuilder {
return StatementBuilder.Delete(from)
}

// With returns a new CommonTableExpressionsBuilder with the given first cte name
//
// See CommonTableExpressionsBuilder.Cte
func With(cte string) CommonTableExpressionsBuilder {
return StatementBuilder.With(cte)
}

// WithRecursive returns a new CommonTableExpressionsBuilder with the RECURSIVE option and the given first cte name
//
// See CommonTableExpressionsBuilder.Cte, CommonTableExpressionsBuilder.Recursive
func WithRecursive(cte string) CommonTableExpressionsBuilder {
return StatementBuilder.With(cte).Recursive(true)
}

// Case returns a new CaseBuilder
// "what" represents case value
func Case(what ...interface{}) CaseBuilder {
Expand Down