Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results in the ordering of the predicates in the query. #92

Open
wants to merge 10 commits into
base: v1
Choose a base branch
from
53 changes: 36 additions & 17 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"reflect"
"sort"
"strings"
)

Expand Down Expand Up @@ -66,6 +67,18 @@ func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) {
return
}

// GenerateOrderPredicateIndex provides a slice of keys useful for ordering predicates.
func GenerateOrderPredicateIndex(predicates map[string]interface{}) []string {
keys := make([]string, len(predicates))
counter := 0
for key := range predicates {
keys[counter] = key
counter++
}
sort.Strings(keys)
return keys
}

// Eq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Eq{"id": 1})
Expand All @@ -74,9 +87,9 @@ type Eq map[string]interface{}
func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
var (
exprs []string
equalOpr string = "="
inOpr string = "IN"
nullOpr string = "IS"
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
)

if useNotOpr {
Expand All @@ -85,8 +98,12 @@ func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
nullOpr = "IS NOT"
}

for key, val := range eq {
expr := ""
predicateIndex := GenerateOrderPredicateIndex(eq)

for _, key := range predicateIndex {
val := eq[key]

var expr string

switch v := val.(type) {
case driver.Valuer:
Expand Down Expand Up @@ -143,7 +160,7 @@ type Lt map[string]interface{}
func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) {
var (
exprs []string
opr string = "<"
opr = "<"
)

if opposite {
Expand All @@ -154,8 +171,10 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
opr = fmt.Sprintf("%s%s", opr, "=")
}

for key, val := range lt {
expr := ""
predicateIndex := GenerateOrderPredicateIndex(lt)

for _, key := range predicateIndex {
val := lt[key]

switch v := val.(type) {
case driver.Valuer:
Expand All @@ -167,16 +186,16 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return
} else {
valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
} else {
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}

valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
}

expr := fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
Expand Down
46 changes: 45 additions & 1 deletion expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,54 @@ package squirrel

import (
"database/sql"
"github.com/stretchr/testify/assert"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGenerateOrderPredicateIndex(t *testing.T) {
output := []string{"one", "two"}

type args struct {
predicates map[string]interface{}
}

tests := []struct {
args args
want []string
}{
{args{Eq{}}, []string{}},
{args{Eq{"one": 1}}, []string{"one"}},
{args{Eq{"one": 1, "two": 2}}, output},
{args{Eq{"two": 2, "one": 1}}, output},

{args{Lt{}}, []string{}},
{args{Lt{"one": 1}}, []string{"one"}},
{args{Lt{"one": 1, "two": 2}}, output},
{args{Lt{"two": 2, "one": 1}}, output},

{args{Gt{}}, []string{}},
{args{Gt{"one": 1}}, []string{"one"}},
{args{Gt{"one": 1, "two": 2}}, output},
{args{Gt{"two": 2, "one": 1}}, output},

{args{GtOrEq{}}, []string{}},
{args{GtOrEq{"one": 1}}, []string{"one"}},
{args{GtOrEq{"one": 1, "two": 2}}, output},
{args{GtOrEq{"two": 2, "one": 1}}, output},

{args{LtOrEq{}}, []string{}},
{args{LtOrEq{"one": 1}}, []string{"one"}},
{args{LtOrEq{"one": 1, "two": 2}}, output},
{args{LtOrEq{"two": 2, "one": 1}}, output},
}

for _, tt := range tests {
got := GenerateOrderPredicateIndex(tt.args.predicates)
assert.Equal(t, tt.want, got)
}
}

func TestEqToSql(t *testing.T) {
b := Eq{"id": 1}
sql, args, err := b.ToSql()
Expand Down