diff --git a/dialect.go b/dialect.go index 1b87847d..0e49b168 100644 --- a/dialect.go +++ b/dialect.go @@ -12,6 +12,7 @@ type crudable interface { SelectMany(store, *Model, Query) error Create(store, *Model, columns.Columns) error Update(store, *Model, columns.Columns) error + UpdateQuery(store, *Model, columns.Columns, Query) (int64, error) Destroy(store, *Model) error Delete(store, *Model, Query) error } diff --git a/dialect_cockroach.go b/dialect_cockroach.go index 16d4d96b..5fb7e324 100644 --- a/dialect_cockroach.go +++ b/dialect_cockroach.go @@ -106,6 +106,10 @@ func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error { return genericUpdate(s, model, cols, p) } +func (p *cockroach) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) { + return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR) +} + func (p *cockroach) Destroy(s store, model *Model) error { stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID())) _, err := genericExec(s, stmt, model.ID()) diff --git a/dialect_common.go b/dialect_common.go index d0876295..5b1d76f4 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -14,6 +14,7 @@ import ( "github.com/gobuffalo/pop/v6/columns" "github.com/gobuffalo/pop/v6/logging" "github.com/gofrs/uuid" + "github.com/jmoiron/sqlx" ) func init() { @@ -110,6 +111,32 @@ func genericUpdate(s store, model *Model, cols columns.Columns, quoter quotable) return nil } +func genericUpdateQuery(s store, model *Model, cols columns.Columns, quoter quotable, query Query, bindType int) (int64, error) { + q := fmt.Sprintf("UPDATE %s AS %s SET %s", quoter.Quote(model.TableName()), model.Alias(), cols.Writeable().QuotedUpdateString(quoter)) + + q, updateArgs, err := sqlx.Named(q, model.Value) + if err != nil { + return 0, err + } + + sb := query.toSQLBuilder(model) + q = sb.buildWhereClauses(q) + + q = sqlx.Rebind(bindType, q) + + result, err := genericExec(s, q, append(updateArgs, sb.args...)...) + if err != nil { + return 0, err + } + + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + + return n, err +} + func genericDestroy(s store, model *Model, quoter quotable) error { stmt := fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", quoter.Quote(model.TableName()), model.Alias(), model.WhereID()) _, err := genericExec(s, stmt, model.ID()) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1cfda3cf..7b5022b4 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -13,6 +13,7 @@ import ( "github.com/gobuffalo/pop/v6/columns" "github.com/gobuffalo/pop/v6/internal/defaults" "github.com/gobuffalo/pop/v6/logging" + "github.com/jmoiron/sqlx" ) const nameMySQL = "mysql" @@ -94,6 +95,14 @@ func (m *mysql) Update(s store, model *Model, cols columns.Columns) error { return nil } +func (m *mysql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) { + if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil { + return n, fmt.Errorf("mysql update query: %w", err) + } else { + return n, nil + } +} + func (m *mysql) Destroy(s store, model *Model) error { stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField()) _, err := genericExec(s, stmt, model.ID()) diff --git a/dialect_postgresql.go b/dialect_postgresql.go index fe4d0236..6ce61569 100644 --- a/dialect_postgresql.go +++ b/dialect_postgresql.go @@ -90,6 +90,10 @@ func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error { return genericUpdate(s, model, cols, p) } +func (p *postgresql) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) { + return genericUpdateQuery(s, model, cols, p, query, sqlx.DOLLAR) +} + func (p *postgresql) Destroy(s store, model *Model) error { stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID())) _, err := genericExec(s, stmt, model.ID()) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 1ac8691f..91dfce97 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -1,3 +1,4 @@ +//go:build sqlite // +build sqlite package pop @@ -19,6 +20,7 @@ import ( "github.com/gobuffalo/pop/v6/columns" "github.com/gobuffalo/pop/v6/internal/defaults" "github.com/gobuffalo/pop/v6/logging" + "github.com/jmoiron/sqlx" "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver ) @@ -109,6 +111,20 @@ func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error { }) } +func (m *sqlite) UpdateQuery(s store, model *Model, cols columns.Columns, query Query) (int64, error) { + rowsAffected := int64(0) + err := m.locker(m.smGil, func() error { + if n, err := genericUpdateQuery(s, model, cols, m, query, sqlx.QUESTION); err != nil { + rowsAffected = n + return fmt.Errorf("sqlite update query: %w", err) + } else { + rowsAffected = n + return nil + } + }) + return rowsAffected, err +} + func (m *sqlite) Destroy(s store, model *Model) error { return m.locker(m.smGil, func() error { if err := genericDestroy(s, model, m); err != nil { diff --git a/executors.go b/executors.go index 04478f66..37b03b35 100644 --- a/executors.go +++ b/executors.go @@ -380,6 +380,35 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { }) } +// UpdateQuery updates all rows matched by the query. The new values are read +// from the first argument, which must be a struct. The column names to be +// updated must be listed explicitly in subsequent arguments. The ID and +// CreatedAt columns are never updated. The UpdatedAt column is updated +// automatically. +// +// UpdateQuery does not execute (before|after)(Create|Update|Save) callbacks. +// +// Calling UpdateQuery with no columnNames will result in only the UpdatedAt +// column being updated. +func (q *Query) UpdateQuery(model interface{}, columnNames ...string) (int64, error) { + sm := NewModel(model, q.Connection.Context()) + modelKind := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(model))).Kind() + if modelKind != reflect.Struct { + return 0, fmt.Errorf("model must be a struct; got %s", modelKind) + } + + cols := columns.NewColumnsWithAlias(sm.TableName(), sm.As, sm.IDField()) + cols.Add(columnNames...) + if _, err := sm.fieldByName("UpdatedAt"); err == nil { + cols.Add("updated_at") + } + cols.Remove(sm.IDField(), "created_at") + + now := nowFunc().Truncate(time.Microsecond) + sm.setUpdatedAt(now) + return q.Connection.Dialect.UpdateQuery(q.Connection.Store, sm, cols, *q) +} + // UpdateColumns writes changes from an entry to the database, including only the given columns // or all columns if no column names are provided. // It updates the `updated_at` column automatically if you include `updated_at` in columnNames. diff --git a/executors_test.go b/executors_test.go index e47e4d0a..d5154836 100644 --- a/executors_test.go +++ b/executors_test.go @@ -556,6 +556,13 @@ func Test_Embedded_Struct(t *testing.T) { r.NoError(tx.Find(&actual, entry.ID)) r.Equal(entry.AdditionalField, actual.AdditionalField) + entry.AdditionalField = entry.AdditionalField + "; updated again" + count, err := tx.Where("id = ?", entry.ID).UpdateQuery(entry, "additional_field") + r.NoError(err) + require.Equal(t, int64(1), count) + r.NoError(tx.Find(&actual, entry.ID)) + r.Equal(entry.AdditionalField, actual.AdditionalField) + r.NoError(tx.Destroy(entry)) }) } @@ -1493,6 +1500,107 @@ func Test_UpdateColumns(t *testing.T) { }) } +func Test_UpdateQuery_NoUpdatedAt(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + r.NoError(PDB.Create(&NonStandardID{OutfacingID: "must-change"})) + count, err := PDB.Where("true").UpdateQuery(&NonStandardID{OutfacingID: "has-changed"}, "id") + r.NoError(err) + r.Equal(int64(1), count) + entity := NonStandardID{} + r.NoError(PDB.First(&entity)) + r.Equal("has-changed", entity.OutfacingID) + }) +} + +func Test_UpdateQuery_NoTransaction(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + + r := require.New(t) + u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")} + r.NoError(PDB.Create(&u1)) + r.NoError(PDB.Reload(&u1)) + count, err := PDB.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio") + r.NoError(err) + require.Equal(t, int64(0), count) + + count, err = PDB.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name") + r.NoError(err) + r.Equal(int64(1), count) + + require.NoError(t, PDB.Destroy(&u1)) +} + +func Test_UpdateQuery(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + u1 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-1")} + u2 := User{Name: nulls.NewString("Foo"), Bio: nulls.NewString("must-not-change-2")} + u3 := User{Name: nulls.NewString("Baz"), Bio: nulls.NewString("must-not-change-3")} + tx.Create(&u1) + tx.Create(&u2) + tx.Create(&u3) + r.NoError(tx.Reload(&u1)) + r.NoError(tx.Reload(&u2)) + r.NoError(tx.Reload(&u3)) + time.Sleep(time.Millisecond * 1) + + // No affected rows + count, err := tx.Where("name = ?", "Nemo").UpdateQuery(&User{Bio: nulls.NewString("did-change")}, "bio") + r.NoError(err) + require.Equal(t, int64(0), count) + mustUnchanged := &User{} + r.NoError(tx.Find(mustUnchanged, u1.ID)) + r.Equal(u1.Bio, mustUnchanged.Bio) + r.Equal(u1.UpdatedAt, mustUnchanged.UpdatedAt) + + // Correct rows are updated, including updated_at + count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "name") + r.NoError(err) + r.Equal(int64(2), count) + + u1b, u2b, u3b := &User{}, &User{}, &User{} + r.NoError(tx.Find(u1b, u1.ID)) + r.NoError(tx.Find(u2b, u2.ID)) + r.NoError(tx.Find(u3b, u3.ID)) + r.Equal(u1b.Name.String, "Bar") + r.Equal(u2b.Name.String, "Bar") + r.Equal(u3b.Name.String, "Baz") + r.Equal(u1b.Bio.String, "must-not-change-1") + r.Equal(u2b.Bio.String, "must-not-change-2") + r.Equal(u3b.Bio.String, "must-not-change-3") + if tx.Dialect.Name() != nameMySQL { // MySQL timestamps are in seconds + r.NotEqual(u1.UpdatedAt, u1b.UpdatedAt) + r.NotEqual(u2.UpdatedAt, u2b.UpdatedAt) + } + r.Equal(u3.UpdatedAt, u3b.UpdatedAt) + + // ID is ignored + count, err = tx.Where("true").UpdateQuery(&User{ID: 123, Name: nulls.NewString("Bar")}, "id", "name") + r.NoError(tx.Find(u1b, u1.ID)) + r.NoError(tx.Find(u2b, u2.ID)) + r.NoError(tx.Find(u3b, u3.ID)) + r.Equal(u1b.Name.String, "Bar") + r.Equal(u2b.Name.String, "Bar") + r.Equal(u3b.Name.String, "Bar") + + // Invalid column yields an error + count, err = tx.Where("name = ?", "Foo").UpdateQuery(&User{Name: nulls.NewString("Bar")}, "mistake") + r.Contains(err.Error(), "could not find name mistake") + + tx.Where("true").Delete(&User{}) + }) +} + func Test_UpdateColumns_UpdatedAt(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") diff --git a/pop_test.go b/pop_test.go index 629ea35a..95fd8648 100644 --- a/pop_test.go +++ b/pop_test.go @@ -52,7 +52,7 @@ func init() { dialect := os.Getenv("SODA_DIALECT") if dialect == "" { - log(logging.Info, "Skipping integration tests") + log(logging.Info, "Skipping integration tests because SODA_DIALECT is blank or unset") return }