From 5d69084434a81bee40d46574d84be03f1eda33f2 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 4 Jan 2021 20:13:45 +0100 Subject: [PATCH 1/4] feat: support context-aware tablenames This patch adds a feature which enables pop to pass down the connection context to the model's `TableName()` function by implementing `TableName(ctx context.Context) string`. The context can be used to dynamically generate tablenames which can be important for prefixed or generic tables and other use cases. Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com> --- belongs_to.go | 8 +- belongs_to_test.go | 7 +- connection.go | 10 +++ connection_details.go | 3 +- connection_instrumented.go | 1 + connection_instrumented_nosqlite_test.go | 3 +- connection_instrumented_test.go | 5 +- dialect_sqlite.go | 3 +- executors.go | 28 +++---- finders.go | 14 ++-- finders_test.go | 2 +- match_test.go | 3 +- migration_info_test.go | 3 +- model.go | 44 +++++++++- model_context_test.go | 84 +++++++++++++++++++ model_test.go | 45 +++++++++- preload_associations.go | 2 +- query_test.go | 15 ++-- scopes_test.go | 3 +- soda/cmd/migrate_status.go | 3 +- store.go | 4 + .../20210104145901_context_tables.down.fizz | 2 + .../20210104145901_context_tables.up.fizz | 9 ++ testdata/models/ac/user.go | 9 ++ testdata/models/bc/user.go | 9 ++ validations.go | 2 +- 26 files changed, 266 insertions(+), 55 deletions(-) create mode 100644 model_context_test.go create mode 100644 testdata/migrations/20210104145901_context_tables.down.fizz create mode 100644 testdata/migrations/20210104145901_context_tables.up.fizz create mode 100644 testdata/models/ac/user.go create mode 100644 testdata/models/bc/user.go diff --git a/belongs_to.go b/belongs_to.go index 0b5c977e..d261a315 100644 --- a/belongs_to.go +++ b/belongs_to.go @@ -19,7 +19,7 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query { // BelongsTo adds a "where" clause based on the "ID" of the // "model" passed into it. func (q *Query) BelongsTo(model interface{}) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID()) return q } @@ -27,7 +27,7 @@ func (q *Query) BelongsTo(model interface{}) *Query { // BelongsToAs adds a "where" clause based on the "ID" of the // "model" passed into it, using an alias. func (q *Query) BelongsToAs(model interface{}, as string) *Query { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) q.Where(fmt.Sprintf("%s = ?", as), m.ID()) return q } @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query { // through the associated "thru" model. func (q *Query) BelongsToThrough(bt, thru interface{}) *Query { q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{ - BelongsTo: &Model{Value: bt}, - Through: &Model{Value: thru}, + BelongsTo: NewModel(bt, q.Connection.Context()), + Through: NewModel(thru, q.Connection.Context()), }) return q } diff --git a/belongs_to_test.go b/belongs_to_test.go index e4e3a3e7..9eb99914 100644 --- a/belongs_to_test.go +++ b/belongs_to_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) { q := PDB.BelongsTo(&User{ID: 1}) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql) @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) { q := PDB.BelongsToAs(&User{ID: 1}, "u_id") - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql) @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) { q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{}) qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) sql, _ := q.ToSQL(m) r.Equal(ts(qs), sql) } diff --git a/connection.go b/connection.go index df116fb3..6e0f55ee 100644 --- a/connection.go +++ b/connection.go @@ -33,6 +33,16 @@ func (c *Connection) URL() string { return c.Dialect.URL() } +// Context returns the connection's context set by "Context()" or context.TODO() +// if no context is set. +func (c *Connection) Context() context.Context { + if c, ok := c.Store.(interface{ Context() context.Context }); ok { + return c.Context() + } + + return context.TODO() +} + // MigrationURL returns the datasource connection string used for running the migrations func (c *Connection) MigrationURL() string { return c.Dialect.MigrationURL() diff --git a/connection_details.go b/connection_details.go index 29c0ae41..6456b7d1 100644 --- a/connection_details.go +++ b/connection_details.go @@ -2,13 +2,14 @@ package pop import ( "fmt" - "github.com/luna-duclos/instrumentedsql" "net/url" "regexp" "strconv" "strings" "time" + "github.com/luna-duclos/instrumentedsql" + "github.com/gobuffalo/pop/v5/internal/defaults" "github.com/gobuffalo/pop/v5/logging" "github.com/pkg/errors" diff --git a/connection_instrumented.go b/connection_instrumented.go index 84cf7820..f2d46b02 100644 --- a/connection_instrumented.go +++ b/connection_instrumented.go @@ -3,6 +3,7 @@ package pop import ( "database/sql" "database/sql/driver" + mysqld "github.com/go-sql-driver/mysql" "github.com/gobuffalo/pop/v5/logging" pgx "github.com/jackc/pgx/v4/stdlib" diff --git a/connection_instrumented_nosqlite_test.go b/connection_instrumented_nosqlite_test.go index 715a92ab..2cf6ae01 100644 --- a/connection_instrumented_nosqlite_test.go +++ b/connection_instrumented_nosqlite_test.go @@ -3,8 +3,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestInstrumentation_WithoutSqlite(t *testing.T) { diff --git a/connection_instrumented_test.go b/connection_instrumented_test.go index d5125639..808d8638 100644 --- a/connection_instrumented_test.go +++ b/connection_instrumented_test.go @@ -3,12 +3,13 @@ package pop import ( "context" "fmt" - "github.com/luna-duclos/instrumentedsql" - "github.com/stretchr/testify/suite" "os" "strings" "sync" "time" + + "github.com/luna-duclos/instrumentedsql" + "github.com/stretchr/testify/suite" ) func testInstrumentedDriver(p *suite.Suite) { diff --git a/dialect_sqlite.go b/dialect_sqlite.go index b3045090..47bc3e28 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -5,7 +5,6 @@ package pop import ( "database/sql/driver" "fmt" - "github.com/mattn/go-sqlite3" "io" "net/url" "os" @@ -15,6 +14,8 @@ import ( "sync" "time" + "github.com/mattn/go-sqlite3" + "github.com/gobuffalo/fizz" "github.com/gobuffalo/fizz/translators" _ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver diff --git a/executors.go b/executors.go index b3560141..657ecea7 100644 --- a/executors.go +++ b/executors.go @@ -13,7 +13,7 @@ import ( // Reload fetch fresh data for a given model, using its ID. func (c *Connection) Reload(model interface{}) error { - sm := Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.Find(m.Value, m.ID()) }) @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) { // // If model is a slice, each item of the slice is validated then saved in the database. func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool { // // If model is a slice, each item of the slice is saved in the database. func (c *Connection) Save(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then created in the database. func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) verrs, err := sm.validateAndOnlyCreate(c) if err != nil || verrs.HasAny() { return verrs, err } } - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) verrs, err = sm.validateCreate(c) if err != nil || verrs.HasAny() { return verrs, err @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { c.disableEager() - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Create", func() error { var localIsEager = isEager @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { } if localIsEager { - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { id, err := m.fieldByName("ID") if err != nil { @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { continue } - sm := &Model{Value: i} + sm := NewModel(i, c.Context()) err = sm.iterate(func(m *Model) error { fbn, err := m.fieldByName("ID") if err != nil { @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is validated then updated in the database. func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) if err := sm.beforeValidate(c); err != nil { return nil, err } @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) Update(model interface{}, excludeColumns ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { // // If model is a slice, each item of the slice is updated in the database. func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Update", func() error { var err error @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err // // If model is a slice, each item of the slice is deleted from the database. func (c *Connection) Destroy(model interface{}) error { - sm := &Model{Value: model} + sm := NewModel(model, c.Context()) return sm.iterate(func(m *Model) error { return c.timeFunc("Destroy", func() error { var err error diff --git a/finders.go b/finders.go index 9026d14b..fa1a220a 100644 --- a/finders.go +++ b/finders.go @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error { // // q.Find(&User{}, 1) func (q *Query) Find(model interface{}, id interface{}) error { - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) idq := m.whereID() switch t := id.(type) { case uuid.UUID: @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error { func (q *Query) First(model interface{}) error { err := q.Connection.timeFunc("First", func() error { q.Limit(1) - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error { err := q.Connection.timeFunc("Last", func() error { q.Limit(1) q.Order("created_at DESC, id DESC") - m := &Model{Value: model} + m := NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { return err } @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error { // q.Where("name = ?", "mark").All(&[]User{}) func (q *Query) All(models interface{}) error { err := q.Connection.timeFunc("All", func() error { - m := &Model{Value: models} + m := NewModel(models, q.Connection.Context()) err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) if err != nil { return err @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { } } - sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) + sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context())) query = query.RawQuery(sqlSentence, args...) if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) + query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context())) // when query contains custom selected fields / executed using RawQuery, // sql may already contains limit and offset diff --git a/finders_test.go b/finders_test.go index 389dc80a..7f30727e 100644 --- a/finders_test.go +++ b/finders_test.go @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) { q := tx.Select("name", "email", "\n", "\t\n", "") - sm := &Model{Value: &User{}} + sm := NewModel(new(User), tx.Context()) sql, _ := q.ToSQL(sm) r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql) diff --git a/match_test.go b/match_test.go index 8cc41a6b..0f1591f1 100644 --- a/match_test.go +++ b/match_test.go @@ -1,8 +1,9 @@ package pop import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func Test_ParseMigrationFilenameFizzDown(t *testing.T) { diff --git a/migration_info_test.go b/migration_info_test.go index 0f76824e..f2174551 100644 --- a/migration_info_test.go +++ b/migration_info_test.go @@ -1,9 +1,10 @@ package pop import ( - "github.com/stretchr/testify/assert" "sort" "testing" + + "github.com/stretchr/testify/assert" ) func TestSortingMigrations(t *testing.T) { diff --git a/model.go b/model.go index 58019f73..7ba7cee2 100644 --- a/model.go +++ b/model.go @@ -1,13 +1,15 @@ package pop import ( + "context" "fmt" - "github.com/gobuffalo/pop/v5/columns" - "github.com/pkg/errors" "reflect" "sync" "time" + "github.com/gobuffalo/pop/v5/columns" + "github.com/pkg/errors" + "github.com/gobuffalo/flect" nflect "github.com/gobuffalo/flect/name" "github.com/gofrs/uuid" @@ -27,10 +29,16 @@ type modelIterable func(*Model) error // that is passed in to many functions. type Model struct { Value + ctx context.Context tableName string As string } +// NewModel returns a new model with the specified value and context. +func NewModel(v Value, ctx context.Context) *Model { + return &Model{Value: v, ctx: ctx} +} + // ID returns the ID of the Model. All models must have an `ID` field this is // of type `int`,`int64` or of type `uuid.UUID`. func (m *Model) ID() interface{} { @@ -86,6 +94,13 @@ type TableNameAble interface { TableName() string } +// TableNameAbleWithContext is equal to TableNameAble but will +// be passed the queries' context. Useful in cases where the +// table name depends on e.g. +type TableNameAbleWithContext interface { + TableName(ctx context.Context) string +} + // TableName returns the corresponding name of the underlying database table // for a given `Model`. See also `TableNameAble` to change the default name of the table. func (m *Model) TableName() string { @@ -96,6 +111,13 @@ func (m *Model) TableName() string { return n.TableName() } + if n, ok := m.Value.(TableNameAbleWithContext); ok { + if m.ctx == nil { + m.ctx = context.TODO() + } + return n.TableName(m.ctx) + } + if m.tableName != "" { return m.tableName } @@ -133,7 +155,7 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { } // validates if the elem of slice or array implements TableNameAble interface. - tableNameAble := (*TableNameAble)(nil) + var tableNameAble *TableNameAble if el.Implements(reflect.TypeOf(tableNameAble).Elem()) { v := reflect.New(el) out := v.MethodByName("TableName").Call([]reflect.Value{}) @@ -143,6 +165,17 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { } } + // validates if the elem of slice or array implements TableNameAbleWithContext interface. + var tableNameAbleWithContext *TableNameAbleWithContext + if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) { + v := reflect.New(el) + out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)}) + name := out[0].String() + if tableMap[m.cacheKey(el)] == "" { + tableMap[m.cacheKey(el)] = name + } + } + return el.Name(), m.cacheKey(el) default: return t.Name(), m.cacheKey(t) @@ -226,7 +259,10 @@ func (m *Model) iterate(fn modelIterable) error { v := reflect.Indirect(reflect.ValueOf(m.Value)) for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := &Model{ + Value: val.Addr().Interface(), + ctx: m.ctx, + } err := fn(newModel) if err != nil { diff --git a/model_context_test.go b/model_context_test.go new file mode 100644 index 00000000..f5312310 --- /dev/null +++ b/model_context_test.go @@ -0,0 +1,84 @@ +package pop + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type ContextTable struct { + ID string `db:"id"` + Value string `db:"value"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func (t ContextTable) TableName(ctx context.Context) string { + return "context_prefix_" + ctx.Value("prefix").(string) + "_table" +} + +func Test_ModelContext(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + + t.Run("contextless", func(t *testing.T) { + r := require.New(t) + r.Panics(func() { + var c ContextTable + r.NoError(PDB.Create(&c)) + }, "panics if context prefix is not set") + }) + + for _, prefix := range []string{"a", "b"} { + t.Run("prefix="+prefix, func(t *testing.T) { + r := require.New(t) + + expected := ContextTable{ID: prefix, Value: prefix} + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", prefix)) + r.NoError(c.Create(&expected)) + + var actual ContextTable + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + exists, err := c.Where("id = ?", actual.ID).Exists(new(ContextTable)) + r.NoError(err) + r.True(exists) + + count, err := c.Where("id = ?", actual.ID).Count(new(ContextTable)) + r.NoError(err) + r.EqualValues(1, count) + + expected.Value += expected.Value + r.NoError(c.Update(&expected)) + + r.NoError(c.Find(&actual, expected.ID)) + r.EqualValues(prefix+prefix, actual.Value) + r.EqualValues(prefix, actual.ID) + + var results []ContextTable + require.NoError(t, c.All(&results)) + + require.NoError(t, c.First(&expected)) + require.NoError(t, c.Last(&expected)) + + r.NoError(c.Destroy(&expected)) + }) + } + + t.Run("prefix=unknown", func(t *testing.T) { + r := require.New(t) + c := PDB.WithContext(context.WithValue(context.Background(), "prefix", "unknown")) + err := c.Create(&ContextTable{ID: "unknown"}) + r.Error(err) + + if !strings.Contains(err.Error(), "context_prefix_unknown_table") { // All other databases + t.Fatalf("Expected error to contain indicator that table does not exist but got: %s", err.Error()) + } + }) +} diff --git a/model_test.go b/model_test.go index 545eb09e..5f8f7660 100644 --- a/model_test.go +++ b/model_test.go @@ -1,9 +1,13 @@ package pop import ( + "context" "testing" "time" + "github.com/gobuffalo/pop/v5/testdata/models/ac" + "github.com/gobuffalo/pop/v5/testdata/models/bc" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,6 +44,12 @@ func (tn) TableName() string { return "this is my table name" } +type tnc struct{} + +func (tnc) TableName(ctx context.Context) string { + return ctx.Value("name").(string) +} + // A failing test case for #477 func Test_TableNameCache(t *testing.T) { r := assert.New(t) @@ -49,6 +59,17 @@ func Test_TableNameCache(t *testing.T) { r.Equal("userb", (&Model{Value: []b.User{}}).TableName()) } +// A failing test case for #477 +func Test_TableNameContextCache(t *testing.T) { + ctx := context.WithValue(context.Background(), "name", "context-table") + + r := assert.New(t) + r.Equal("context-table-usera", (&Model{Value: ac.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-userb", (&Model{Value: bc.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-usera", (&Model{Value: []ac.User{}, ctx: ctx}).TableName()) + r.Equal("context-table-userb", (&Model{Value: []bc.User{}, ctx: ctx}).TableName()) +} + func Test_TableName(t *testing.T) { r := require.New(t) @@ -62,6 +83,22 @@ func Test_TableName(t *testing.T) { } } +func Test_TableNameContext(t *testing.T) { + r := require.New(t) + + tn := "context table name" + ctx := context.WithValue(context.Background(), "name", tn) + + cases := []interface{}{ + tnc{}, + []tnc{}, + } + for _, tc := range cases { + m := Model{Value: tc, ctx: ctx} + r.Equal(tn, m.TableName()) + } +} + type TimeTimestamp struct { ID int `db:"id"` CreatedAt time.Time `db:"created_at"` @@ -77,7 +114,7 @@ type UnixTimestamp struct { func Test_Touch_Time_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &TimeTimestamp{}} + m := NewModel(&TimeTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -101,7 +138,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { createdAt := nowFunc().Add(-36 * time.Hour) - m := Model{Value: &TimeTimestamp{CreatedAt: createdAt}} + m := NewModel(&TimeTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*TimeTimestamp) @@ -112,7 +149,7 @@ func Test_Touch_Time_Timestamp_With_Existing_Value(t *testing.T) { func Test_Touch_Unix_Timestamp(t *testing.T) { r := require.New(t) - m := Model{Value: &UnixTimestamp{}} + m := NewModel(&UnixTimestamp{}, context.Background()) // Override time.Now() t0, _ := time.Parse(time.RFC3339, "2019-07-14T00:00:00Z") @@ -136,7 +173,7 @@ func Test_Touch_Unix_Timestamp_With_Existing_Value(t *testing.T) { createdAt := int(time.Now().Add(-36 * time.Hour).Unix()) - m := Model{Value: &UnixTimestamp{CreatedAt: createdAt}} + m := NewModel(&UnixTimestamp{CreatedAt: createdAt}, context.Background()) m.touchCreatedAt() m.touchUpdatedAt() v := m.Value.(*UnixTimestamp) diff --git a/preload_associations.go b/preload_associations.go index 4a03b579..dc19ccea 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -167,7 +167,7 @@ func (ami *AssociationMetaInfo) fkName() string { // preload is the query mode used to load associations from database // similar to the active record default approach on Rails. func preload(tx *Connection, model interface{}, fields ...string) error { - mmi := NewModelMetaInfo(&Model{Value: model}) + mmi := NewModelMetaInfo(NewModel(model, tx.Context())) preloadFields, err := mmi.preloadFields(fields...) if err != nil { diff --git a/query_test.go b/query_test.go index 93da765c..7df5a994 100644 --- a/query_test.go +++ b/query_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "fmt" "testing" @@ -13,7 +14,7 @@ func Test_Where(t *testing.T) { t.Skip("skipping integration tests") } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Where("id = ?", 1) sql, _ := q.ToSQL(m) @@ -107,7 +108,7 @@ func Test_Order(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Order("id desc") sql, _ := q.ToSQL(m) a.Equal(ts("SELECT enemies.A FROM enemies AS enemies ORDER BY id desc"), sql) @@ -123,7 +124,7 @@ func Test_GroupBy(t *testing.T) { } a := require.New(t) - m := &Model{Value: &Enemy{}} + m := NewModel(&Enemy{}, context.Background()) q := PDB.Q() q.GroupBy("A") sql, _ := q.ToSQL(m) @@ -159,7 +160,7 @@ func Test_ToSQL(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(&User{}, tx.Context()) s := "SELECT name as full_name, users.alive, users.bio, users.birth_date, users.created_at, users.email, users.id, users.name, users.price, users.updated_at, users.user_name FROM users AS users" @@ -171,10 +172,10 @@ func Test_ToSQL(t *testing.T) { q, _ = query.ToSQL(user) a.Equal(fmt.Sprintf("%s ORDER BY id desc", s), q) - q, _ = query.ToSQL(&Model{Value: &User{}, As: "u"}) + q, _ = query.ToSQL(&Model{Value: &User{}, As: "u", ctx: tx.Context()}) a.Equal("SELECT name as full_name, u.alive, u.bio, u.birth_date, u.created_at, u.email, u.id, u.name, u.price, u.updated_at, u.user_name FROM users AS u ORDER BY id desc", q) - q, _ = query.ToSQL(&Model{Value: &Family{}}) + q, _ = query.ToSQL(&Model{Value: &Family{}, ctx: tx.Context()}) a.Equal("SELECT family_members.created_at, family_members.first_name, family_members.id, family_members.last_name, family_members.updated_at FROM family.members AS family_members ORDER BY id desc", q) query = tx.Where("id = 1") @@ -262,7 +263,7 @@ func Test_ToSQLInjection(t *testing.T) { } a := require.New(t) transaction(func(tx *Connection) { - user := &Model{Value: &User{}} + user := NewModel(new(User), tx.Context()) query := tx.Where("name = '?'", "\\\u0027 or 1=1 limit 1;\n-- ") q, _ := query.ToSQL(user) a.NotEqual("SELECT * FROM users AS users WHERE name = '\\'' or 1=1 limit 1;\n-- '", q) diff --git a/scopes_test.go b/scopes_test.go index 0e22b76e..f393ffb3 100644 --- a/scopes_test.go +++ b/scopes_test.go @@ -1,6 +1,7 @@ package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -13,7 +14,7 @@ func Test_Scopes(t *testing.T) { r := require.New(t) oql := "SELECT enemies.A FROM enemies AS enemies" - m := &Model{Value: &Enemy{}} + m := NewModel(new(Enemy), context.Background()) q := PDB.Q() s, _ := q.ToSQL(m) diff --git a/soda/cmd/migrate_status.go b/soda/cmd/migrate_status.go index e9a85df6..98d0aae4 100644 --- a/soda/cmd/migrate_status.go +++ b/soda/cmd/migrate_status.go @@ -1,9 +1,10 @@ package cmd import ( + "os" + "github.com/gobuffalo/pop/v5" "github.com/spf13/cobra" - "os" ) var migrateStatusCmd = &cobra.Command{ diff --git a/store.go b/store.go index cc98a266..39c554ed 100644 --- a/store.go +++ b/store.go @@ -54,3 +54,7 @@ func (s contextStore) Exec(query string, args ...interface{}) (sql.Result, error func (s contextStore) PrepareNamed(query string) (*sqlx.NamedStmt, error) { return s.store.PrepareNamedContext(s.ctx, query) } + +func (s contextStore) Context() context.Context { + return s.ctx +} diff --git a/testdata/migrations/20210104145901_context_tables.down.fizz b/testdata/migrations/20210104145901_context_tables.down.fizz new file mode 100644 index 00000000..d0f82ee2 --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.down.fizz @@ -0,0 +1,2 @@ +drop_table("context_prefix_a_table") +drop_table("context_prefix_b_table") diff --git a/testdata/migrations/20210104145901_context_tables.up.fizz b/testdata/migrations/20210104145901_context_tables.up.fizz new file mode 100644 index 00000000..ae94796f --- /dev/null +++ b/testdata/migrations/20210104145901_context_tables.up.fizz @@ -0,0 +1,9 @@ +create_table("context_prefix_a_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} + +create_table("context_prefix_b_table") { + t.Column("id", "string", { primary: true }) + t.Column("value", "string") +} diff --git a/testdata/models/ac/user.go b/testdata/models/ac/user.go new file mode 100644 index 00000000..39a5e934 --- /dev/null +++ b/testdata/models/ac/user.go @@ -0,0 +1,9 @@ +package ac + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "-usera" +} diff --git a/testdata/models/bc/user.go b/testdata/models/bc/user.go new file mode 100644 index 00000000..30b543bc --- /dev/null +++ b/testdata/models/bc/user.go @@ -0,0 +1,9 @@ +package bc + +import "context" + +type User struct{} + +func (u User) TableName(ctx context.Context) string { + return ctx.Value("name").(string) + "-userb" +} diff --git a/validations.go b/validations.go index 79e76105..ab5b845e 100644 --- a/validations.go +++ b/validations.go @@ -133,7 +133,7 @@ func (m *Model) iterateAndValidate(fn modelIterableValidator) (*validate.Errors, if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { for i := 0; i < v.Len(); i++ { val := v.Index(i) - newModel := &Model{Value: val.Addr().Interface()} + newModel := NewModel(val.Addr().Interface(), m.ctx) verrs, err := fn(newModel) if err != nil || verrs.HasAny() { From bf09b4d7ce1ea67e300d2a015fdf1f8fe2322263 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 8 Jan 2021 13:09:57 +0100 Subject: [PATCH 2/4] fix: remove ineffective and bug-ridden model `tableMap` cache This patch removes the Model's `tableMap`. The cache seems to have served a purpose before `TableName()` was introduced. Now however, it does not significantly reduce computation time and instead introduces a myriad of really difficult to debug problems. This usually happens due to cacheKey collisions. The amount of time the TCP connection will use to query the database is significantly more than calling `reflect.Elem()`, further delegitimizing the complexity of having a cache here. This patch completely removes the cache and adds failing test cases to ensure that future regressions do not resurface. Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com> --- model.go | 38 +++++++++++++++----------------------- model_context_test.go | 25 +++++++++++++++++++++++++ model_test.go | 12 ++++++------ testdata/models/ac/user.go | 2 +- testdata/models/bc/user.go | 2 +- 5 files changed, 48 insertions(+), 31 deletions(-) diff --git a/model.go b/model.go index 7ba7cee2..a2ff0bef 100644 --- a/model.go +++ b/model.go @@ -4,14 +4,16 @@ import ( "context" "fmt" "reflect" + "strings" "sync" "time" + nflect "github.com/gobuffalo/flect/name" + "github.com/gobuffalo/pop/v5/columns" "github.com/pkg/errors" "github.com/gobuffalo/flect" - nflect "github.com/gobuffalo/flect/name" "github.com/gofrs/uuid" ) @@ -107,6 +109,7 @@ func (m *Model) TableName() string { if s, ok := m.Value.(string); ok { return s } + if n, ok := m.Value.(TableNameAble); ok { return n.TableName() } @@ -118,21 +121,13 @@ func (m *Model) TableName() string { return n.TableName(m.ctx) } + m.isSlice() + if m.tableName != "" { return m.tableName } - t := reflect.TypeOf(m.Value) - name, cacheKey := m.typeName(t) - - defer tableMapMu.Unlock() - tableMapMu.Lock() - - if tableMap[cacheKey] == "" { - m.tableName = nflect.Tableize(name) - tableMap[cacheKey] = m.tableName - } - return tableMap[cacheKey] + return m.typeName(reflect.TypeOf(m.Value)) } func (m *Model) Columns() columns.Columns { @@ -143,7 +138,7 @@ func (m *Model) cacheKey(t reflect.Type) string { return t.PkgPath() + "." + t.Name() } -func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { +func (m *Model) typeName(t reflect.Type) (name string) { if t.Kind() == reflect.Ptr { t = t.Elem() } @@ -159,10 +154,7 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { if el.Implements(reflect.TypeOf(tableNameAble).Elem()) { v := reflect.New(el) out := v.MethodByName("TableName").Call([]reflect.Value{}) - name := out[0].String() - if tableMap[m.cacheKey(el)] == "" { - tableMap[m.cacheKey(el)] = name - } + return out[0].String() } // validates if the elem of slice or array implements TableNameAbleWithContext interface. @@ -170,15 +162,15 @@ func (m *Model) typeName(t reflect.Type) (name, cacheKey string) { if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) { v := reflect.New(el) out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)}) - name := out[0].String() - if tableMap[m.cacheKey(el)] == "" { - tableMap[m.cacheKey(el)] = name - } + return out[0].String() + + // We do not want to cache contextualized TableNames because that would break + // the contextualization. } - return el.Name(), m.cacheKey(el) + return nflect.Tableize(name) default: - return t.Name(), m.cacheKey(t) + return nflect.Tableize(t.Name()) } } diff --git a/model_context_test.go b/model_context_test.go index f5312310..d6a14213 100644 --- a/model_context_test.go +++ b/model_context_test.go @@ -17,6 +17,8 @@ type ContextTable struct { } func (t ContextTable) TableName(ctx context.Context) string { + // This is singular on purpose! It will checck if the TableName is properly + // Respected in slices as well. return "context_prefix_" + ctx.Value("prefix").(string) + "_table" } @@ -81,4 +83,27 @@ func Test_ModelContext(t *testing.T) { t.Fatalf("Expected error to contain indicator that table does not exist but got: %s", err.Error()) } }) + + t.Run("cache_busting", func(t *testing.T) { + r := require.New(t) + + var expectedA, expectedB ContextTable + expectedA.ID = "expectedA" + expectedB.ID = "expectedB" + + cA := PDB.WithContext(context.WithValue(context.Background(), "prefix", "a")) + r.NoError(cA.Create(&expectedA)) + + cB := PDB.WithContext(context.WithValue(context.Background(), "prefix", "b")) + r.NoError(cB.Create(&expectedB)) + + var actualA, actualB []ContextTable + r.NoError(cA.All(&actualA)) + r.NoError(cB.All(&actualB)) + + r.Len(cA, 1) + r.Len(cB, 1) + + r.NotEqual(cA.ID, cB.ID, "if these are equal context switching did not work") + }) } diff --git a/model_test.go b/model_test.go index 5f8f7660..b1bf4dc8 100644 --- a/model_test.go +++ b/model_test.go @@ -61,13 +61,13 @@ func Test_TableNameCache(t *testing.T) { // A failing test case for #477 func Test_TableNameContextCache(t *testing.T) { - ctx := context.WithValue(context.Background(), "name", "context-table") + ctx := context.WithValue(context.Background(), "name", "context_table") r := assert.New(t) - r.Equal("context-table-usera", (&Model{Value: ac.User{}, ctx: ctx}).TableName()) - r.Equal("context-table-userb", (&Model{Value: bc.User{}, ctx: ctx}).TableName()) - r.Equal("context-table-usera", (&Model{Value: []ac.User{}, ctx: ctx}).TableName()) - r.Equal("context-table-userb", (&Model{Value: []bc.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_useras", (&Model{Value: ac.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_userbs", (&Model{Value: bc.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_useras", (&Model{Value: []ac.User{}, ctx: ctx}).TableName()) + r.Equal("context_table_userbs", (&Model{Value: []bc.User{}, ctx: ctx}).TableName()) } func Test_TableName(t *testing.T) { @@ -86,7 +86,7 @@ func Test_TableName(t *testing.T) { func Test_TableNameContext(t *testing.T) { r := require.New(t) - tn := "context table name" + tn := "context_table_names" ctx := context.WithValue(context.Background(), "name", tn) cases := []interface{}{ diff --git a/testdata/models/ac/user.go b/testdata/models/ac/user.go index 39a5e934..92335a16 100644 --- a/testdata/models/ac/user.go +++ b/testdata/models/ac/user.go @@ -5,5 +5,5 @@ import "context" type User struct{} func (u User) TableName(ctx context.Context) string { - return ctx.Value("name").(string) + "-usera" + return ctx.Value("name").(string) + "_useras" } diff --git a/testdata/models/bc/user.go b/testdata/models/bc/user.go index 30b543bc..4b1c6257 100644 --- a/testdata/models/bc/user.go +++ b/testdata/models/bc/user.go @@ -5,5 +5,5 @@ import "context" type User struct{} func (u User) TableName(ctx context.Context) string { - return ctx.Value("name").(string) + "-userb" + return ctx.Value("name").(string) + "_userbs" } From 333700c8729e5e1048889dc135b72da024dafed8 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 8 Jan 2021 13:10:26 +0100 Subject: [PATCH 3/4] fix: respect as naming in whereID and whereNamedID Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com> --- model.go | 18 ++++++++++++------ model_test.go | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/model.go b/model.go index a2ff0bef..5fb4b43c 100644 --- a/model.go +++ b/model.go @@ -5,7 +5,6 @@ import ( "fmt" "reflect" "strings" - "sync" "time" nflect "github.com/gobuffalo/flect/name" @@ -19,9 +18,6 @@ import ( var nowFunc = time.Now -var tableMap = map[string]string{} -var tableMapMu = sync.RWMutex{} - // Value is the contents of a `Model`. type Value interface{} @@ -234,11 +230,21 @@ func (m *Model) touchUpdatedAt() { } func (m *Model) whereID() string { - return fmt.Sprintf("%s.%s = ?", m.TableName(), m.IDField()) + as := m.As + if as == "" { + as = strings.ReplaceAll(m.TableName(), ".", "_") + } + + return fmt.Sprintf("%s.%s = ?", as, m.IDField()) } func (m *Model) whereNamedID() string { - return fmt.Sprintf("%s.%s = :%s", m.TableName(), m.IDField(), m.IDField()) + as := m.As + if as == "" { + as = strings.ReplaceAll(m.TableName(), ".", "_") + } + + return fmt.Sprintf("%s.%s = :%s", as, m.IDField(), m.IDField()) } func (m *Model) isSlice() bool { diff --git a/model_test.go b/model_test.go index b1bf4dc8..fb9cf89e 100644 --- a/model_test.go +++ b/model_test.go @@ -196,3 +196,25 @@ func Test_IDField(t *testing.T) { m = Model{Value: &testNormalID{ID: 1}} r.Equal("id", m.IDField()) } + +type testPrefixID struct { + ID int `db:"custom_id"` +} + +func (t testPrefixID) TableName() string { + return "foo.bar" +} + +func Test_WhereID(t *testing.T) { + r := require.New(t) + m := Model{Value: &testPrefixID{ID: 1}} + + r.Equal("foo_bar_custom_id = ?", m.whereID()) + r.Equal("foo_bar_custom_id = ?", m.whereNamedID()) + + type testNormalID struct { + ID int + } + m = Model{Value: &testNormalID{ID: 1}} + r.Equal("id", m.IDField()) +} From 84c72b7cebf1fad2bd155f7730541e67e3376325 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 8 Jan 2021 13:54:08 +0100 Subject: [PATCH 4/4] fix: resolve test failures and regressions Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com> --- model.go | 13 +++--------- model_context_test.go | 6 +++--- model_test.go | 48 +++++++++++++++++++++++-------------------- 3 files changed, 32 insertions(+), 35 deletions(-) diff --git a/model.go b/model.go index 5fb4b43c..1dea97f5 100644 --- a/model.go +++ b/model.go @@ -27,9 +27,8 @@ type modelIterable func(*Model) error // that is passed in to many functions. type Model struct { Value - ctx context.Context - tableName string - As string + ctx context.Context + As string } // NewModel returns a new model with the specified value and context. @@ -117,12 +116,6 @@ func (m *Model) TableName() string { return n.TableName(m.ctx) } - m.isSlice() - - if m.tableName != "" { - return m.tableName - } - return m.typeName(reflect.TypeOf(m.Value)) } @@ -164,7 +157,7 @@ func (m *Model) typeName(t reflect.Type) (name string) { // the contextualization. } - return nflect.Tableize(name) + return nflect.Tableize(el.Name()) default: return nflect.Tableize(t.Name()) } diff --git a/model_context_test.go b/model_context_test.go index d6a14213..6b03a882 100644 --- a/model_context_test.go +++ b/model_context_test.go @@ -101,9 +101,9 @@ func Test_ModelContext(t *testing.T) { r.NoError(cA.All(&actualA)) r.NoError(cB.All(&actualB)) - r.Len(cA, 1) - r.Len(cB, 1) + r.Len(actualA, 1) + r.Len(actualB, 1) - r.NotEqual(cA.ID, cB.ID, "if these are equal context switching did not work") + r.NotEqual(actualA[0].ID, actualB[0].ID, "if these are equal context switching did not work") }) } diff --git a/model_test.go b/model_test.go index fb9cf89e..d0d0c23c 100644 --- a/model_test.go +++ b/model_test.go @@ -2,6 +2,7 @@ package pop import ( "context" + "fmt" "testing" "time" @@ -16,26 +17,25 @@ import ( ) func Test_Model_TableName(t *testing.T) { - r := require.New(t) - - m := Model{Value: User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &Users{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: []User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: &[]User{}} - r.Equal(m.TableName(), "users") - - m = Model{Value: []*User{}} - r.Equal(m.TableName(), "users") - + for k, v := range []interface{}{ + User{}, + &User{}, + + &Users{}, + Users{}, + + []*User{}, + &[]*User{}, + + []User{}, + &[]User{}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + r := require.New(t) + m := Model{Value: v} + r.Equal("users", m.TableName()) + }) + } } type tn struct{} @@ -75,7 +75,11 @@ func Test_TableName(t *testing.T) { cases := []interface{}{ tn{}, + &tn{}, []tn{}, + &[]tn{}, + []*tn{}, + &[]*tn{}, } for _, tc := range cases { m := Model{Value: tc} @@ -209,8 +213,8 @@ func Test_WhereID(t *testing.T) { r := require.New(t) m := Model{Value: &testPrefixID{ID: 1}} - r.Equal("foo_bar_custom_id = ?", m.whereID()) - r.Equal("foo_bar_custom_id = ?", m.whereNamedID()) + r.Equal("foo_bar.custom_id = ?", m.whereID()) + r.Equal("foo_bar.custom_id = :custom_id", m.whereNamedID()) type testNormalID struct { ID int