diff --git a/belongs_to.go b/belongs_to.go index 0b5c977e..d261a315 100644 --- a/belongs_to.go +++ b/belongs_to.go @@ -19,7 +19,7 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query { // BelongsTo adds a "where" clause based on the "ID" of the // "model" passed into it. func (q *Query) BelongsTo(model interface{}) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID()) return q } @@ -27,7 +27,7 @@ func (q *Query) BelongsTo(model interface{}) *Query { // BelongsToAs adds a "where" clause based on the "ID" of the // "model" passed into it, using an alias. func (q *Query) BelongsToAs(model interface{}, as string) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", as), m.ID()) return q } @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query { // through the associated "thru" model. func (q *Query) BelongsToThrough(bt, thru interface{}) *Query { q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{ - BelongsTo: &Model{Value: bt}, - Through: &Model{Value: thru}, + BelongsTo: NewModel(bt, q.Connection.Context()), + Through: NewModel(thru, q.Connection.Context()), }) return q } diff --git a/belongs_to_test.go b/belongs_to_test.go index e4e3a3e7..9eb99914 100644 --- a/belongs_to_test.go +++ b/belongs_to_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) { q := PDB.BelongsTo(&User{ID: 1}) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql) @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) { q := PDB.BelongsToAs(&User{ID: 1}, "u_id") - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql) @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) { q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{}) qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts(qs), sql) } diff --git a/connection.go b/connection.go index df116fb3..6e0f55ee 100644 --- a/connection.go +++ b/connection.go @@ -33,6 +33,16 @@ func (c *Connection) URL() string { return c.Dialect.URL() } +// Context returns the connection's context set by "Context()" or context.TODO() +// if no context is set. +func (c *Connection) Context() context.Context { + if c, ok := c.Store.(interface{ Context() context.Context }); ok { + return c.Context() + } + + return context.TODO() +} + // MigrationURL returns the datasource connection string used for running the migrations func (c *Connection) MigrationURL() string { return c.Dialect.MigrationURL() diff --git a/connection_details.go b/connection_details.go index 29c0ae41..6456b7d1 100644 --- a/connection_details.go +++ b/connection_details.go @@ -2,13 +2,14 @@ package pop import ( "fmt" - "github.com/luna-duclos/instrumentedsql" "net/url" "regexp" "strconv" "strings" "time" + "github.com/luna-duclos/instrumentedsql" + "github.com/gobuffalo/pop/v5/internal/defaults" "github.com/gobuffalo/pop/v5/logging" "github.com/pkg/errors" diff --git a/connection_instrumented.go b/connection_instrumented.go index 84cf7820..f2d46b02 100644 --- a/connection_instrumented.go +++ b/connection_instrumented.go @@ -3,6 +3,7 @@ package pop import ( "database/sql" "database/sql/driver" + mysqld "github.com/go-sql-driver/mysql" "github.com/gobuffalo/pop/v5/logging" pgx "github.com/jackc/pgx/v4/stdlib" diff --git a/connection_instrumented_nosqlite_test.go b/connection_instrumented_nosqlite_test.go index 715a92ab..2cf6ae01 100644 --- a/connection_instrumented_nosqlite_test.go +++ b/connection_instrumented_nosqlite_test.go @@ -3,8 +3,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestInstrumentation_WithoutSqlite(t *testing.T) { diff --git a/connection_instrumented_test.go b/connection_instrumented_test.go index d5125639..808d8638 100644 --- a/connection_instrumented_test.go +++ b/connection_instrumented_test.go @@ -3,12 +3,13 @@ package pop import ( "context" "fmt" - "github.com/luna-duclos/instrumentedsql" - "github.com/stretchr/testify/suite" "os" "strings" "sync" "time" + + "github.com/luna-duclos/instrumentedsql" + "github.com/stretchr/testify/suite" ) func testInstrumentedDriver(p *suite.Suite) { diff --git a/dialect_sqlite.go b/dialect_sqlite.go index b3045090..47bc3e28 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -5,7 +5,6 @@ package pop import ( "database/sql/driver" "fmt" - "github.com/mattn/go-sqlite3" "io" "net/url" "os" @@ -15,6 +14,8 @@ import ( "sync" "time" + "github.com/mattn/go-sqlite3" + "github.com/gobuffalo/fizz" "github.com/gobuffalo/fizz/translators" _ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver diff --git a/executors.go b/executors.go index b3560141..657ecea7 100644 --- a/executors.go +++ b/executors.go @@ -13,7 +13,7 @@ import ( // Reload fetch fresh data for a given model, using its ID. func (c *Connection) Reload(model interface{}) error { - sm := Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.Find(m.Value, m.ID()) }) @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) { // // If model is a slice, each item of the slice is validated then saved in the database. func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool { // // If model is a slice, each item of the slice is saved in the database. func (c *Connection) Save(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then created in the database. func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err } } - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) verrs, err = sm.validateCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { c.disableEager() - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Create", func() error { var localIsEager = isEager @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { } if localIsEager { - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { fbn, err := m.fieldByName("ID") if err != nil { @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then updated in the database. func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) Update(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err // // If model is a slice, each item of the slice is deleted from the database. func (c *Connection) Destroy(model interface{}) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Destroy", func() error { var err error diff --git a/finders.go b/finders.go index 9026d14b..fa1a220a 100644 --- a/finders.go +++ b/finders.go @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error { // // q.Find(&User{}, 1) func (q *Query) Find(model interface{}, id interface{}) error { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) idq := m.whereID() switch t := id.(type) { case uuid.UUID: @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error { func (q *Query) First(model interface{}) error { err := q.Connection.timeFunc("First", func() error { q.Limit(1) - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error { err := q.Connection.timeFunc("Last", func() error { q.Limit(1) q.Order("created_at DESC, id DESC") - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error { // q.Where("name = ?", "mark").All(&[]User{}) func (q *Query) All(models interface{}) error { err := q.Connection.timeFunc("All", func() error { - m := &Model{Value: models} + m := NewModel(models, q.Connection.Context()) err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) if err != nil { return err @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { } } - sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) + sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context())) query = query.RawQuery(sqlSentence, args...) if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset diff --git a/finders_test.go b/finders_test.go index 389dc80a..7f30727e 100644 --- a/finders_test.go +++ b/finders_test.go @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) { q := tx.Select("name", "email", "\n", "\t\n", "") - sm := &Model{Value: &User{}} + sm := NewModel(new(User), tx.Context()) sql, _ := q.ToSQL(sm) r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql) diff --git a/match_test.go b/match_test.go index 8cc41a6b..0f1591f1 100644 --- a/match_test.go +++ b/match_test.go @@ -1,8 +1,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func Test_ParseMigrationFilenameFizzDown(t *testing.T) { diff --git a/migration_info_test.go b/migration_info_test.go index 0f76824e..f2174551 100644 --- a/migration_info_test.go +++ b/migration_info_test.go @@ -1,9 +1,10 @@ package pop import ( - "github.com/stretchr/testify/assert" "sort" "testing" + + "github.com/stretchr/testify/assert" ) func TestSortingMigrations(t *testing.T) { diff --git a/model.go b/model.go index 58019f73..7ba7cee2 100644 --- a/model.go +++ b/model.go @@ -1,13 +1,15 @@ package pop import ( + "context" "fmt" - "github.com/gobuffalo/pop/v5/columns" - "github.com/pkg/errors" "reflect" "sync" "time" + "github.com/gobuffalo/pop/v5/columns" + "github.com/pkg/errors" + "github.com/gobuffalo/flect" nflect "github.com/gobuffalo/flect/name" "github.com/gofrs/uuid" @@ -27,10 +29,16 @@ type modelIterable func(*Model) error // that is passed in to many functions. type Model struct { Value + ctx context.Context tableName string As string } +// NewModel returns a new model with the specified value and context. +func NewModel(v Value, ctx context.Context) *Model { + return &Model{Value: v, ctx: ctx} +} + // ID returns the ID of the Model. All models must have an `ID` field this is // of type `int`,`int64` or of type `uuid.UUID`. func (m *Model) ID() interface{} { @@ -86,6 +94,13 @@ type TableNameAble interface { TableName() string } +// TableNameAbleWithContext is equal to TableNameAble but will +// be passed the queries' context. Useful in cases where the +// table name depends on e.g. +type TableNameAbleWithContext interface { + TableName(ctx context.Context) string +} + // TableName returns the corresponding name of the underlying database table // for a given `Model`. See also `TableNameAble` to change the default name of the table. func (m *Model) TableName() string { @@ -96,6 +111,13 @@ func (m *Model) TableName() string { return n.TableName() } + if n, ok := m.Value.(TableNameAbleWithContext); ok { + if m.ctx == nil { + m.ctx = context.TODO() + } + return n.TableName(m.ctx) + } + if m.tableName != "" { return m.tableName } @@ -133,7 +155,7 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { } // validates if the elem of slice or array implements TableNameAble interface. - tableNameAble := (*TableNameAble)(nil) + var tableNameAble *TableNameAble if el.Implements(reflect.TypeOf(tableNameAble).Elem()) { v := reflect.New(el) out := v.MethodByName("TableName").Call([]reflect.Value{}) @@ -143,6 +165,17 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { } } + // validates if the elem of slice or array implements TableNameAbleWithContext interface. + var tableNameAbleWithContext *TableNameAbleWithContext + if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) { + v := reflect.New(el) + out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)}) + name := out[0].String() + if tableMap[m.cacheKey(el)] == "" { + tableMap[m.cacheKey(el)] = name + } + } + return el.Name(), m.cacheKey(el) default: return t.Name(), m.cacheKey(t) @@ -226,7 +259,10 @@ func (m *Model) iterate(fn modelIterable) error { v := reflect.Indirect(reflect.ValueOf(m.Value)) for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := &Model{ + Value: val.Addr().Interface(), + ctx: m.ctx, + } err := fn(newModel) if err != nil { diff --git a/model_context_test.go b/model_context_test.go new file mode 100644 index 00000000..f5312310 --- /dev/null +++ b/model_context_test.go @@ -0,0 +1,84 @@ +package pop + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type ContextTable struct { + ID string `db:"id"` + Value string `db:"value"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func (t ContextTable) TableName(ctx context.Context) string { + return "context_prefix_" + ctx.Value("prefix").(string) + "_table" +} + +func Test_ModelContext(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + + t.Run("contextless", func(t *testing.T) { + r := require.New(t) + r.Panics(func() { + var c ContextTable + r.NoError(PDB.Create(&c)) + }, "panics if context prefix is not set") + }) + + for _, prefix := range []string{"a", "b"} { + t.Run("prefix="+prefix, func(t *testing.T) { + r := require.New(t) + + expected := ContextTable{ID: prefix, Value: prefix} + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", prefix)) + r.NoError(c.Create(&expected)) + + var actual ContextTable + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + exists, err := c.Where("id = ?", actual.ID).Exists(new(ContextTable)) + r.NoError(err) + r.True(exists) + + count, err := c.Where("id = ?", actual.ID).Count(new(ContextTable)) + r.NoError(err) + r.EqualValues(1, count) + + expected.Value += expected.Value + r.NoError(c.Update(&expected)) + + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix+prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + var results []ContextTable + require.NoError(t, c.All(&results)) + + require.NoError(t, c.First(&expected)) + require.NoError(t, c.Last(&expected)) + + r.NoError(c.Destroy(&expected)) + }) + } + + t.Run("prefix=unknown", func(t *testing.T) { + r := require.New(t) + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", "unknown")) + err := c.Create(&ContextTable{ID: "unknown"}) + r.Error(err) + + if !strings.Contains(err.Error(), "context_prefix_unknown_table") { // All other databases + t.Fatalf("Expected error to contain indicator that table does not exist but got: %s", err.Error()) + } + }) +} diff --git a/model_test.go b/model_test.go index 545eb09e..5f8f7660 100644 --- a/model_test.go +++ b/model_test.go @@ -1,9 +1,13 @@ package pop import ( + "context" "testing" "time" + "github.com/gobuffalo/pop/v5/testdata/models/ac" + "github.com/gobuffalo/pop/v5/testdata/models/bc" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,6 +44,12 @@ func (tn) TableName() string { return "this is my table name" } +type tnc struct{} + +func (tnc) TableName(ctx context.Context) string { + return ctx.Value("name").(string) +} + // A failing test case for #477 func Test_TableNameCache(t *testing.T) { r := assert.New(t) @@ -49,6 +59,17 @@ func Test_TableNameCache(t *testing.T) { r.Equal("userb", (&Model{Value: []b.User{}}).TableName()) } +// A failing test case for #477 +func Test_TableNameContextCache(t *testing.T) { + ctx := context.WithValue(context.Background(), "name", "context-table") + + r := assert.New(t) + r.Equal("context-table-usera", (&Model{Value: ac.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-userb", (&Model{Value: bc.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-usera", (&Model{Value: []ac.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-userb", (&Model{Value: []bc.User{}, ctx: ctx}).TableName()) +} + func Test_TableName(t *testing.T) { r := require.New(t) @@ -62,6 +83,22 @@ func Test_TableName(t *testing.T) { } } +func Test_TableNameContext(t *testing.T) { + r := require.New(t) + + tn := "context table name" + ctx := context.WithValue(context.Background(), "name", tn) + + cases := []interface{}{ + tnc{}, + []tnc{}, + } + for _, tc := range cases { + m := Model{Value: tc, ctx: ctx} + r.Equal(tn, m.TableName()) + } +} + type TimeTimestamp struct { ID int `db:"id"` CreatedAt time.Time `db:"created_at"` @@ -77,7 +114,7 @@ type UnixTimestamp struct { func Test_Touch_Time_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &TimeTimestamp{}} + m := NewModel(&TimeTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -101,7 +138,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { createdAt := nowFunc().Add(-36 * time.Hour) - m := Model{Value: &TimeTimestamp{CreatedAt: createdAt}} + m := NewModel(&TimeTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*TimeTimestamp) @@ -112,7 +149,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { func Test_Touch_Unix_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &UnixTimestamp{}} + m := NewModel(&UnixTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -136,7 +173,7 @@ func Test_Touch_Unix_Timestamp_With_Existing_Value(t *testing.T) { createdAt := int(time.Now().Add(-36 * time.Hour).Unix()) - m := Model{Value: &UnixTimestamp{CreatedAt: createdAt}} + m := NewModel(&UnixTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*UnixTimestamp) diff --git a/preload_associations.go b/preload_associations.go index 4a03b579..dc19ccea 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -167,7 +167,7 @@ func (ami *AssociationMetaInfo) fkName() string { // preload is the query mode used to load associations from database // similar to the active record default approach on Rails. func preload(tx *Connection, model interface{}, fields ...string) error { - mmi := NewModelMetaInfo(&Model{Value: model}) + mmi := NewModelMetaInfo(NewModel(model, tx.Context())) preloadFields, err := mmi.preloadFields(fields...) if err != nil { diff --git a/query_test.go b/query_test.go index 93da765c..7df5a994 100644 --- a/query_test.go +++ b/query_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "fmt" "testing" @@ -13,7 +14,7 @@ func Test_Where(t *testing.T) { t.Skip("skipping integration tests") } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Where("id = ?", 1) sql, _ := q.ToSQL(m) @@ -107,7 +108,7 @@ func Test_Order(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Order("id desc") sql, _ := q.ToSQL(m) a.Equal(ts("SELECT enemies.A FROM enemies AS enemies ORDER BY id desc"), sql) @@ -123,7 +124,7 @@ func Test_GroupBy(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Q() q.GroupBy("A") sql, _ := q.ToSQL(m) @@ -159,7 +160,7 @@ func Test_ToSQL(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(&User{}, tx.Context()) s := "SELECT name as full_name, users.alive, users.bio, users.birth_date, users.created_at, users.email, users.id, users.name, users.price, users.updated_at, users.user_name FROM users AS users" @@ -171,10 +172,10 @@ func Test_ToSQL(t *testing.T) { q, _ = query.ToSQL(user) a.Equal(fmt.Sprintf("%s ORDER BY id desc", s), q) - q, _ = query.ToSQL(&Model{Value: &User{}, As: "u"}) + q, _ = query.ToSQL(&Model{Value: &User{}, As: "u", ctx: tx.Context()}) a.Equal("SELECT name as full_name, u.alive, u.bio, u.birth_date, u.created_at, u.email, u.id, u.name, u.price, u.updated_at, u.user_name FROM users AS u ORDER BY id desc", q) - q, _ = query.ToSQL(&Model{Value: &Family{}}) + q, _ = query.ToSQL(&Model{Value: &Family{}, ctx: tx.Context()}) a.Equal("SELECT family_members.created_at, family_members.first_name, family_members.id, family_members.last_name, family_members.updated_at FROM family.members AS family_members ORDER BY id desc", q) query = tx.Where("id = 1") @@ -262,7 +263,7 @@ func Test_ToSQLInjection(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(new(User), tx.Context()) query := tx.Where("name = '?'", "\\\u0027 or 1=1 limit 1;\n-- ") q, _ := query.ToSQL(user) a.NotEqual("SELECT * FROM users AS users WHERE name = '\\'' or 1=1 limit 1;\n-- '", q) diff --git a/scopes_test.go b/scopes_test.go index 0e22b76e..f393ffb3 100644 --- a/scopes_test.go +++ b/scopes_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -13,7 +14,7 @@ func Test_Scopes(t *testing.T) { r := require.New(t) oql := "SELECT enemies.A FROM enemies AS enemies" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Q() s, _ := q.ToSQL(m) diff --git a/soda/cmd/migrate_status.go b/soda/cmd/migrate_status.go index e9a85df6..98d0aae4 100644 --- a/soda/cmd/migrate_status.go +++ b/soda/cmd/migrate_status.go @@ -1,9 +1,10 @@ package cmd import ( + "os" + "github.com/gobuffalo/pop/v5" "github.com/spf13/cobra" - "os" ) var migrateStatusCmd = &cobra.Command{ diff --git a/store.go b/store.go index cc98a266..39c554ed 100644 --- a/store.go +++ b/store.go @@ -54,3 +54,7 @@ func (s contextStore) Exec(query string, args ...interface{}) (sql.Result, error func (s contextStore) PrepareNamed(query string) (*sqlx.NamedStmt, error) { return s.store.PrepareNamedContext(s.ctx, query) } + +func (s contextStore) Context() context.Context { + return s.ctx +} diff --git a/testdata/migrations/20210104145901_context_tables.down.fizz b/testdata/migrations/20210104145901_context_tables.down.fizz new file mode 100644 index 00000000..d0f82ee2 --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.down.fizz @@ -0,0 +1,2 @@ +drop_table("context_prefix_a_table") +drop_table("context_prefix_b_table") diff --git a/testdata/migrations/20210104145901_context_tables.up.fizz b/testdata/migrations/20210104145901_context_tables.up.fizz new file mode 100644 index 00000000..ae94796f --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.up.fizz @@ -0,0 +1,9 @@ +create_table("context_prefix_a_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} + +create_table("context_prefix_b_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} diff --git a/testdata/models/ac/user.go b/testdata/models/ac/user.go new file mode 100644 index 00000000..39a5e934 --- /dev/null +++ b/testdata/models/ac/user.go @@ -0,0 +1,9 @@ +package ac + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "-usera" +} diff --git a/testdata/models/bc/user.go b/testdata/models/bc/user.go new file mode 100644 index 00000000..30b543bc --- /dev/null +++ b/testdata/models/bc/user.go @@ -0,0 +1,9 @@ +package bc + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "-userb" +} diff --git a/validations.go b/validations.go index 79e76105..ab5b845e 100644 --- a/validations.go +++ b/validations.go @@ -133,7 +133,7 @@ func (m *Model) iterateAndValidate(fn modelIterableValidator) (*validate.Errors, if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := NewModel(val.Addr().Interface(), m.ctx) verrs, err := fn(newModel) if err != nil || verrs.HasAny() {