diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 72a65616e6..4685e0dee2 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1316,6 +1316,30 @@ func And(preds ...*Predicate) *Predicate { }) } +// IsTrue appends a predicate that checks if the column value is truthy. +func IsTrue(col string) *Predicate { + return P().IsTrue(col) +} + +// IsTrue appends a predicate that checks if the column value is truthy. +func (p *Predicate) IsTrue(col string) *Predicate { + return p.Append(func(b *Builder) { + b.Ident(col) + }) +} + +// IsFalse appends a predicate that checks if the column value is falsey. +func IsFalse(col string) *Predicate { + return P().IsFalse(col) +} + +// IsFalse appends a predicate that checks if the column value is falsey. +func (p *Predicate) IsFalse(col string) *Predicate { + return p.Append(func(b *Builder) { + b.WriteString("NOT ").Ident(col) + }) +} + // EQ returns a "=" predicate. func EQ(col string, value interface{}) *Predicate { return P().EQ(col, value) @@ -1323,11 +1347,21 @@ func EQ(col string, value interface{}) *Predicate { // EQ appends a "=" predicate. func (p *Predicate) EQ(col string, arg interface{}) *Predicate { - return p.Append(func(b *Builder) { - b.Ident(col) - b.WriteOp(OpEQ) - p.arg(b, arg) - }) + // A small optimization to avoid passing + // arguments when it can be avoided. + switch arg := arg.(type) { + case bool: + if arg { + return IsTrue(col) + } + return IsFalse(col) + default: + return p.Append(func(b *Builder) { + b.Ident(col) + b.WriteOp(OpEQ) + p.arg(b, arg) + }) + } } // ColumnsEQ appends a "=" predicate between 2 columns. @@ -1347,11 +1381,21 @@ func NEQ(col string, value interface{}) *Predicate { // NEQ appends a "<>" predicate. func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { - return p.Append(func(b *Builder) { - b.Ident(col) - b.WriteOp(OpNEQ) - p.arg(b, arg) - }) + // A small optimization to avoid passing + // arguments when it can be avoided. + switch arg := arg.(type) { + case bool: + if arg { + return IsFalse(col) + } + return IsTrue(col) + default: + return p.Append(func(b *Builder) { + b.Ident(col) + b.WriteOp(OpNEQ) + p.arg(b, arg) + }) + } } // ColumnsNEQ appends a "<>" predicate between 2 columns. diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 7157bc5b9c..0798b0fcb4 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -1520,8 +1520,8 @@ func TestBuilder(t *testing.T) { EQ("active", true), ), ), - wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active" = $5)`, - wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna", true}, + wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active")`, + wantArgs: []interface{}{"pedro", "pedro", "pedro", "luna"}, }, { input: func() Querier { @@ -1637,8 +1637,8 @@ func TestSelector_Union(t *testing.T) { ), ). Query() - require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" WHERE "is_active" = $2 AND "age" > $3 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $4 AND "age" < $5`, query) - require.Equal(t, []interface{}{true, true, 20, "true", 18}, args) + require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) + require.Equal(t, []interface{}{20, "true", 18}, args) t1, t2, t3 := Table("files"), Table("files"), Table("path") n := Queries{ @@ -1665,8 +1665,8 @@ func TestSelector_Union(t *testing.T) { From(t3), } query, args = n.Query() - require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND `files`.`deleted` = ? UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE `files`.`deleted` = ?) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query) - require.Equal(t, []interface{}{false, false}, args) + require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query) + require.Nil(t, args) } func TestBuilderContext(t *testing.T) { @@ -1770,7 +1770,7 @@ func TestSelector_UnionOrderBy(t *testing.T) { Union(Select("*").From(Table("old_users1"))). OrderBy(table.C("whatever")). Query() - require.Equal(t, `SELECT * FROM "users" WHERE "active" = $1 UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query) + require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query) } func TestUpdateBuilder_SetExpr(t *testing.T) { @@ -1958,8 +1958,7 @@ func TestReusePredicates(t *testing.T) { }{ { p: EQ("active", false), - wantQuery: `SELECT * FROM "users" WHERE "active" = $1`, - wantArgs: []interface{}{false}, + wantQuery: `SELECT * FROM "users" WHERE NOT "active"`, }, { p: Or( @@ -1979,8 +1978,8 @@ func TestReusePredicates(t *testing.T) { In("id", Select("oid").From(Table("history"))), ), ), - wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "name" LIKE $2 AND "name" LIKE $3 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`, - wantArgs: []interface{}{true, "foo%", "%bar"}, + wantQuery: `SELECT * FROM "users" WHERE "active" AND "name" LIKE $1 AND "name" LIKE $2 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`, + wantArgs: []interface{}{"foo%", "%bar"}, }, { p: func() *Predicate { @@ -2010,3 +2009,20 @@ func TestReusePredicates(t *testing.T) { require.Equal(t, tt.wantArgs, args) } } + +func TestBoolPredicates(t *testing.T) { + t1, t2 := Table("users"), Table("posts") + query, args := Select(). + From(t1). + Join(t2). + On(t1.C("id"), t2.C("author_id")). + Where( + And( + EQ(t1.C("active"), true), + NEQ(t2.C("deleted"), true), + ), + ). + Query() + require.Nil(t, args) + require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query) +} diff --git a/dialect/sql/sqlgraph/entql_test.go b/dialect/sql/sqlgraph/entql_test.go index e4dafe272d..cbbc637efe 100644 --- a/dialect/sql/sqlgraph/entql_test.go +++ b/dialect/sql/sqlgraph/entql_test.go @@ -142,22 +142,21 @@ func TestGraph_EvalP(t *testing.T) { { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))), - wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $2 OR "t1"."name" = $3)`, - wantArgs: []interface{}{true, "GitHub", "GitLab"}, + wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $1 OR "t1"."name" = $2)`, + wantArgs: []interface{}{"GitHub", "GitLab"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))), - wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`, - wantArgs: []interface{}{true}, + wantQuery: `SELECT * FROM "users" WHERE "active" AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) { s.Where(sql.EQ("owner_id", 10)) })), - wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $2 AND "owner_id" = $3)`, - wantArgs: []interface{}{true, "pedro", 10}, + wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 AND "owner_id" = $2)`, + wantArgs: []interface{}{"pedro", 10}, }, } for i, tt := range tests { diff --git a/dialect/sql/sqlgraph/graph_test.go b/dialect/sql/sqlgraph/graph_test.go index 9c9c9afbb3..7389f46b3d 100644 --- a/dialect/sql/sqlgraph/graph_test.go +++ b/dialect/sql/sqlgraph/graph_test.go @@ -655,8 +655,7 @@ func TestHasNeighborsWith(t *testing.T) { predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, - wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`, - wantArgs: []interface{}{false}, + wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE NOT "expired")`, }, { name: "O2O/inverse", @@ -779,8 +778,7 @@ WHERE "groups"."id" IN predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, - wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE "expired" = $1)`, - wantArgs: []interface{}{false}, + wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE NOT "expired")`, }, { name: "schema/O2M", @@ -1549,11 +1547,11 @@ func TestUpdateNode(t *testing.T) { }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() - mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND `deleted` = ?")). - WithArgs(1, 1, false). + mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")). + WithArgs(1, 1). WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND `deleted` = ?")). - WithArgs(1, false). + mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ? AND NOT `deleted`")). + WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() diff --git a/dialect/sql/sqljson/sqljson_test.go b/dialect/sql/sqljson/sqljson_test.go index 8c40a5782b..35ceb02b6d 100644 --- a/dialect/sql/sqljson/sqljson_test.go +++ b/dialect/sql/sqljson/sqljson_test.go @@ -81,8 +81,8 @@ func TestWritePath(t *testing.T) { sql.EQ("active", true), ), ), - wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active` = ?", - wantArgs: []interface{}{100, true}, + wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND (JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL OR JSON_TYPE(`j`, \"$.a.*.c\") = 'null') AND `active`", + wantArgs: []interface{}{100}, }, { input: sql.Dialect(dialect.Postgres). diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 24ed9a7db7..12605bf6bc 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -766,6 +766,14 @@ func Predicate(t *testing.T, client *ent.Client) { ). CountX(ctx), ) + + inf := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx) + hub := client.Group.Create().SetName("GitHub").SetExpire(time.Now()).SetInfo(inf).SaveX(ctx) + lab := client.Group.Create().SetName("GitLab").SetExpire(time.Now()).SetInfo(inf).SetActive(false).SaveX(ctx) + require.Equal(hub.ID, client.Group.Query().Where(group.Active(true)).OnlyIDX(ctx)) + require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx)) + require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx)) + require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx)) } func AddValues(t *testing.T, client *ent.Client) {