Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement UpdateQuery #706

Merged
merged 1 commit into from Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dialect.go
Expand Up @@ -12,6 +12,7 @@ type crudable interface {
SelectMany(store, *Model, Query) error
Create(store, *Model, columns.Columns) error
Update(store, *Model, columns.Columns) error
UpdateQuery(store, *Model, columns.Columns, Query) (int64, error)
Destroy(store, *Model) error
Delete(store, *Model, Query) error
}
Expand Down
4 changes: 4 additions & 0 deletions dialect_cockroach.go
Expand Up @@ -106,6 +106,10 @@ func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols, p)
}

func (p *cockroach) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR)
}

func (p *cockroach) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
_, err := genericExec(s, stmt, model.ID())
Expand Down
27 changes: 27 additions & 0 deletions dialect_common.go
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/logging"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx"
)

func init() {
Expand Down Expand Up @@ -110,6 +111,32 @@ func genericUpdate(s store, model *Model, cols columns.Columns, quoter quotable)
return nil
}

func genericUpdateQuery(s store, model *Model, cols columns.Columns, quoter quotable, query Query, bindType int) (int64, error) {
q := fmt.Sprintf("UPDATE %s AS %s SET %s", quoter.Quote(model.TableName()), model.Alias(), cols.Writeable().QuotedUpdateString(quoter))

q, updateArgs, err := sqlx.Named(q, model.Value)
if err != nil {
return 0, err
}

sb := query.toSQLBuilder(model)
q = sb.buildWhereClauses(q)

q = sqlx.Rebind(bindType, q)

result, err := genericExec(s, q, append(updateArgs, sb.args...)...)
if err != nil {
return 0, err
}

n, err := result.RowsAffected()
if err != nil {
return 0, err
}

return n, err
}

func genericDestroy(s store, model *Model, quoter quotable) error {
stmt := fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", quoter.Quote(model.TableName()), model.Alias(), model.WhereID())
_, err := genericExec(s, stmt, model.ID())
Expand Down
9 changes: 9 additions & 0 deletions dialect_mysql.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/logging"
"github.com/jmoiron/sqlx"
)

const nameMySQL = "mysql"
Expand Down Expand Up @@ -94,6 +95,14 @@ func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
return nil
}

func (m *mysql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil {
return n, fmt.Errorf("mysql update query: %w", err)
} else {
return n, nil
}
}

func (m *mysql) Destroy(s store, model *Model) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField())
_, err := genericExec(s, stmt, model.ID())
Expand Down
4 changes: 4 additions & 0 deletions dialect_postgresql.go
Expand Up @@ -90,6 +90,10 @@ func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols, p)
}

func (p *postgresql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR)
}

func (p *postgresql) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
_, err := genericExec(s, stmt, model.ID())
Expand Down
16 changes: 16 additions & 0 deletions dialect_sqlite.go
@@ -1,3 +1,4 @@
//go:build sqlite
// +build sqlite

package pop
Expand All @@ -19,6 +20,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/logging"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
)
Expand Down Expand Up @@ -109,6 +111,20 @@ func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
})
}

func (m *sqlite) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) {
rowsAffected := int64(0)
err := m.locker(m.smGil, func() error {
if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil {
rowsAffected = n
return fmt.Errorf("sqlite update query: %w", err)
} else {
rowsAffected = n
return nil
}
})
return rowsAffected, err
}

func (m *sqlite) Destroy(s store, model *Model) error {
return m.locker(m.smGil, func() error {
if err := genericDestroy(s, model, m); err != nil {
Expand Down
29 changes: 29 additions & 0 deletions executors.go
Expand Up @@ -380,6 +380,35 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
})
}

// UpdateQuery updates all rows matched by the query. The new values are read
// from the first argument, which must be a struct. The column names to be
// updated must be listed explicitly in subsequent arguments. The ID and
// CreatedAt columns are never updated. The UpdatedAt column is updated
// automatically.
//
// UpdateQuery does not execute (before|after)(Create|Update|Save) callbacks.
//
// Calling UpdateQuery with no columnNames will result in only the UpdatedAt
// column being updated.
func (q *Query) UpdateQuery(model interface{}, columnNames ...string) (int64, error) {
sm := NewModel(model, q.Connection.Context())
modelKind := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(model))).Kind()
if modelKind != reflect.Struct {
return 0, fmt.Errorf("model must be a struct; got %s", modelKind)
}

cols := columns.NewColumnsWithAlias(sm.TableName(), sm.As, sm.IDField())
cols.Add(columnNames...)
if _, err := sm.fieldByName("UpdatedAt"); err == nil {
cols.Add("updated_at")
}
cols.Remove(sm.IDField(), "created_at")

now := nowFunc().Truncate(time.Microsecond)
sm.setUpdatedAt(now)
return q.Connection.Dialect.UpdateQuery(q.Connection.Store, sm, cols, *q)
}

// UpdateColumns writes changes from an entry to the database, including only the given columns
// or all columns if no column names are provided.
// It updates the `updated_at` column automatically if you include `updated_at` in columnNames.
Expand Down
108 changes: 108 additions & 0 deletions executors_test.go
Expand Up @@ -556,6 +556,13 @@ func Test_Embedded_Struct(t *testing.T) {
r.NoError(tx.Find(&actual, entry.ID))
r.Equal(entry.AdditionalField, actual.AdditionalField)

entry.AdditionalField = entry.AdditionalField + "; updated again"
count, err := tx.Where("id = ?", entry.ID).UpdateQuery(entry, "additional_field")
r.NoError(err)
require.Equal(t, int64(1), count)
r.NoError(tx.Find(&actual, entry.ID))
r.Equal(entry.AdditionalField, actual.AdditionalField)

r.NoError(tx.Destroy(entry))
})
}
Expand Down Expand Up @@ -1493,6 +1500,107 @@ func Test_UpdateColumns(t *testing.T) {
})
}

func Test_UpdateQuery_NoUpdatedAt(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)
r.NoError(PDB.Create(&NonStandardID{OutfacingID: "must-change"}))
count, err := PDB.Where("true").UpdateQuery(&NonStandardID{OutfacingID: "has-changed"}, "id")
r.NoError(err)
r.Equal(int64(1), count)
entity := NonStandardID{}
r.NoError(PDB.First(&entity))
r.Equal("has-changed", entity.OutfacingID)
})
}

func Test_UpdateQuery_NoTransaction(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}

r := require.New(t)
u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")}
r.NoError(PDB.Create(&u1))
r.NoError(PDB.Reload(&u1))
count, err := PDB.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio")
r.NoError(err)
require.Equal(t, int64(0), count)

count, err = PDB.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name")
r.NoError(err)
r.Equal(int64(1), count)

require.NoError(t, PDB.Destroy(&u1))
}

func Test_UpdateQuery(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")}
u2 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-2")}
u3 := User{Name: nulls.NewString("Baz"), Bio: nulls.NewString("must-not-change-3")}
tx.Create(&u1)
tx.Create(&u2)
tx.Create(&u3)
r.NoError(tx.Reload(&u1))
r.NoError(tx.Reload(&u2))
r.NoError(tx.Reload(&u3))
time.Sleep(time.Millisecond * 1)

// No affected rows
count, err := tx.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio")
r.NoError(err)
require.Equal(t, int64(0), count)
mustUnchanged := &User{}
r.NoError(tx.Find(mustUnchanged, u1.ID))
r.Equal(u1.Bio, mustUnchanged.Bio)
r.Equal(u1.UpdatedAt, mustUnchanged.UpdatedAt)

// Correct rows are updated, including updated_at
count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name")
r.NoError(err)
r.Equal(int64(2), count)

u1b, u2b, u3b := &User{}, &User{}, &User{}
r.NoError(tx.Find(u1b, u1.ID))
r.NoError(tx.Find(u2b, u2.ID))
r.NoError(tx.Find(u3b, u3.ID))
r.Equal(u1b.Name.String, "Bar")
r.Equal(u2b.Name.String, "Bar")
r.Equal(u3b.Name.String, "Baz")
r.Equal(u1b.Bio.String, "must-not-change-1")
r.Equal(u2b.Bio.String, "must-not-change-2")
r.Equal(u3b.Bio.String, "must-not-change-3")
if tx.Dialect.Name() != nameMySQL { // MySQL timestamps are in seconds
r.NotEqual(u1.UpdatedAt, u1b.UpdatedAt)
r.NotEqual(u2.UpdatedAt, u2b.UpdatedAt)
}
r.Equal(u3.UpdatedAt, u3b.UpdatedAt)

// ID is ignored
count, err = tx.Where("true").UpdateQuery(&User{ID: 123, Name: nulls.NewString("Bar")}, "id", "name")
r.NoError(tx.Find(u1b, u1.ID))
r.NoError(tx.Find(u2b, u2.ID))
r.NoError(tx.Find(u3b, u3.ID))
r.Equal(u1b.Name.String, "Bar")
r.Equal(u2b.Name.String, "Bar")
r.Equal(u3b.Name.String, "Bar")

// Invalid column yields an error
count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "mistake")
r.Contains(err.Error(), "could not find name mistake")

tx.Where("true").Delete(&User{})
})
}

func Test_UpdateColumns_UpdatedAt(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down
2 changes: 1 addition & 1 deletion pop_test.go
Expand Up @@ -52,7 +52,7 @@ func init() {
dialect := os.Getenv("SODA_DIALECT")

if dialect == "" {
log(logging.Info, "Skipping integration tests")
log(logging.Info, "Skipping integration tests because SODA_DIALECT is blank or unset")
return
}

Expand Down