Skip to content

Commit

Permalink
dialect/sql: avoid passing bool arguments on bool predicates (ent#2405)
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m authored and gitlawr committed Apr 13, 2022
1 parent 55273e5 commit fd7ee2a
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 37 deletions.
64 changes: 54 additions & 10 deletions dialect/sql/builder.go
Expand Up @@ -1316,18 +1316,52 @@ func And(preds ...*Predicate) *Predicate {
})
}

// IsTrue appends a predicate that checks if the column value is truthy.
func IsTrue(col string) *Predicate {
return P().IsTrue(col)
}

// IsTrue appends a predicate that checks if the column value is truthy.
func (p *Predicate) IsTrue(col string) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
})
}

// IsFalse appends a predicate that checks if the column value is falsey.
func IsFalse(col string) *Predicate {
return P().IsFalse(col)
}

// IsFalse appends a predicate that checks if the column value is falsey.
func (p *Predicate) IsFalse(col string) *Predicate {
return p.Append(func(b *Builder) {
b.WriteString("NOT ").Ident(col)
})
}

// EQ returns a "=" predicate.
func EQ(col string, value interface{}) *Predicate {
return P().EQ(col, value)
}

// EQ appends a "=" predicate.
func (p *Predicate) EQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsTrue(col)
}
return IsFalse(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
}
}

// ColumnsEQ appends a "=" predicate between 2 columns.
Expand All @@ -1347,11 +1381,21 @@ func NEQ(col string, value interface{}) *Predicate {

// NEQ appends a "<>" predicate.
func (p *Predicate) NEQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsFalse(col)
}
return IsTrue(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
}
}

// ColumnsNEQ appends a "<>" predicate between 2 columns.
Expand Down
38 changes: 27 additions & 11 deletions dialect/sql/builder_test.go
Expand Up @@ -1520,8 +1520,8 @@ func TestBuilder(t *testing.T) {
EQ("active", true),
),
),
wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active" = $5)`,
wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna", true},
wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active")`,
wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna"},
},
{
input: func() Querier {
Expand Down Expand Up @@ -1637,8 +1637,8 @@ func TestSelector_Union(t *testing.T) {
),
).
Query()
require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" WHERE "is_active" = $2 AND "age" > $3 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $4 AND "age" < $5`, query)
require.Equal(t, []interface{}{true, true, 20, "true", 18}, args)
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
require.Equal(t, []interface{}{20, "true", 18}, args)

t1, t2, t3 := Table("files"), Table("files"), Table("path")
n := Queries{
Expand All @@ -1665,8 +1665,8 @@ func TestSelector_Union(t *testing.T) {
From(t3),
}
query, args = n.Query()
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND `files`.`deleted` = ? UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE `files`.`deleted` = ?) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
require.Equal(t, []interface{}{false, false}, args)
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
require.Nil(t, args)
}

func TestBuilderContext(t *testing.T) {
Expand Down Expand Up @@ -1770,7 +1770,7 @@ func TestSelector_UnionOrderBy(t *testing.T) {
Union(Select("*").From(Table("old_users1"))).
OrderBy(table.C("whatever")).
Query()
require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query)
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query)
}

func TestUpdateBuilder_SetExpr(t *testing.T) {
Expand Down Expand Up @@ -1958,8 +1958,7 @@ func TestReusePredicates(t *testing.T) {
}{
{
p: EQ("active", false),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "users" WHERE NOT "active"`,
},
{
p: Or(
Expand All @@ -1979,8 +1978,8 @@ func TestReusePredicates(t *testing.T) {
In("id", Select("oid").From(Table("history"))),
),
),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "name" LIKE $2 AND "name" LIKE $3 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`,
wantArgs: []interface{}{true, "foo%", "%bar"},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "name" LIKE $1 AND "name" LIKE $2 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`,
wantArgs: []interface{}{"foo%", "%bar"},
},
{
p: func() *Predicate {
Expand Down Expand Up @@ -2010,3 +2009,20 @@ func TestReusePredicates(t *testing.T) {
require.Equal(t, tt.wantArgs, args)
}
}

func TestBoolPredicates(t *testing.T) {
t1, t2 := Table("users"), Table("posts")
query, args := Select().
From(t1).
Join(t2).
On(t1.C("id"), t2.C("author_id")).
Where(
And(
EQ(t1.C("active"), true),
NEQ(t2.C("deleted"), true),
),
).
Query()
require.Nil(t, args)
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
}
11 changes: 5 additions & 6 deletions dialect/sql/sqlgraph/entql_test.go
Expand Up @@ -142,22 +142,21 @@ func TestGraph_EvalP(t *testing.T) {
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $2 OR "t1"."name" = $3)`,
wantArgs: []interface{}{true, "GitHub", "GitLab"},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $1 OR "t1"."name" = $2)`,
wantArgs: []interface{}{"GitHub", "GitLab"},
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`,
wantArgs: []interface{}{true},
wantQuery: `SELECT * FROM "users" WHERE "active" AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`,
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) {
s.Where(sql.EQ("owner_id", 10))
})),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $2 AND "owner_id" = $3)`,
wantArgs: []interface{}{true, "pedro", 10},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 AND "owner_id" = $2)`,
wantArgs: []interface{}{"pedro", 10},
},
}
for i, tt := range tests {
Expand Down
14 changes: 6 additions & 8 deletions dialect/sql/sqlgraph/graph_test.go
Expand Up @@ -655,8 +655,7 @@ func TestHasNeighborsWith(t *testing.T) {
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("expired", false))
},
wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE NOT "expired")`,
},
{
name: "O2O/inverse",
Expand Down Expand Up @@ -779,8 +778,7 @@ WHERE "groups"."id" IN
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("expired", false))
},
wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE "expired" = $1)`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE NOT "expired")`,
},
{
name: "schema/O2M",
Expand Down Expand Up @@ -1549,11 +1547,11 @@ func TestUpdateNode(t *testing.T) {
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND `deleted` = ?")).
WithArgs(1, 1, false).
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
WithArgs(1, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND `deleted` = ?")).
WithArgs(1, false).
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND NOT `deleted`")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
AddRow(1, 31, nil))
mock.ExpectCommit()
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/sqljson/sqljson_test.go
Expand Up @@ -81,8 +81,8 @@ func TestWritePath(t *testing.T) {
sql.EQ("active", true),
),
),
wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active` = ?",
wantArgs: []interface{}{100, true},
wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active`",
wantArgs: []interface{}{100},
},
{
input: sql.Dialect(dialect.Postgres).
Expand Down
8 changes: 8 additions & 0 deletions entc/integration/integration_test.go
Expand Up @@ -766,6 +766,14 @@ func Predicate(t *testing.T, client *ent.Client) {
).
CountX(ctx),
)

inf := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx)
hub := client.Group.Create().SetName("GitHub").SetExpire(time.Now()).SetInfo(inf).SaveX(ctx)
lab := client.Group.Create().SetName("GitLab").SetExpire(time.Now()).SetInfo(inf).SetActive(false).SaveX(ctx)
require.Equal(hub.ID, client.Group.Query().Where(group.Active(true)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx))
require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx))
}

func AddValues(t *testing.T, client *ent.Client) {
Expand Down

0 comments on commit fd7ee2a

Please sign in to comment.