diff --git a/belongs_to.go b/belongs_to.go index 0b5c977e..d261a315 100644 --- a/belongs_to.go +++ b/belongs_to.go @@ -19,7 +19,7 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query { // BelongsTo adds a "where" clause based on the "ID" of the // "model" passed into it. func (q *Query) BelongsTo(model interface{}) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID()) return q } @@ -27,7 +27,7 @@ func (q *Query) BelongsTo(model interface{}) *Query { // BelongsToAs adds a "where" clause based on the "ID" of the // "model" passed into it, using an alias. func (q *Query) BelongsToAs(model interface{}, as string) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", as), m.ID()) return q } @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query { // through the associated "thru" model. func (q *Query) BelongsToThrough(bt, thru interface{}) *Query { q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{ - BelongsTo: &Model{Value: bt}, - Through: &Model{Value: thru}, + BelongsTo: NewModel(bt, q.Connection.Context()), + Through: NewModel(thru, q.Connection.Context()), }) return q } diff --git a/belongs_to_test.go b/belongs_to_test.go index e4e3a3e7..9eb99914 100644 --- a/belongs_to_test.go +++ b/belongs_to_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) { q := PDB.BelongsTo(&User{ID: 1}) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql) @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) { q := PDB.BelongsToAs(&User{ID: 1}, "u_id") - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql) @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) { q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{}) qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts(qs), sql) } diff --git a/connection.go b/connection.go index df116fb3..6e0f55ee 100644 --- a/connection.go +++ b/connection.go @@ -33,6 +33,16 @@ func (c *Connection) URL() string { return c.Dialect.URL() } +// Context returns the connection's context set by "Context()" or context.TODO() +// if no context is set. +func (c *Connection) Context() context.Context { + if c, ok := c.Store.(interface{ Context() context.Context }); ok { + return c.Context() + } + + return context.TODO() +} + // MigrationURL returns the datasource connection string used for running the migrations func (c *Connection) MigrationURL() string { return c.Dialect.MigrationURL() diff --git a/connection_details.go b/connection_details.go index 29c0ae41..6456b7d1 100644 --- a/connection_details.go +++ b/connection_details.go @@ -2,13 +2,14 @@ package pop import ( "fmt" - "github.com/luna-duclos/instrumentedsql" "net/url" "regexp" "strconv" "strings" "time" + "github.com/luna-duclos/instrumentedsql" + "github.com/gobuffalo/pop/v5/internal/defaults" "github.com/gobuffalo/pop/v5/logging" "github.com/pkg/errors" diff --git a/connection_instrumented.go b/connection_instrumented.go index 84cf7820..f2d46b02 100644 --- a/connection_instrumented.go +++ b/connection_instrumented.go @@ -3,6 +3,7 @@ package pop import ( "database/sql" "database/sql/driver" + mysqld "github.com/go-sql-driver/mysql" "github.com/gobuffalo/pop/v5/logging" pgx "github.com/jackc/pgx/v4/stdlib" diff --git a/connection_instrumented_nosqlite_test.go b/connection_instrumented_nosqlite_test.go index 715a92ab..2cf6ae01 100644 --- a/connection_instrumented_nosqlite_test.go +++ b/connection_instrumented_nosqlite_test.go @@ -3,8 +3,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestInstrumentation_WithoutSqlite(t *testing.T) { diff --git a/connection_instrumented_test.go b/connection_instrumented_test.go index d5125639..808d8638 100644 --- a/connection_instrumented_test.go +++ b/connection_instrumented_test.go @@ -3,12 +3,13 @@ package pop import ( "context" "fmt" - "github.com/luna-duclos/instrumentedsql" - "github.com/stretchr/testify/suite" "os" "strings" "sync" "time" + + "github.com/luna-duclos/instrumentedsql" + "github.com/stretchr/testify/suite" ) func testInstrumentedDriver(p *suite.Suite) { diff --git a/dialect_sqlite.go b/dialect_sqlite.go index b3045090..47bc3e28 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -5,7 +5,6 @@ package pop import ( "database/sql/driver" "fmt" - "github.com/mattn/go-sqlite3" "io" "net/url" "os" @@ -15,6 +14,8 @@ import ( "sync" "time" + "github.com/mattn/go-sqlite3" + "github.com/gobuffalo/fizz" "github.com/gobuffalo/fizz/translators" _ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver diff --git a/executors.go b/executors.go index b3560141..657ecea7 100644 --- a/executors.go +++ b/executors.go @@ -13,7 +13,7 @@ import ( // Reload fetch fresh data for a given model, using its ID. func (c *Connection) Reload(model interface{}) error { - sm := Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.Find(m.Value, m.ID()) }) @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) { // // If model is a slice, each item of the slice is validated then saved in the database. func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool { // // If model is a slice, each item of the slice is saved in the database. func (c *Connection) Save(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then created in the database. func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err } } - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) verrs, err = sm.validateCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { c.disableEager() - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Create", func() error { var localIsEager = isEager @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { } if localIsEager { - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { fbn, err := m.fieldByName("ID") if err != nil { @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then updated in the database. func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) Update(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err // // If model is a slice, each item of the slice is deleted from the database. func (c *Connection) Destroy(model interface{}) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Destroy", func() error { var err error diff --git a/finders.go b/finders.go index 9026d14b..fa1a220a 100644 --- a/finders.go +++ b/finders.go @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error { // // q.Find(&User{}, 1) func (q *Query) Find(model interface{}, id interface{}) error { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) idq := m.whereID() switch t := id.(type) { case uuid.UUID: @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error { func (q *Query) First(model interface{}) error { err := q.Connection.timeFunc("First", func() error { q.Limit(1) - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error { err := q.Connection.timeFunc("Last", func() error { q.Limit(1) q.Order("created_at DESC, id DESC") - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error { // q.Where("name = ?", "mark").All(&[]User{}) func (q *Query) All(models interface{}) error { err := q.Connection.timeFunc("All", func() error { - m := &Model{Value: models} + m := NewModel(models, q.Connection.Context()) err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) if err != nil { return err @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { } } - sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) + sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context())) query = query.RawQuery(sqlSentence, args...) if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset diff --git a/finders_test.go b/finders_test.go index 389dc80a..7f30727e 100644 --- a/finders_test.go +++ b/finders_test.go @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) { q := tx.Select("name", "email", "\n", "\t\n", "") - sm := &Model{Value: &User{}} + sm := NewModel(new(User), tx.Context()) sql, _ := q.ToSQL(sm) r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql) diff --git a/match_test.go b/match_test.go index 8cc41a6b..0f1591f1 100644 --- a/match_test.go +++ b/match_test.go @@ -1,8 +1,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func Test_ParseMigrationFilenameFizzDown(t *testing.T) { diff --git a/migration_info_test.go b/migration_info_test.go index 0f76824e..f2174551 100644 --- a/migration_info_test.go +++ b/migration_info_test.go @@ -1,9 +1,10 @@ package pop import ( - "github.com/stretchr/testify/assert" "sort" "testing" + + "github.com/stretchr/testify/assert" ) func TestSortingMigrations(t *testing.T) { diff --git a/model.go b/model.go index 58019f73..1dea97f5 100644 --- a/model.go +++ b/model.go @@ -1,23 +1,23 @@ package pop import ( + "context" "fmt" - "github.com/gobuffalo/pop/v5/columns" - "github.com/pkg/errors" "reflect" - "sync" + "strings" "time" - "github.com/gobuffalo/flect" nflect "github.com/gobuffalo/flect/name" + + "github.com/gobuffalo/pop/v5/columns" + "github.com/pkg/errors" + + "github.com/gobuffalo/flect" "github.com/gofrs/uuid" ) var nowFunc = time.Now -var tableMap = map[string]string{} -var tableMapMu = sync.RWMutex{} - // Value is the contents of a `Model`. type Value interface{} @@ -27,8 +27,13 @@ type modelIterable func(*Model) error // that is passed in to many functions. type Model struct { Value - tableName string - As string + ctx context.Context + As string +} + +// NewModel returns a new model with the specified value and context. +func NewModel(v Value, ctx context.Context) *Model { + return &Model{Value: v, ctx: ctx} } // ID returns the ID of the Model. All models must have an `ID` field this is @@ -86,31 +91,32 @@ type TableNameAble interface { TableName() string } +// TableNameAbleWithContext is equal to TableNameAble but will +// be passed the queries' context. Useful in cases where the +// table name depends on e.g. +type TableNameAbleWithContext interface { + TableName(ctx context.Context) string +} + // TableName returns the corresponding name of the underlying database table // for a given `Model`. See also `TableNameAble` to change the default name of the table. func (m *Model) TableName() string { if s, ok := m.Value.(string); ok { return s } + if n, ok := m.Value.(TableNameAble); ok { return n.TableName() } - if m.tableName != "" { - return m.tableName + if n, ok := m.Value.(TableNameAbleWithContext); ok { + if m.ctx == nil { + m.ctx = context.TODO() + } + return n.TableName(m.ctx) } - t := reflect.TypeOf(m.Value) - name, cacheKey := m.typeName(t) - - defer tableMapMu.Unlock() - tableMapMu.Lock() - - if tableMap[cacheKey] == "" { - m.tableName = nflect.Tableize(name) - tableMap[cacheKey] = m.tableName - } - return tableMap[cacheKey] + return m.typeName(reflect.TypeOf(m.Value)) } func (m *Model) Columns() columns.Columns { @@ -121,7 +127,7 @@ func (m *Model) cacheKey(t reflect.Type) string { return t.PkgPath() + "." + t.Name() } -func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { +func (m *Model) typeName(t reflect.Type) (name string) { if t.Kind() == reflect.Ptr { t = t.Elem() } @@ -133,19 +139,27 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { } // validates if the elem of slice or array implements TableNameAble interface. - tableNameAble := (*TableNameAble)(nil) + var tableNameAble *TableNameAble if el.Implements(reflect.TypeOf(tableNameAble).Elem()) { v := reflect.New(el) out := v.MethodByName("TableName").Call([]reflect.Value{}) - name := out[0].String() - if tableMap[m.cacheKey(el)] == "" { - tableMap[m.cacheKey(el)] = name - } + return out[0].String() + } + + // validates if the elem of slice or array implements TableNameAbleWithContext interface. + var tableNameAbleWithContext *TableNameAbleWithContext + if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) { + v := reflect.New(el) + out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)}) + return out[0].String() + + // We do not want to cache contextualized TableNames because that would break + // the contextualization. } - return el.Name(), m.cacheKey(el) + return nflect.Tableize(el.Name()) default: - return t.Name(), m.cacheKey(t) + return nflect.Tableize(t.Name()) } } @@ -209,11 +223,21 @@ func (m *Model) touchUpdatedAt() { } func (m *Model) whereID() string { - return fmt.Sprintf("%s.%s = ?", m.TableName(), m.IDField()) + as := m.As + if as == "" { + as = strings.ReplaceAll(m.TableName(), ".", "_") + } + + return fmt.Sprintf("%s.%s = ?", as, m.IDField()) } func (m *Model) whereNamedID() string { - return fmt.Sprintf("%s.%s = :%s", m.TableName(), m.IDField(), m.IDField()) + as := m.As + if as == "" { + as = strings.ReplaceAll(m.TableName(), ".", "_") + } + + return fmt.Sprintf("%s.%s = :%s", as, m.IDField(), m.IDField()) } func (m *Model) isSlice() bool { @@ -226,7 +250,10 @@ func (m *Model) iterate(fn modelIterable) error { v := reflect.Indirect(reflect.ValueOf(m.Value)) for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := &Model{ + Value: val.Addr().Interface(), + ctx: m.ctx, + } err := fn(newModel) if err != nil { diff --git a/model_context_test.go b/model_context_test.go new file mode 100644 index 00000000..6b03a882 --- /dev/null +++ b/model_context_test.go @@ -0,0 +1,109 @@ +package pop + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type ContextTable struct { + ID string `db:"id"` + Value string `db:"value"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func (t ContextTable) TableName(ctx context.Context) string { + // This is singular on purpose! It will checck if the TableName is properly + // Respected in slices as well. + return "context_prefix_" + ctx.Value("prefix").(string) + "_table" +} + +func Test_ModelContext(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + + t.Run("contextless", func(t *testing.T) { + r := require.New(t) + r.Panics(func() { + var c ContextTable + r.NoError(PDB.Create(&c)) + }, "panics if context prefix is not set") + }) + + for _, prefix := range []string{"a", "b"} { + t.Run("prefix="+prefix, func(t *testing.T) { + r := require.New(t) + + expected := ContextTable{ID: prefix, Value: prefix} + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", prefix)) + r.NoError(c.Create(&expected)) + + var actual ContextTable + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + exists, err := c.Where("id = ?", actual.ID).Exists(new(ContextTable)) + r.NoError(err) + r.True(exists) + + count, err := c.Where("id = ?", actual.ID).Count(new(ContextTable)) + r.NoError(err) + r.EqualValues(1, count) + + expected.Value += expected.Value + r.NoError(c.Update(&expected)) + + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix+prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + var results []ContextTable + require.NoError(t, c.All(&results)) + + require.NoError(t, c.First(&expected)) + require.NoError(t, c.Last(&expected)) + + r.NoError(c.Destroy(&expected)) + }) + } + + t.Run("prefix=unknown", func(t *testing.T) { + r := require.New(t) + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", "unknown")) + err := c.Create(&ContextTable{ID: "unknown"}) + r.Error(err) + + if !strings.Contains(err.Error(), "context_prefix_unknown_table") { // All other databases + t.Fatalf("Expected error to contain indicator that table does not exist but got: %s", err.Error()) + } + }) + + t.Run("cache_busting", func(t *testing.T) { + r := require.New(t) + + var expectedA, expectedB ContextTable + expectedA.ID = "expectedA" + expectedB.ID = "expectedB" + + cA := PDB.WithContext(context.WithValue(context.Background(), "prefix", "a")) + r.NoError(cA.Create(&expectedA)) + + cB := PDB.WithContext(context.WithValue(context.Background(), "prefix", "b")) + r.NoError(cB.Create(&expectedB)) + + var actualA, actualB []ContextTable + r.NoError(cA.All(&actualA)) + r.NoError(cB.All(&actualB)) + + r.Len(actualA, 1) + r.Len(actualB, 1) + + r.NotEqual(actualA[0].ID, actualB[0].ID, "if these are equal context switching did not work") + }) +} diff --git a/model_test.go b/model_test.go index 545eb09e..d0d0c23c 100644 --- a/model_test.go +++ b/model_test.go @@ -1,9 +1,14 @@ package pop import ( + "context" + "fmt" "testing" "time" + "github.com/gobuffalo/pop/v5/testdata/models/ac" + "github.com/gobuffalo/pop/v5/testdata/models/bc" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -12,26 +17,25 @@ import ( ) func Test_Model_TableName(t *testing.T) { - r := require.New(t) - - m := Model{Value: User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &Users{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: []User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &[]User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: []*User{}} - r.Equal(m.TableName(), "users") - + for k, v := range []interface{}{ + User{}, + &User{}, + + &Users{}, + Users{}, + + []*User{}, + &[]*User{}, + + []User{}, + &[]User{}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + r := require.New(t) + m := Model{Value: v} + r.Equal("users", m.TableName()) + }) + } } type tn struct{} @@ -40,6 +44,12 @@ func (tn) TableName() string { return "this is my table name" } +type tnc struct{} + +func (tnc) TableName(ctx context.Context) string { + return ctx.Value("name").(string) +} + // A failing test case for #477 func Test_TableNameCache(t *testing.T) { r := assert.New(t) @@ -49,12 +59,27 @@ func Test_TableNameCache(t *testing.T) { r.Equal("userb", (&Model{Value: []b.User{}}).TableName()) } +// A failing test case for #477 +func Test_TableNameContextCache(t *testing.T) { + ctx := context.WithValue(context.Background(), "name", "context_table") + + r := assert.New(t) + r.Equal("context_table_useras", (&Model{Value: ac.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_userbs", (&Model{Value: bc.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_useras", (&Model{Value: []ac.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_userbs", (&Model{Value: []bc.User{}, ctx: ctx}).TableName()) +} + func Test_TableName(t *testing.T) { r := require.New(t) cases := []interface{}{ tn{}, + &tn{}, []tn{}, + &[]tn{}, + []*tn{}, + &[]*tn{}, } for _, tc := range cases { m := Model{Value: tc} @@ -62,6 +87,22 @@ func Test_TableName(t *testing.T) { } } +func Test_TableNameContext(t *testing.T) { + r := require.New(t) + + tn := "context_table_names" + ctx := context.WithValue(context.Background(), "name", tn) + + cases := []interface{}{ + tnc{}, + []tnc{}, + } + for _, tc := range cases { + m := Model{Value: tc, ctx: ctx} + r.Equal(tn, m.TableName()) + } +} + type TimeTimestamp struct { ID int `db:"id"` CreatedAt time.Time `db:"created_at"` @@ -77,7 +118,7 @@ type UnixTimestamp struct { func Test_Touch_Time_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &TimeTimestamp{}} + m := NewModel(&TimeTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -101,7 +142,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { createdAt := nowFunc().Add(-36 * time.Hour) - m := Model{Value: &TimeTimestamp{CreatedAt: createdAt}} + m := NewModel(&TimeTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*TimeTimestamp) @@ -112,7 +153,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { func Test_Touch_Unix_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &UnixTimestamp{}} + m := NewModel(&UnixTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -136,7 +177,7 @@ func Test_Touch_Unix_Timestamp_With_Existing_Value(t *testing.T) { createdAt := int(time.Now().Add(-36 * time.Hour).Unix()) - m := Model{Value: &UnixTimestamp{CreatedAt: createdAt}} + m := NewModel(&UnixTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*UnixTimestamp) @@ -159,3 +200,25 @@ func Test_IDField(t *testing.T) { m = Model{Value: &testNormalID{ID: 1}} r.Equal("id", m.IDField()) } + +type testPrefixID struct { + ID int `db:"custom_id"` +} + +func (t testPrefixID) TableName() string { + return "foo.bar" +} + +func Test_WhereID(t *testing.T) { + r := require.New(t) + m := Model{Value: &testPrefixID{ID: 1}} + + r.Equal("foo_bar.custom_id = ?", m.whereID()) + r.Equal("foo_bar.custom_id = :custom_id", m.whereNamedID()) + + type testNormalID struct { + ID int + } + m = Model{Value: &testNormalID{ID: 1}} + r.Equal("id", m.IDField()) +} diff --git a/preload_associations.go b/preload_associations.go index 8ddea3c9..c7c8b6c5 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -167,7 +167,7 @@ func (ami *AssociationMetaInfo) fkName() string { // preload is the query mode used to load associations from database // similar to the active record default approach on Rails. func preload(tx *Connection, model interface{}, fields ...string) error { - mmi := NewModelMetaInfo(&Model{Value: model}) + mmi := NewModelMetaInfo(NewModel(model, tx.Context())) preloadFields, err := mmi.preloadFields(fields...) if err != nil { diff --git a/query_test.go b/query_test.go index 93da765c..7df5a994 100644 --- a/query_test.go +++ b/query_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "fmt" "testing" @@ -13,7 +14,7 @@ func Test_Where(t *testing.T) { t.Skip("skipping integration tests") } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Where("id = ?", 1) sql, _ := q.ToSQL(m) @@ -107,7 +108,7 @@ func Test_Order(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Order("id desc") sql, _ := q.ToSQL(m) a.Equal(ts("SELECT enemies.A FROM enemies AS enemies ORDER BY id desc"), sql) @@ -123,7 +124,7 @@ func Test_GroupBy(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Q() q.GroupBy("A") sql, _ := q.ToSQL(m) @@ -159,7 +160,7 @@ func Test_ToSQL(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(&User{}, tx.Context()) s := "SELECT name as full_name, users.alive, users.bio, users.birth_date, users.created_at, users.email, users.id, users.name, users.price, users.updated_at, users.user_name FROM users AS users" @@ -171,10 +172,10 @@ func Test_ToSQL(t *testing.T) { q, _ = query.ToSQL(user) a.Equal(fmt.Sprintf("%s ORDER BY id desc", s), q) - q, _ = query.ToSQL(&Model{Value: &User{}, As: "u"}) + q, _ = query.ToSQL(&Model{Value: &User{}, As: "u", ctx: tx.Context()}) a.Equal("SELECT name as full_name, u.alive, u.bio, u.birth_date, u.created_at, u.email, u.id, u.name, u.price, u.updated_at, u.user_name FROM users AS u ORDER BY id desc", q) - q, _ = query.ToSQL(&Model{Value: &Family{}}) + q, _ = query.ToSQL(&Model{Value: &Family{}, ctx: tx.Context()}) a.Equal("SELECT family_members.created_at, family_members.first_name, family_members.id, family_members.last_name, family_members.updated_at FROM family.members AS family_members ORDER BY id desc", q) query = tx.Where("id = 1") @@ -262,7 +263,7 @@ func Test_ToSQLInjection(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(new(User), tx.Context()) query := tx.Where("name = '?'", "\\\u0027 or 1=1 limit 1;\n-- ") q, _ := query.ToSQL(user) a.NotEqual("SELECT * FROM users AS users WHERE name = '\\'' or 1=1 limit 1;\n-- '", q) diff --git a/scopes_test.go b/scopes_test.go index 0e22b76e..f393ffb3 100644 --- a/scopes_test.go +++ b/scopes_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -13,7 +14,7 @@ func Test_Scopes(t *testing.T) { r := require.New(t) oql := "SELECT enemies.A FROM enemies AS enemies" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Q() s, _ := q.ToSQL(m) diff --git a/soda/cmd/migrate_status.go b/soda/cmd/migrate_status.go index e9a85df6..98d0aae4 100644 --- a/soda/cmd/migrate_status.go +++ b/soda/cmd/migrate_status.go @@ -1,9 +1,10 @@ package cmd import ( + "os" + "github.com/gobuffalo/pop/v5" "github.com/spf13/cobra" - "os" ) var migrateStatusCmd = &cobra.Command{ diff --git a/store.go b/store.go index cc98a266..39c554ed 100644 --- a/store.go +++ b/store.go @@ -54,3 +54,7 @@ func (s contextStore) Exec(query string, args ...interface{}) (sql.Result, error func (s contextStore) PrepareNamed(query string) (*sqlx.NamedStmt, error) { return s.store.PrepareNamedContext(s.ctx, query) } + +func (s contextStore) Context() context.Context { + return s.ctx +} diff --git a/testdata/migrations/20210104145901_context_tables.down.fizz b/testdata/migrations/20210104145901_context_tables.down.fizz new file mode 100644 index 00000000..d0f82ee2 --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.down.fizz @@ -0,0 +1,2 @@ +drop_table("context_prefix_a_table") +drop_table("context_prefix_b_table") diff --git a/testdata/migrations/20210104145901_context_tables.up.fizz b/testdata/migrations/20210104145901_context_tables.up.fizz new file mode 100644 index 00000000..ae94796f --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.up.fizz @@ -0,0 +1,9 @@ +create_table("context_prefix_a_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} + +create_table("context_prefix_b_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} diff --git a/testdata/models/ac/user.go b/testdata/models/ac/user.go new file mode 100644 index 00000000..92335a16 --- /dev/null +++ b/testdata/models/ac/user.go @@ -0,0 +1,9 @@ +package ac + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "_useras" +} diff --git a/testdata/models/bc/user.go b/testdata/models/bc/user.go new file mode 100644 index 00000000..4b1c6257 --- /dev/null +++ b/testdata/models/bc/user.go @@ -0,0 +1,9 @@ +package bc + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "_userbs" +} diff --git a/validations.go b/validations.go index 79e76105..ab5b845e 100644 --- a/validations.go +++ b/validations.go @@ -133,7 +133,7 @@ func (m *Model) iterateAndValidate(fn modelIterableValidator) (*validate.Errors, if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := NewModel(val.Addr().Interface(), m.ctx) verrs, err := fn(newModel) if err != nil || verrs.HasAny() {