diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 72a65616e6..4685e0dee2 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1316,6 +1316,30 @@ func And(preds ...*Predicate) *Predicate { }) } +// IsTrue appends a predicate that checks if the column value is truthy. +func IsTrue(col string) *Predicate { + return P().IsTrue(col) +} + +// IsTrue appends a predicate that checks if the column value is truthy. +func (p *Predicate) IsTrue(col string) *Predicate { + return p.Append(func(b *Builder) { + b.Ident(col) + }) +} + +// IsFalse appends a predicate that checks if the column value is falsey. +func IsFalse(col string) *Predicate { + return P().IsFalse(col) +} + +// IsFalse appends a predicate that checks if the column value is falsey. +func (p *Predicate) IsFalse(col string) *Predicate { + return p.Append(func(b *Builder) { + b.WriteString("NOT ").Ident(col) + }) +} + // EQ returns a "=" predicate. func EQ(col string, value interface{}) *Predicate { return P().EQ(col, value) @@ -1323,11 +1347,21 @@ func EQ(col string, value interface{}) *Predicate { // EQ appends a "=" predicate. func (p *Predicate) EQ(col string, arg interface{}) *Predicate { - return p.Append(func(b *Builder) { - b.Ident(col) - b.WriteOp(OpEQ) - p.arg(b, arg) - }) + // A small optimization to avoid passing + // arguments when it can be avoided. + switch arg := arg.(type) { + case bool: + if arg { + return IsTrue(col) + } + return IsFalse(col) + default: + return p.Append(func(b *Builder) { + b.Ident(col) + b.WriteOp(OpEQ) + p.arg(b, arg) + }) + } } // ColumnsEQ appends a "=" predicate between 2 columns. @@ -1347,11 +1381,21 @@ func NEQ(col string, value interface{}) *Predicate { // NEQ appends a "<>" predicate. func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { - return p.Append(func(b *Builder) { - b.Ident(col) - b.WriteOp(OpNEQ) - p.arg(b, arg) - }) + // A small optimization to avoid passing + // arguments when it can be avoided. + switch arg := arg.(type) { + case bool: + if arg { + return IsFalse(col) + } + return IsTrue(col) + default: + return p.Append(func(b *Builder) { + b.Ident(col) + b.WriteOp(OpNEQ) + p.arg(b, arg) + }) + } } // ColumnsNEQ appends a "<>" predicate between 2 columns. diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 7157bc5b9c..75b5c8521d 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -2010,3 +2010,20 @@ func TestReusePredicates(t *testing.T) { require.Equal(t, tt.wantArgs, args) } } + +func TestBoolPredicates(t *testing.T) { + t1, t2 := Table("users"), Table("posts") + query, args := Select(). + From(t1). + Join(t2). + On(t1.C("id"), t2.C("author_id")). + Where( + And( + EQ(t1.C("active"), true), + NEQ(t2.C("deleted"), true), + ), + ). + Query() + require.Nil(t, args) + require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query) +} diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 24ed9a7db7..12605bf6bc 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -766,6 +766,14 @@ func Predicate(t *testing.T, client *ent.Client) { ). CountX(ctx), ) + + inf := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx) + hub := client.Group.Create().SetName("GitHub").SetExpire(time.Now()).SetInfo(inf).SaveX(ctx) + lab := client.Group.Create().SetName("GitLab").SetExpire(time.Now()).SetInfo(inf).SetActive(false).SaveX(ctx) + require.Equal(hub.ID, client.Group.Query().Where(group.Active(true)).OnlyIDX(ctx)) + require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx)) + require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx)) + require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx)) } func AddValues(t *testing.T, client *ent.Client) {