Skip to content

Commit

Permalink
feat: implement UpdateQuery
Browse files Browse the repository at this point in the history
This commit introduces a new function, UpdateQuery, that enables updating all
rows matched by a query. It can be used for conditional updates.
  • Loading branch information
grantzvolsky committed Apr 18, 2022
1 parent d6d7437 commit cd6c23e
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 1 deletion.
1 change: 1 addition & 0 deletions dialect.go
Expand Up @@ -12,6 +12,7 @@ type crudable interface {
SelectMany(store, *Model, Query) error
Create(store, *Model, columns.Columns) error
Update(store, *Model, columns.Columns) error
UpdateQuery(store, *Model, columns.Columns, Query) (int64, error)
Destroy(store, *Model) error
Delete(store, *Model, Query) error
}
Expand Down
4 changes: 4 additions & 0 deletions dialect_cockroach.go
Expand Up @@ -106,6 +106,10 @@ func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols, p)
}

func (p *cockroach) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR)
}

func (p *cockroach) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
_, err := genericExec(s, stmt, model.ID())
Expand Down
27 changes: 27 additions & 0 deletions dialect_common.go
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/logging"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx"
)

func init() {
Expand Down Expand Up @@ -110,6 +111,32 @@ func genericUpdate(s store, model *Model, cols columns.Columns, quoter quotable)
return nil
}

func genericUpdateQuery(s store, model *Model, cols columns.Columns, quoter quotable, query Query, bindType int) (int64, error) {
q := fmt.Sprintf("UPDATE %s AS %s SET %s", quoter.Quote(model.TableName()), model.Alias(), cols.Writeable().QuotedUpdateString(quoter))

q, updateArgs, err := sqlx.Named(q, model.Value)
if err != nil {
return 0, err
}

sb := query.toSQLBuilder(model)
q = sb.buildWhereClauses(q)

q = sqlx.Rebind(bindType, q)

result, err := genericExec(s, q, append(updateArgs, sb.args...)...)
if err != nil {
return 0, err
}

n, err := result.RowsAffected()
if err != nil {
return 0, err
}

return n, err
}

func genericDestroy(s store, model *Model, quoter quotable) error {
stmt := fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", quoter.Quote(model.TableName()), model.Alias(), model.WhereID())
_, err := genericExec(s, stmt, model.ID())
Expand Down
9 changes: 9 additions & 0 deletions dialect_mysql.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/logging"
"github.com/jmoiron/sqlx"
)

const nameMySQL = "mysql"
Expand Down Expand Up @@ -94,6 +95,14 @@ func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
return nil
}

func (m *mysql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil {
return n, fmt.Errorf("mysql update query: %w", err)
} else {
return n, nil
}
}

func (m *mysql) Destroy(s store, model *Model) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField())
_, err := genericExec(s, stmt, model.ID())
Expand Down
4 changes: 4 additions & 0 deletions dialect_postgresql.go
Expand Up @@ -90,6 +90,10 @@ func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols, p)
}

func (p *postgresql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR)
}

func (p *postgresql) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
_, err := genericExec(s, stmt, model.ID())
Expand Down
16 changes: 16 additions & 0 deletions dialect_sqlite.go
@@ -1,3 +1,4 @@
//go:build sqlite
// +build sqlite

package pop
Expand All @@ -19,6 +20,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/logging"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
)
Expand Down Expand Up @@ -109,6 +111,20 @@ func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
})
}

func (m *sqlite) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
rowsAffected := int64(0)
err := m.locker(m.smGil, func() error {
if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil {
rowsAffected = n
return fmt.Errorf("sqlite update query: %w", err)
} else {
rowsAffected = n
return nil
}
})
return rowsAffected, err
}

func (m *sqlite) Destroy(s store, model *Model) error {
return m.locker(m.smGil, func() error {
if err := genericDestroy(s, model, m); err != nil {
Expand Down
29 changes: 29 additions & 0 deletions executors.go
Expand Up @@ -380,6 +380,35 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
})
}

// UpdateQuery updates all rows matched by the query. The new values are read
// from the first argument, which must be a struct. The column names to be
// updated must be listed explicitly in subsequent arguments. The ID and
// CreatedAt columns are never updated. The UpdatedAt column is updated
// automatically.
//
// UpdateQuery does not execute (before|after)(Create|Update|Save) callbacks.
//
// Calling UpdateQuery with no columnNames will result in only the UpdatedAt
// column being updated.
func (q *Query) UpdateQuery(model interface{}, columnNames ...string) (int64, error) {
sm := NewModel(model, q.Connection.Context())
modelKind := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(model))).Kind()
if modelKind != reflect.Struct {
return 0, fmt.Errorf("model must be a struct; got %s", modelKind)
}

cols := columns.NewColumnsWithAlias(sm.TableName(), sm.As, sm.IDField())
cols.Add(columnNames...)
if _, err := sm.fieldByName("UpdatedAt"); err == nil {
cols.Add("updated_at")
}
cols.Remove(sm.IDField(), "created_at")

now := nowFunc().Truncate(time.Microsecond)
sm.setUpdatedAt(now)
return q.Connection.Dialect.UpdateQuery(q.Connection.Store, sm, cols, *q)
}

// UpdateColumns writes changes from an entry to the database, including only the given columns
// or all columns if no column names are provided.
// It updates the `updated_at` column automatically if you include `updated_at` in columnNames.
Expand Down
108 changes: 108 additions & 0 deletions executors_test.go
Expand Up @@ -556,6 +556,13 @@ func Test_Embedded_Struct(t *testing.T) {
r.NoError(tx.Find(&actual, entry.ID))
r.Equal(entry.AdditionalField, actual.AdditionalField)

entry.AdditionalField = entry.AdditionalField + "; updated again"
count, err := tx.Where("id = ?", entry.ID).UpdateQuery(entry, "additional_field")
r.NoError(err)
require.Equal(t, int64(1), count)
r.NoError(tx.Find(&actual, entry.ID))
r.Equal(entry.AdditionalField, actual.AdditionalField)

r.NoError(tx.Destroy(entry))
})
}
Expand Down Expand Up @@ -1349,6 +1356,107 @@ func Test_UpdateColumns(t *testing.T) {
})
}

func Test_UpdateQuery_NoUpdatedAt(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)
r.NoError(PDB.Create(&NonStandardID{OutfacingID: "must-change"}))
count, err := PDB.Where("true").UpdateQuery(&NonStandardID{OutfacingID: "has-changed"}, "id")
r.NoError(err)
r.Equal(int64(1), count)
entity := NonStandardID{}
r.NoError(PDB.First(&entity))
r.Equal("has-changed", entity.OutfacingID)
})
}

func Test_UpdateQuery_NoTransaction(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}

r := require.New(t)
u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")}
r.NoError(PDB.Create(&u1))
r.NoError(PDB.Reload(&u1))
count, err := PDB.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio")
r.NoError(err)
require.Equal(t, int64(0), count)

count, err = PDB.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name")
r.NoError(err)
r.Equal(int64(1), count)

require.NoError(t, PDB.Destroy(&u1))
}

func Test_UpdateQuery(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")}
u2 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-2")}
u3 := User{Name: nulls.NewString("Baz"), Bio: nulls.NewString("must-not-change-3")}
tx.Create(&u1)
tx.Create(&u2)
tx.Create(&u3)
r.NoError(tx.Reload(&u1))
r.NoError(tx.Reload(&u2))
r.NoError(tx.Reload(&u3))
time.Sleep(time.Millisecond * 1)

// No affected rows
count, err := tx.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio")
r.NoError(err)
require.Equal(t, int64(0), count)
mustUnchanged := &User{}
r.NoError(tx.Find(mustUnchanged, u1.ID))
r.Equal(u1.Bio, mustUnchanged.Bio)
r.Equal(u1.UpdatedAt, mustUnchanged.UpdatedAt)

// Correct rows are updated, including updated_at
count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name")
r.NoError(err)
r.Equal(int64(2), count)

u1b, u2b, u3b := &User{}, &User{}, &User{}
r.NoError(tx.Find(u1b, u1.ID))
r.NoError(tx.Find(u2b, u2.ID))
r.NoError(tx.Find(u3b, u3.ID))
r.Equal(u1b.Name.String, "Bar")
r.Equal(u2b.Name.String, "Bar")
r.Equal(u3b.Name.String, "Baz")
r.Equal(u1b.Bio.String, "must-not-change-1")
r.Equal(u2b.Bio.String, "must-not-change-2")
r.Equal(u3b.Bio.String, "must-not-change-3")
if tx.Dialect.Name() != nameMySQL { // MySQL timestamps are in seconds
r.NotEqual(u1.UpdatedAt, u1b.UpdatedAt)
r.NotEqual(u2.UpdatedAt, u2b.UpdatedAt)
}
r.Equal(u3.UpdatedAt, u3b.UpdatedAt)

// ID is ignored
count, err = tx.Where("true").UpdateQuery(&User{ID: 123, Name: nulls.NewString("Bar")}, "id", "name")
r.NoError(tx.Find(u1b, u1.ID))
r.NoError(tx.Find(u2b, u2.ID))
r.NoError(tx.Find(u3b, u3.ID))
r.Equal(u1b.Name.String, "Bar")
r.Equal(u2b.Name.String, "Bar")
r.Equal(u3b.Name.String, "Bar")

// Invalid column yields an error
count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "mistake")
r.Contains(err.Error(), "could not find name mistake")

tx.Where("true").Delete(&User{})
})
}

func Test_UpdateColumns_UpdatedAt(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down
2 changes: 1 addition & 1 deletion pop_test.go
Expand Up @@ -52,7 +52,7 @@ func init() {
dialect := os.Getenv("SODA_DIALECT")

if dialect == "" {
log(logging.Info, "Skipping integration tests")
log(logging.Info, "Skipping integration tests because SODA_DIALECT is blank or unset")
return
}

Expand Down

0 comments on commit cd6c23e

Please sign in to comment.