Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support context-aware tablenames #614

Merged
merged 5 commits into from Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions belongs_to.go
Expand Up @@ -19,15 +19,15 @@ func (c *Connection) BelongsToAs(model interface{}, as string) *Query {
// BelongsTo adds a "where" clause based on the "ID" of the
// "model" passed into it.
func (q *Query) BelongsTo(model interface{}) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID())
return q
}

// BelongsToAs adds a "where" clause based on the "ID" of the
// "model" passed into it, using an alias.
func (q *Query) BelongsToAs(model interface{}, as string) *Query {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
q.Where(fmt.Sprintf("%s = ?", as), m.ID())
return q
}
Expand All @@ -42,8 +42,8 @@ func (c *Connection) BelongsToThrough(bt, thru interface{}) *Query {
// through the associated "thru" model.
func (q *Query) BelongsToThrough(bt, thru interface{}) *Query {
q.belongsToThroughClauses = append(q.belongsToThroughClauses, belongsToThroughClause{
BelongsTo: &Model{Value: bt},
Through: &Model{Value: thru},
BelongsTo: NewModel(bt, q.Connection.Context()),
Through: NewModel(thru, q.Connection.Context()),
})
return q
}
7 changes: 4 additions & 3 deletions belongs_to_test.go
@@ -1,6 +1,7 @@
package pop

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -14,7 +15,7 @@ func Test_BelongsTo(t *testing.T) {

q := PDB.BelongsTo(&User{ID: 1})

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE user_id = ?"), sql)
Expand All @@ -28,7 +29,7 @@ func Test_BelongsToAs(t *testing.T) {

q := PDB.BelongsToAs(&User{ID: 1}, "u_id")

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())

sql, _ := q.ToSQL(m)
r.Equal(ts("SELECT enemies.A FROM enemies AS enemies WHERE u_id = ?"), sql)
Expand All @@ -43,7 +44,7 @@ func Test_BelongsToThrough(t *testing.T) {
q := PDB.BelongsToThrough(&User{ID: 1}, &Friend{})
qs := "SELECT enemies.A FROM enemies AS enemies, good_friends AS good_friends WHERE good_friends.user_id = ? AND enemies.id = good_friends.enemy_id"

m := &Model{Value: &Enemy{}}
m := NewModel(new(Enemy), context.Background())
sql, _ := q.ToSQL(m)
r.Equal(ts(qs), sql)
}
10 changes: 10 additions & 0 deletions connection.go
Expand Up @@ -33,6 +33,16 @@ func (c *Connection) URL() string {
return c.Dialect.URL()
}

// Context returns the connection's context set by "Context()" or context.TODO()
// if no context is set.
func (c *Connection) Context() context.Context {
if c, ok := c.Store.(interface{ Context() context.Context }); ok {
return c.Context()
}

return context.TODO()
}

// MigrationURL returns the datasource connection string used for running the migrations
func (c *Connection) MigrationURL() string {
return c.Dialect.MigrationURL()
Expand Down
3 changes: 2 additions & 1 deletion connection_details.go
Expand Up @@ -2,13 +2,14 @@ package pop

import (
"fmt"
"github.com/luna-duclos/instrumentedsql"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/luna-duclos/instrumentedsql"

"github.com/gobuffalo/pop/v5/internal/defaults"
"github.com/gobuffalo/pop/v5/logging"
"github.com/pkg/errors"
Expand Down
1 change: 1 addition & 0 deletions connection_instrumented.go
Expand Up @@ -3,6 +3,7 @@ package pop
import (
"database/sql"
"database/sql/driver"

mysqld "github.com/go-sql-driver/mysql"
"github.com/gobuffalo/pop/v5/logging"
pgx "github.com/jackc/pgx/v4/stdlib"
Expand Down
3 changes: 2 additions & 1 deletion connection_instrumented_nosqlite_test.go
Expand Up @@ -3,8 +3,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestInstrumentation_WithoutSqlite(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions connection_instrumented_test.go
Expand Up @@ -3,12 +3,13 @@ package pop
import (
"context"
"fmt"
"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
"os"
"strings"
"sync"
"time"

"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
)

func testInstrumentedDriver(p *suite.Suite) {
Expand Down
3 changes: 2 additions & 1 deletion dialect_sqlite.go
Expand Up @@ -5,7 +5,6 @@ package pop
import (
"database/sql/driver"
"fmt"
"github.com/mattn/go-sqlite3"
"io"
"net/url"
"os"
Expand All @@ -15,6 +14,8 @@ import (
"sync"
"time"

"github.com/mattn/go-sqlite3"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/fizz/translators"
_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
Expand Down
28 changes: 14 additions & 14 deletions executors.go
Expand Up @@ -13,7 +13,7 @@ import (

// Reload fetch fresh data for a given model, using its ID.
func (c *Connection) Reload(model interface{}) error {
sm := Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.Find(m.Value, m.ID())
})
Expand Down Expand Up @@ -51,7 +51,7 @@ func (q *Query) ExecWithCount() (int, error) {
//
// If model is a slice, each item of the slice is validated then saved in the database.
func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -77,7 +77,7 @@ func IsZeroOfUnderlyingType(x interface{}) bool {
//
// If model is a slice, each item of the slice is saved in the database.
func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand All @@ -95,7 +95,7 @@ func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then created in the database.
func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -140,14 +140,14 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
verrs, err := sm.validateAndOnlyCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
}
}

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
verrs, err = sm.validateCreate(c)
if err != nil || verrs.HasAny() {
return verrs, err
Expand All @@ -170,7 +170,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {

c.disableEager()

sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Create", func() error {
var localIsEager = isEager
Expand Down Expand Up @@ -203,7 +203,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
}

if localIsEager {
sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
id, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -255,7 +255,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
continue
}

sm := &Model{Value: i}
sm := NewModel(i, c.Context())
err = sm.iterate(func(m *Model) error {
fbn, err := m.fieldByName("ID")
if err != nil {
Expand Down Expand Up @@ -318,7 +318,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is validated then updated in the database.
func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
if err := sm.beforeValidate(c); err != nil {
return nil, err
}
Expand All @@ -337,7 +337,7 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -377,7 +377,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
//
// If model is a slice, each item of the slice is updated in the database.
func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Update", func() error {
var err error
Expand Down Expand Up @@ -419,7 +419,7 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err
//
// If model is a slice, each item of the slice is deleted from the database.
func (c *Connection) Destroy(model interface{}) error {
sm := &Model{Value: model}
sm := NewModel(model, c.Context())
return sm.iterate(func(m *Model) error {
return c.timeFunc("Destroy", func() error {
var err error
Expand Down
14 changes: 7 additions & 7 deletions finders.go
Expand Up @@ -29,7 +29,7 @@ func (c *Connection) Find(model interface{}, id interface{}) error {
//
// q.Find(&User{}, 1)
func (q *Query) Find(model interface{}, id interface{}) error {
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
idq := m.whereID()
switch t := id.(type) {
case uuid.UUID:
Expand Down Expand Up @@ -69,7 +69,7 @@ func (c *Connection) First(model interface{}) error {
func (q *Query) First(model interface{}) error {
err := q.Connection.timeFunc("First", func() error {
q.Limit(1)
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func (q *Query) Last(model interface{}) error {
err := q.Connection.timeFunc("Last", func() error {
q.Limit(1)
q.Order("created_at DESC, id DESC")
m := &Model{Value: model}
m := NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
return err
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (c *Connection) All(models interface{}) error {
// q.Where("name = ?", "mark").All(&[]User{})
func (q *Query) All(models interface{}) error {
err := q.Connection.timeFunc("All", func() error {
m := &Model{Value: models}
m := NewModel(models, q.Connection.Context())
err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
if err != nil {
return err
Expand Down Expand Up @@ -258,7 +258,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
}
}

sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context()))
query = query.RawQuery(sqlSentence, args...)

if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
Expand Down Expand Up @@ -302,7 +302,7 @@ func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context()))

// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset
Expand Down Expand Up @@ -348,7 +348,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) {
tmpQuery.Paginator = nil
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context()))
// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset

Expand Down
2 changes: 1 addition & 1 deletion finders_test.go
Expand Up @@ -101,7 +101,7 @@ func Test_Select(t *testing.T) {

q := tx.Select("name", "email", "\n", "\t\n", "")

sm := &Model{Value: &User{}}
sm := NewModel(new(User), tx.Context())
sql, _ := q.ToSQL(sm)
r.Equal(tx.Dialect.TranslateSQL("SELECT email, name FROM users AS users"), sql)

Expand Down
3 changes: 2 additions & 1 deletion match_test.go
@@ -1,8 +1,9 @@
package pop

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func Test_ParseMigrationFilenameFizzDown(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion migration_info_test.go
@@ -1,9 +1,10 @@
package pop

import (
"github.com/stretchr/testify/assert"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSortingMigrations(t *testing.T) {
Expand Down