Skip to content

Commit

Permalink
feat: support context-aware tablenames (#614)
Browse files Browse the repository at this point in the history
This patch adds a feature which enables pop to pass down the connection context to the model's TableName() function by implementing TableName(ctx context.Context) string. The context can be used to dynamically generate tablenames which can be important for prefixed or generic tables and other use cases.
  • Loading branch information
aeneasr committed Jan 18, 2021
1 parent b2918a3 commit 0fb7635
Show file tree
Hide file tree
Showing 26 changed files with 357 additions and 104 deletions.
8 changes: 4 additions & 4 deletions belongs_to.go
Expand Up @@ -19,15 +19,15 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query {
// BelongsTo adds a "where" clause based on the "ID" of the
// "model" passed into it.
func (q *Query) BelongsTo(model interface{}) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID())
return q
}

// BelongsToAs adds a "where" clause based on the "ID" of the
// "model" passed into it, using an alias.
func (q *Query) BelongsToAs(model interface{}, as string) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", as), m.ID())
return q
}
Expand All @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query {
// through the associated "thru" model.
func (q *Query) BelongsToThrough(bt, thru interface{}) *Query {
q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{
BelongsTo: &Model{Value: bt},
Through: &Model{Value: thru},
BelongsTo: NewModel(bt, q.Connection.Context()),
Through: NewModel(thru, q.Connection.Context()),
})
return q
}
7 changes: 4 additions & 3 deletions belongs_to_test.go
@@ -1,6 +1,7 @@
package pop

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) {

q := PDB.BelongsTo(&User{ID: 1})

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql)
Expand All @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) {

q := PDB.BelongsToAs(&User{ID: 1}, "u_id")

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql)
Expand All @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) {
q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{})
qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id"

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())
sql, _ := q.ToSQL(m)
r.Equal(ts(qs), sql)
}
10 changes: 10 additions & 0 deletions connection.go
Expand Up @@ -33,6 +33,16 @@ func (c *Connection) URL() string {
return c.Dialect.URL()
}

// Context returns the connection's context set by "Context()" or context.TODO()
// if no context is set.
func (c *Connection) Context() context.Context {
if c, ok := c.Store.(interface{ Context() context.Context }); ok {
return c.Context()
}

return context.TODO()
}

// MigrationURL returns the datasource connection string used for running the migrations
func (c *Connection) MigrationURL() string {
return c.Dialect.MigrationURL()
Expand Down
3 changes: 2 additions & 1 deletion connection_details.go
Expand Up @@ -2,13 +2,14 @@ package pop

import (
"fmt"
"github.com/luna-duclos/instrumentedsql"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/luna-duclos/instrumentedsql"

"github.com/gobuffalo/pop/v5/internal/defaults"
"github.com/gobuffalo/pop/v5/logging"
"github.com/pkg/errors"
Expand Down
1 change: 1 addition & 0 deletions connection_instrumented.go
Expand Up @@ -3,6 +3,7 @@ package pop
import (
"database/sql"
"database/sql/driver"

mysqld "github.com/go-sql-driver/mysql"
"github.com/gobuffalo/pop/v5/logging"
pgx "github.com/jackc/pgx/v4/stdlib"
Expand Down
3 changes: 2 additions & 1 deletion connection_instrumented_nosqlite_test.go
Expand Up @@ -3,8 +3,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestInstrumentation_WithoutSqlite(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions connection_instrumented_test.go
Expand Up @@ -3,12 +3,13 @@ package pop
import (
"context"
"fmt"
"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
"os"
"strings"
"sync"
"time"

"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
)

func testInstrumentedDriver(p *suite.Suite) {
Expand Down
3 changes: 2 additions & 1 deletion dialect_sqlite.go
Expand Up @@ -5,7 +5,6 @@ package pop
import (
"database/sql/driver"
"fmt"
"github.com/mattn/go-sqlite3"
"io"
"net/url"
"os"
Expand All @@ -15,6 +14,8 @@ import (
"sync"
"time"

"github.com/mattn/go-sqlite3"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/fizz/translators"
_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
Expand Down
28 changes: 14 additions & 14 deletions executors.go
Expand Up @@ -13,7 +13,7 @@ import (

// Reload fetch fresh data for a given model, using its ID.
func (c *Connection) Reload(model interface{}) error {
sm := Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.Find(m.Value, m.ID())
})
Expand Down Expand Up @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) {
//
// If model is a slice, each item of the slice is validated then saved in the database.
func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool {
//
// If model is a slice, each item of the slice is saved in the database.
func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand All @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then created in the database.
func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
}
}

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
verrs, err = sm.validateCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {

c.disableEager()

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Create", func() error {
var localIsEager = isEager
Expand Down Expand Up @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
}

if localIsEager {
sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
fbn, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then updated in the database.
func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err
//
// If model is a slice, each item of the slice is deleted from the database.
func (c *Connection) Destroy(model interface{}) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Destroy", func() error {
var err error
Expand Down
14 changes: 7 additions & 7 deletions finders.go
Expand Up @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error {
//
// q.Find(&User{}, 1)
func (q *Query) Find(model interface{}, id interface{}) error {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
idq := m.whereID()
switch t := id.(type) {
case uuid.UUID:
Expand Down Expand Up @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error {
func (q *Query) First(model interface{}) error {
err := q.Connection.timeFunc("First", func() error {
q.Limit(1)
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error {
err := q.Connection.timeFunc("Last", func() error {
q.Limit(1)
q.Order("created_at DESC, id DESC")
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error {
// q.Where("name = ?", "mark").All(&[]User{})
func (q *Query) All(models interface{}) error {
err := q.Connection.timeFunc("All", func() error {
m := &Model{Value: models}
m := NewModel(models, q.Connection.Context())
err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
if err != nil {
return err
Expand Down Expand Up @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
}
}

sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context()))
query = query.RawQuery(sqlSentence, args...)

if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
Expand Down Expand Up @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context()))

// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset
Expand Down Expand Up @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context()))
// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset

Expand Down
2 changes: 1 addition & 1 deletion finders_test.go
Expand Up @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) {

q := tx.Select("name", "email", "\n", "\t\n", "")

sm := &Model{Value: &User{}}
sm := NewModel(new(User), tx.Context())
sql, _ := q.ToSQL(sm)
r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql)

Expand Down
3 changes: 2 additions & 1 deletion match_test.go
@@ -1,8 +1,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func Test_ParseMigrationFilenameFizzDown(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion migration_info_test.go
@@ -1,9 +1,10 @@
package pop

import (
"github.com/stretchr/testify/assert"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSortingMigrations(t *testing.T) {
Expand Down

0 comments on commit 0fb7635

Please sign in to comment.