Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialect/sql: avoid passing bool arguments on bool predicates #2405

Merged
merged 1 commit into from Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 54 additions & 10 deletions dialect/sql/builder.go
Expand Up @@ -1316,18 +1316,52 @@ func And(preds ...*Predicate) *Predicate {
})
}

// IsTrue appends a predicate that checks if the column value is truthy.
func IsTrue(col string) *Predicate {
return P().IsTrue(col)
}

// IsTrue appends a predicate that checks if the column value is truthy.
func (p *Predicate) IsTrue(col string) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
})
}

// IsFalse appends a predicate that checks if the column value is falsey.
func IsFalse(col string) *Predicate {
return P().IsFalse(col)
}

// IsFalse appends a predicate that checks if the column value is falsey.
func (p *Predicate) IsFalse(col string) *Predicate {
return p.Append(func(b *Builder) {
b.WriteString("NOT ").Ident(col)
})
}

// EQ returns a "=" predicate.
func EQ(col string, value interface{}) *Predicate {
return P().EQ(col, value)
}

// EQ appends a "=" predicate.
func (p *Predicate) EQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsTrue(col)
}
return IsFalse(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
}
}

// ColumnsEQ appends a "=" predicate between 2 columns.
Expand All @@ -1347,11 +1381,21 @@ func NEQ(col string, value interface{}) *Predicate {

// NEQ appends a "<>" predicate.
func (p *Predicate) NEQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsFalse(col)
}
return IsTrue(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
}
}

// ColumnsNEQ appends a "<>" predicate between 2 columns.
Expand Down
38 changes: 27 additions & 11 deletions dialect/sql/builder_test.go
Expand Up @@ -1520,8 +1520,8 @@ func TestBuilder(t *testing.T) {
EQ("active", true),
),
),
wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active" = $5)`,
wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna", true},
wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active")`,
wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna"},
},
{
input: func() Querier {
Expand Down Expand Up @@ -1637,8 +1637,8 @@ func TestSelector_Union(t *testing.T) {
),
).
Query()
require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" WHERE "is_active" = $2 AND "age" > $3 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $4 AND "age" < $5`, query)
require.Equal(t, []interface{}{true, true, 20, "true", 18}, args)
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query)
require.Equal(t, []interface{}{20, "true", 18}, args)

t1, t2, t3 := Table("files"), Table("files"), Table("path")
n := Queries{
Expand All @@ -1665,8 +1665,8 @@ func TestSelector_Union(t *testing.T) {
From(t3),
}
query, args = n.Query()
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND `files`.`deleted` = ? UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE `files`.`deleted` = ?) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
require.Equal(t, []interface{}{false, false}, args)
require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query)
require.Nil(t, args)
}

func TestBuilderContext(t *testing.T) {
Expand Down Expand Up @@ -1770,7 +1770,7 @@ func TestSelector_UnionOrderBy(t *testing.T) {
Union(Select("*").From(Table("old_users1"))).
OrderBy(table.C("whatever")).
Query()
require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query)
require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query)
}

func TestUpdateBuilder_SetExpr(t *testing.T) {
Expand Down Expand Up @@ -1958,8 +1958,7 @@ func TestReusePredicates(t *testing.T) {
}{
{
p: EQ("active", false),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "users" WHERE NOT "active"`,
},
{
p: Or(
Expand All @@ -1979,8 +1978,8 @@ func TestReusePredicates(t *testing.T) {
In("id", Select("oid").From(Table("history"))),
),
),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "name" LIKE $2 AND "name" LIKE $3 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`,
wantArgs: []interface{}{true, "foo%", "%bar"},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "name" LIKE $1 AND "name" LIKE $2 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`,
wantArgs: []interface{}{"foo%", "%bar"},
},
{
p: func() *Predicate {
Expand Down Expand Up @@ -2010,3 +2009,20 @@ func TestReusePredicates(t *testing.T) {
require.Equal(t, tt.wantArgs, args)
}
}

func TestBoolPredicates(t *testing.T) {
t1, t2 := Table("users"), Table("posts")
query, args := Select().
From(t1).
Join(t2).
On(t1.C("id"), t2.C("author_id")).
Where(
And(
EQ(t1.C("active"), true),
NEQ(t2.C("deleted"), true),
),
).
Query()
require.Nil(t, args)
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
}
11 changes: 5 additions & 6 deletions dialect/sql/sqlgraph/entql_test.go
Expand Up @@ -142,22 +142,21 @@ func TestGraph_EvalP(t *testing.T) {
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $2 OR "t1"."name" = $3)`,
wantArgs: []interface{}{true, "GitHub", "GitLab"},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $1 OR "t1"."name" = $2)`,
wantArgs: []interface{}{"GitHub", "GitLab"},
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`,
wantArgs: []interface{}{true},
wantQuery: `SELECT * FROM "users" WHERE "active" AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`,
},
{
s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)),
p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) {
s.Where(sql.EQ("owner_id", 10))
})),
wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $2 AND "owner_id" = $3)`,
wantArgs: []interface{}{true, "pedro", 10},
wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 AND "owner_id" = $2)`,
wantArgs: []interface{}{"pedro", 10},
},
}
for i, tt := range tests {
Expand Down
14 changes: 6 additions & 8 deletions dialect/sql/sqlgraph/graph_test.go
Expand Up @@ -655,8 +655,7 @@ func TestHasNeighborsWith(t *testing.T) {
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("expired", false))
},
wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE NOT "expired")`,
},
{
name: "O2O/inverse",
Expand Down Expand Up @@ -779,8 +778,7 @@ WHERE "groups"."id" IN
predicate: func(s *sql.Selector) {
s.Where(sql.EQ("expired", false))
},
wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE "expired" = $1)`,
wantArgs: []interface{}{false},
wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE NOT "expired")`,
},
{
name: "schema/O2M",
Expand Down Expand Up @@ -1549,11 +1547,11 @@ func TestUpdateNode(t *testing.T) {
},
prepare: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND `deleted` = ?")).
WithArgs(1, 1, false).
mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")).
WithArgs(1, 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND `deleted` = ?")).
WithArgs(1, false).
mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND NOT `deleted`")).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}).
AddRow(1, 31, nil))
mock.ExpectCommit()
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/sqljson/sqljson_test.go
Expand Up @@ -81,8 +81,8 @@ func TestWritePath(t *testing.T) {
sql.EQ("active", true),
),
),
wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active` = ?",
wantArgs: []interface{}{100, true},
wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active`",
wantArgs: []interface{}{100},
},
{
input: sql.Dialect(dialect.Postgres).
Expand Down
8 changes: 8 additions & 0 deletions entc/integration/integration_test.go
Expand Up @@ -766,6 +766,14 @@ func Predicate(t *testing.T, client *ent.Client) {
).
CountX(ctx),
)

inf := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx)
hub := client.Group.Create().SetName("GitHub").SetExpire(time.Now()).SetInfo(inf).SaveX(ctx)
lab := client.Group.Create().SetName("GitLab").SetExpire(time.Now()).SetInfo(inf).SetActive(false).SaveX(ctx)
require.Equal(hub.ID, client.Group.Query().Where(group.Active(true)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx))
require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx))
}

func AddValues(t *testing.T, client *ent.Client) {
Expand Down