Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Union Queries for Squirrel #320

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
45 changes: 45 additions & 0 deletions cte.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package squirrel

import (
"bytes"
"strings"
)

// CTE represents a single common table expression. They are composed of an alias, a few optional components, and a data manipulation statement, though exactly what sort of statement depends on the database system you're using. MySQL, for example, only allows SELECT statements; others, like PostgreSQL, permit INSERTs, UPDATEs, and DELETEs.
// The optional components supported by this fork of Squirrel include:
// * a list of columns
// * the keyword RECURSIVE, the use of which may place additional constraints on the data manipulation statement
type CTE struct {
Alias string
ColumnList []string
Recursive bool
Expression Sqlizer
}

// ToSql builds the SQL for a CTE
func (c CTE) ToSql() (string, []interface{}, error) {

var buf bytes.Buffer

if c.Recursive {
buf.WriteString("RECURSIVE ")
}

buf.WriteString(c.Alias)

if len(c.ColumnList) > 0 {
buf.WriteString("(")
buf.WriteString(strings.Join(c.ColumnList, ", "))
buf.WriteString(")")
}

buf.WriteString(" AS (")
sql, args, err := c.Expression.ToSql()
if err != nil {
return "", []interface{}{}, err
}
buf.WriteString(sql)
buf.WriteString(")")

return buf.String(), args, nil
}
42 changes: 42 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package squirrel

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNormalCTE(t *testing.T) {

cte := CTE{
Alias: "cte",
ColumnList: []string{"abc", "def"},
Recursive: false,
Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}),
}

sql, args, err := cte.ToSql()

assert.Equal(t, "cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql)
assert.Equal(t, []interface{}{1}, args)
assert.Nil(t, err)

}

func TestRecursiveCTE(t *testing.T) {

// this isn't usually valid SQL, but the point is to test the RECURSIVE part
cte := CTE{
Alias: "cte",
ColumnList: []string{"abc", "def"},
Recursive: true,
Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}),
}

sql, args, err := cte.ToSql()

assert.Equal(t, "RECURSIVE cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql)
assert.Equal(t, []interface{}{1}, args)
assert.Nil(t, err)

}
68 changes: 68 additions & 0 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type selectData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes []Sqlizer
CTEs []Sqlizer
Union Sqlizer
UnionAll Sqlizer
Options []string
Columns []Sqlizer
From Sqlizer
Expand Down Expand Up @@ -78,6 +81,15 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
sql.WriteString(" ")
}

if len(d.CTEs) > 0 {
sql.WriteString("WITH ")
args, err = appendToSql(d.CTEs, sql, ", ", args)
if err != nil {
return
}
sql.WriteString(" ")
}

sql.WriteString("SELECT ")

if len(d.Options) > 0 {
Expand Down Expand Up @@ -116,6 +128,22 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
}
}

if d.Union != nil {
sql.WriteString(" UNION ")
args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args)
if err != nil {
return
}
}

if d.UnionAll != nil {
sql.WriteString(" UNION ALL ")
args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args)
if err != nil {
return
}
}

if len(d.GroupBys) > 0 {
sql.WriteString(" GROUP BY ")
sql.WriteString(strings.Join(d.GroupBys, ", "))
Expand Down Expand Up @@ -253,6 +281,22 @@ func (b SelectBuilder) Options(options ...string) SelectBuilder {
return builder.Extend(b, "Options", options).(SelectBuilder)
}

// With adds a non-recursive CTE to the query.
func (b SelectBuilder) With(alias string, expr Sqlizer) SelectBuilder {
return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: false, Expression: expr})
}

// WithRecursive adds a recursive CTE to the query.
func (b SelectBuilder) WithRecursive(alias string, expr Sqlizer) SelectBuilder {
return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: true, Expression: expr})
}

// WithCTE adds an arbitrary Sqlizer to the query.
// The sqlizer will be sandwiched between the keyword WITH and, if there's more than one CTE, a comma.
func (b SelectBuilder) WithCTE(cte Sqlizer) SelectBuilder {
return builder.Append(b, "CTEs", cte).(SelectBuilder)
}

// Columns adds result columns to the query.
func (b SelectBuilder) Columns(columns ...string) SelectBuilder {
parts := make([]interface{}, 0, len(columns))
Expand Down Expand Up @@ -289,6 +333,20 @@ func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilde
return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder)
}

// UnionSelect sets a union SelectBuilder which removes duplicate rows
// --> UNION combines the result from multiple SELECT statements into a single result set
func (b SelectBuilder) UnionSelect(union SelectBuilder) SelectBuilder {
union = union.PlaceholderFormat(Question)
return builder.Set(b, "Union", union).(SelectBuilder)
}

// UnionAllSelect sets a union SelectBuilder which includes all matching rows
// --> UNION combines the result from multiple SELECT statements into a single result set
func (b SelectBuilder) UnionAllSelect(union SelectBuilder) SelectBuilder {
union = union.PlaceholderFormat(Question)
return builder.Set(b, "UnionAll", union).(SelectBuilder)
}

// JoinClause adds a join clause to the query.
func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder)
Expand Down Expand Up @@ -319,6 +377,16 @@ func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder
return b.JoinClause("CROSS JOIN "+join, rest...)
}

// Union adds UNION to the query. (duplicate rows are removed)
func (b SelectBuilder) Union(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("UNION "+join, rest...)
}

// UnionAll adds UNION ALL to the query. (includes all matching rows)
func (b SelectBuilder) UnionAll(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("UNION ALL "+join, rest...)
}

// Where adds an expression to the WHERE clause of the query.
//
// Expressions are ANDed together in the generated SQL.
Expand Down
63 changes: 63 additions & 0 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,30 @@ func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) {
assert.Equal(t, []interface{}{1, 2}, args)
}

func TestOneCTE(t *testing.T) {
sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).ToSql()

assert.NoError(t, err)

assert.Equal(t, "WITH cte AS (SELECT abc FROM def) SELECT * FROM cte", sql)
}

func TestTwoCTEs(t *testing.T) {
sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).With("cte2", Select("ghi").From("jkl")).ToSql()

assert.NoError(t, err)

assert.Equal(t, "WITH cte AS (SELECT abc FROM def), cte2 AS (SELECT ghi FROM jkl) SELECT * FROM cte", sql)
}

func TestCTEErrorBubblesUp(t *testing.T) {

// a SELECT with no columns raises an error
_, _, err := Select("*").From("cte").With("cte", SelectBuilder{}.From("def")).ToSql()

assert.Error(t, err)
}

func TestSelectJoinClausePlaceholderNumbering(t *testing.T) {
subquery := Select("a").Where(Eq{"b": 2}).PlaceholderFormat(Dollar)

Expand Down Expand Up @@ -452,6 +476,45 @@ func ExampleSelectBuilder_ToSql() {
}
}

func TestSelectBuilderUnionToSql(t *testing.T) {
multi := Select("column1", "column2").
From("table1").
Where(Eq{"column1": "test"}).
UnionSelect(Select("column3", "column4").From("table2").Where(Lt{"column4": 5}).
UnionSelect(Select("column5", "column6").From("table3").Where(LtOrEq{"column5": 6})))
sql, args, err := multi.ToSql()
assert.NoError(t, err)

expectedSql := `SELECT column1, column2 FROM table1 WHERE column1 = ? ` +
"UNION SELECT column3, column4 FROM table2 WHERE column4 < ? " +
"UNION SELECT column5, column6 FROM table3 WHERE column5 <= ?"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{"test", 5, 6}
assert.Equal(t, expectedArgs, args)

sql, _, err = multi.PlaceholderFormat(Dollar).ToSql()
assert.NoError(t, err)
expectedSql = `SELECT column1, column2 FROM table1 WHERE column1 = $1 ` +
"UNION SELECT column3, column4 FROM table2 WHERE column4 < $2 " +
"UNION SELECT column5, column6 FROM table3 WHERE column5 <= $3"
assert.Equal(t, expectedSql, sql)

unionAll := Select("count(true) as C").
From("table1").
Where(Eq{"column1": []string{"test", "tester"}}).
UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where("id=table2.column3")))
sql, args, err = unionAll.ToSql()
assert.NoError(t, err)

expectedSql = `SELECT count(true) as C FROM table1 WHERE column1 IN (?,?) ` +
"UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id=table2.column3 )"
assert.Equal(t, expectedSql, sql)

expectedArgs = []interface{}{"test", "tester"}
assert.Equal(t, expectedArgs, args)
}

func TestRemoveColumns(t *testing.T) {
query := Select("id").
From("users").
Expand Down