Skip to content

Commit

Permalink
dialect/sql: avoid passing bool arguments on bool predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Mar 15, 2022
1 parent db1617b commit 3fc5314
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
64 changes: 54 additions & 10 deletions dialect/sql/builder.go
Expand Up @@ -1316,18 +1316,52 @@ func And(preds ...*Predicate) *Predicate {
})
}

// IsTrue appends a predicate that checks if the column value is truthy.
func IsTrue(col string) *Predicate {
return P().IsTrue(col)
}

// IsTrue appends a predicate that checks if the column value is truthy.
func (p *Predicate) IsTrue(col string) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
})
}

// IsFalse appends a predicate that checks if the column value is falsey.
func IsFalse(col string) *Predicate {
return P().IsFalse(col)
}

// IsFalse appends a predicate that checks if the column value is falsey.
func (p *Predicate) IsFalse(col string) *Predicate {
return p.Append(func(b *Builder) {
b.WriteString("NOT ").Ident(col)
})
}

// EQ returns a "=" predicate.
func EQ(col string, value interface{}) *Predicate {
return P().EQ(col, value)
}

// EQ appends a "=" predicate.
func (p *Predicate) EQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsTrue(col)
}
return IsFalse(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpEQ)
p.arg(b, arg)
})
}
}

// ColumnsEQ appends a "=" predicate between 2 columns.
Expand All @@ -1347,11 +1381,21 @@ func NEQ(col string, value interface{}) *Predicate {

// NEQ appends a "<>" predicate.
func (p *Predicate) NEQ(col string, arg interface{}) *Predicate {
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
// A small optimization to avoid passing
// arguments when it can be avoided.
switch arg := arg.(type) {
case bool:
if arg {
return IsFalse(col)
}
return IsTrue(col)
default:
return p.Append(func(b *Builder) {
b.Ident(col)
b.WriteOp(OpNEQ)
p.arg(b, arg)
})
}
}

// ColumnsNEQ appends a "<>" predicate between 2 columns.
Expand Down
17 changes: 17 additions & 0 deletions dialect/sql/builder_test.go
Expand Up @@ -2010,3 +2010,20 @@ func TestReusePredicates(t *testing.T) {
require.Equal(t, tt.wantArgs, args)
}
}

func TestBoolPredicates(t *testing.T) {
t1, t2 := Table("users"), Table("posts")
query, args := Select().
From(t1).
Join(t2).
On(t1.C("id"), t2.C("author_id")).
Where(
And(
EQ(t1.C("active"), true),
NEQ(t2.C("deleted"), true),
),
).
Query()
require.Nil(t, args)
require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query)
}
8 changes: 8 additions & 0 deletions entc/integration/integration_test.go
Expand Up @@ -766,6 +766,14 @@ func Predicate(t *testing.T, client *ent.Client) {
).
CountX(ctx),
)

inf := client.GroupInfo.Create().SetDesc("desc").SaveX(ctx)
hub := client.Group.Create().SetName("GitHub").SetExpire(time.Now()).SetInfo(inf).SaveX(ctx)
lab := client.Group.Create().SetName("GitLab").SetExpire(time.Now()).SetInfo(inf).SetActive(false).SaveX(ctx)
require.Equal(hub.ID, client.Group.Query().Where(group.Active(true)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.Active(false)).OnlyIDX(ctx))
require.Equal(hub.ID, client.Group.Query().Where(group.ActiveNEQ(false)).OnlyIDX(ctx))
require.Equal(lab.ID, client.Group.Query().Where(group.ActiveNEQ(true)).OnlyIDX(ctx))
}

func AddValues(t *testing.T, client *ent.Client) {
Expand Down

0 comments on commit 3fc5314

Please sign in to comment.