Skip to content

Commit

Permalink
feat: support context-aware tablenames
Browse files Browse the repository at this point in the history
This patch adds a feature which enables pop to pass down the connection context to the model's `TableName()` function by implementing `TableName(ctx context.Context) string`. The context can be used to dynamically generate tablenames which can be important for prefixed or generic tables and other use cases.

Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
aeneasr committed Jan 4, 2021
1 parent 0e3d2e2 commit 713f63a
Show file tree
Hide file tree
Showing 26 changed files with 262 additions and 55 deletions.
8 changes: 4 additions & 4 deletions belongs_to.go
Expand Up @@ -19,15 +19,15 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query {
// BelongsTo adds a "where" clause based on the "ID" of the
// "model" passed into it.
func (q *Query) BelongsTo(model interface{}) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID())
return q
}

// BelongsToAs adds a "where" clause based on the "ID" of the
// "model" passed into it, using an alias.
func (q *Query) BelongsToAs(model interface{}, as string) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", as), m.ID())
return q
}
Expand All @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query {
// through the associated "thru" model.
func (q *Query) BelongsToThrough(bt, thru interface{}) *Query {
q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{
BelongsTo: &Model{Value: bt},
Through: &Model{Value: thru},
BelongsTo: NewModel(bt, q.Connection.Context()),
Through: NewModel(thru, q.Connection.Context()),
})
return q
}
7 changes: 4 additions & 3 deletions belongs_to_test.go
@@ -1,6 +1,7 @@
package pop

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) {

q := PDB.BelongsTo(&User{ID: 1})

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql)
Expand All @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) {

q := PDB.BelongsToAs(&User{ID: 1}, "u_id")

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql)
Expand All @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) {
q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{})
qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id"

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())
sql, _ := q.ToSQL(m)
r.Equal(ts(qs), sql)
}
10 changes: 10 additions & 0 deletions connection.go
Expand Up @@ -33,6 +33,16 @@ func (c *Connection) URL() string {
return c.Dialect.URL()
}

// Context returns the connection's context set by "Context()" or context.TODO()
// if no context is set.
func (c *Connection) Context() context.Context {
if c, ok := c.Store.(interface{ Context() context.Context }); ok {
return c.Context()
}

return context.TODO()
}

// MigrationURL returns the datasource connection string used for running the migrations
func (c *Connection) MigrationURL() string {
return c.Dialect.MigrationURL()
Expand Down
3 changes: 2 additions & 1 deletion connection_details.go
Expand Up @@ -2,13 +2,14 @@ package pop

import (
"fmt"
"github.com/luna-duclos/instrumentedsql"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/luna-duclos/instrumentedsql"

"github.com/gobuffalo/pop/v5/internal/defaults"
"github.com/gobuffalo/pop/v5/logging"
"github.com/pkg/errors"
Expand Down
1 change: 1 addition & 0 deletions connection_instrumented.go
Expand Up @@ -3,6 +3,7 @@ package pop
import (
"database/sql"
"database/sql/driver"

mysqld "github.com/go-sql-driver/mysql"
"github.com/gobuffalo/pop/v5/logging"
pgx "github.com/jackc/pgx/v4/stdlib"
Expand Down
3 changes: 2 additions & 1 deletion connection_instrumented_nosqlite_test.go
Expand Up @@ -3,8 +3,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestInstrumentation_WithoutSqlite(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions connection_instrumented_test.go
Expand Up @@ -3,12 +3,13 @@ package pop
import (
"context"
"fmt"
"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
"os"
"strings"
"sync"
"time"

"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
)

func testInstrumentedDriver(p *suite.Suite) {
Expand Down
3 changes: 2 additions & 1 deletion dialect_sqlite.go
Expand Up @@ -5,7 +5,6 @@ package pop
import (
"database/sql/driver"
"fmt"
"github.com/mattn/go-sqlite3"
"io"
"net/url"
"os"
Expand All @@ -15,6 +14,8 @@ import (
"sync"
"time"

"github.com/mattn/go-sqlite3"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/fizz/translators"
_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
Expand Down
28 changes: 14 additions & 14 deletions executors.go
Expand Up @@ -13,7 +13,7 @@ import (

// Reload fetch fresh data for a given model, using its ID.
func (c *Connection) Reload(model interface{}) error {
sm := Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.Find(m.Value, m.ID())
})
Expand Down Expand Up @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) {
//
// If model is a slice, each item of the slice is validated then saved in the database.
func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool {
//
// If model is a slice, each item of the slice is saved in the database.
func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand All @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then created in the database.
func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
}
}

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
verrs, err = sm.validateCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {

c.disableEager()

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Create", func() error {
var localIsEager = isEager
Expand Down Expand Up @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
}

if localIsEager {
sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
fbn, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then updated in the database.
func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err
//
// If model is a slice, each item of the slice is deleted from the database.
func (c *Connection) Destroy(model interface{}) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Destroy", func() error {
var err error
Expand Down
14 changes: 7 additions & 7 deletions finders.go
Expand Up @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error {
//
// q.Find(&User{}, 1)
func (q *Query) Find(model interface{}, id interface{}) error {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
idq := m.whereID()
switch t := id.(type) {
case uuid.UUID:
Expand Down Expand Up @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error {
func (q *Query) First(model interface{}) error {
err := q.Connection.timeFunc("First", func() error {
q.Limit(1)
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error {
err := q.Connection.timeFunc("Last", func() error {
q.Limit(1)
q.Order("created_at DESC, id DESC")
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error {
// q.Where("name = ?", "mark").All(&[]User{})
func (q *Query) All(models interface{}) error {
err := q.Connection.timeFunc("All", func() error {
m := &Model{Value: models}
m := NewModel(models, q.Connection.Context())
err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
if err != nil {
return err
Expand Down Expand Up @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
}
}

sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context()))
query = query.RawQuery(sqlSentence, args...)

if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
Expand Down Expand Up @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context()))

// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset
Expand Down Expand Up @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context()))
// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset

Expand Down
2 changes: 1 addition & 1 deletion finders_test.go
Expand Up @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) {

q := tx.Select("name", "email", "\n", "\t\n", "")

sm := &Model{Value: &User{}}
sm := NewModel(new(User), tx.Context())
sql, _ := q.ToSQL(sm)
r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql)

Expand Down
3 changes: 2 additions & 1 deletion match_test.go
@@ -1,8 +1,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func Test_ParseMigrationFilenameFizzDown(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion migration_info_test.go
@@ -1,9 +1,10 @@
package pop

import (
"github.com/stretchr/testify/assert"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSortingMigrations(t *testing.T) {
Expand Down

0 comments on commit 713f63a

Please sign in to comment.