Skip to content

Commit

Permalink
fix: improve model ID field customization (#604)
Browse files Browse the repository at this point in the history
Updates places where `"id"` was hardcoded instead of using `model.IDField()`.
  • Loading branch information
zepatrik committed Oct 29, 2020
1 parent 6a95bfb commit f36afb5
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 32 deletions.
14 changes: 8 additions & 6 deletions columns/columns.go
Expand Up @@ -13,6 +13,7 @@ type Columns struct {
lock *sync.RWMutex
TableName string
TableAlias string
IDField string
}

// Add a column to the list.
Expand Down Expand Up @@ -74,7 +75,7 @@ func (c *Columns) Add(names ...string) []*Column {
} else if xs[1] == "w" {
col.Readable = false
}
} else if col.Name == "id" {
} else if col.Name == c.IDField {
col.Writeable = false
}

Expand All @@ -98,7 +99,7 @@ func (c *Columns) Remove(names ...string) {

// Writeable gets a list of the writeable columns from the column list.
func (c Columns) Writeable() *WriteableColumns {
w := &WriteableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias)}
w := &WriteableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias, c.IDField)}
for _, col := range c.Cols {
if col.Writeable {
w.Cols[col.Name] = col
Expand All @@ -109,7 +110,7 @@ func (c Columns) Writeable() *WriteableColumns {

// Readable gets a list of the readable columns from the column list.
func (c Columns) Readable() *ReadableColumns {
w := &ReadableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias)}
w := &ReadableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias, c.IDField)}
for _, col := range c.Cols {
if col.Readable {
w.Cols[col.Name] = col
Expand Down Expand Up @@ -157,17 +158,18 @@ func (c Columns) SymbolizedString() string {
}

// NewColumns constructs a list of columns for a given table name.
func NewColumns(tableName string) Columns {
return NewColumnsWithAlias(tableName, "")
func NewColumns(tableName, idField string) Columns {
return NewColumnsWithAlias(tableName, "", idField)
}

// NewColumnsWithAlias constructs a list of columns for a given table
// name, using a given alias for the table.
func NewColumnsWithAlias(tableName string, tableAlias string) Columns {
func NewColumnsWithAlias(tableName, tableAlias, idField string) Columns {
return Columns{
lock: &sync.RWMutex{},
Cols: map[string]*Column{},
TableName: tableName,
TableAlias: tableAlias,
IDField: idField,
}
}
10 changes: 5 additions & 5 deletions columns/columns_for_struct.go
Expand Up @@ -6,17 +6,17 @@ import (

// ForStruct returns a Columns instance for
// the struct passed in.
func ForStruct(s interface{}, tableName string) (columns Columns) {
return ForStructWithAlias(s, tableName, "")
func ForStruct(s interface{}, tableName, idField string) (columns Columns) {
return ForStructWithAlias(s, tableName, "", idField)
}

// ForStructWithAlias returns a Columns instance for the struct passed in.
// If the tableAlias is not empty, it will be used.
func ForStructWithAlias(s interface{}, tableName string, tableAlias string) (columns Columns) {
columns = NewColumnsWithAlias(tableName, tableAlias)
func ForStructWithAlias(s interface{}, tableName, tableAlias, idField string) (columns Columns) {
columns = NewColumnsWithAlias(tableName, tableAlias, idField)
defer func() {
if r := recover(); r != nil {
columns = NewColumnsWithAlias(tableName, tableAlias)
columns = NewColumnsWithAlias(tableName, tableAlias, idField)
columns.Add("*")
}
}()
Expand Down
46 changes: 40 additions & 6 deletions columns/columns_test.go
Expand Up @@ -21,16 +21,16 @@ type foos []foo
func Test_Column_MapsSlice(t *testing.T) {
r := require.New(t)

c1 := columns.ForStruct(&foo{}, "foo")
c2 := columns.ForStruct(&foos{}, "foo")
c1 := columns.ForStruct(&foo{}, "foo", "id")
c2 := columns.ForStruct(&foos{}, "foo", "id")
r.Equal(c1.String(), c2.String())
}

func Test_Columns_Basics(t *testing.T) {
r := require.New(t)

for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
r.Equal(len(c.Cols), 4)
r.Equal(c.Cols["first_name"], &columns.Column{Name: "first_name", Writeable: false, Readable: true, SelectSQL: "first_name as f"})
r.Equal(c.Cols["LastName"], &columns.Column{Name: "LastName", Writeable: true, Readable: true, SelectSQL: "foo.LastName"})
Expand All @@ -43,7 +43,7 @@ func Test_Columns_Add(t *testing.T) {
r := require.New(t)

for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
r.Equal(len(c.Cols), 4)
c.Add("foo", "first_name")
r.Equal(len(c.Cols), 5)
Expand All @@ -55,7 +55,7 @@ func Test_Columns_Remove(t *testing.T) {
r := require.New(t)

for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
r.Equal(len(c.Cols), 4)
c.Remove("foo", "first_name")
r.Equal(len(c.Cols), 3)
Expand All @@ -75,9 +75,43 @@ func (fooQuoter) Quote(key string) string {
func Test_Columns_Sorted(t *testing.T) {
r := require.New(t)

c := columns.ForStruct(fooWithSuffix{}, "fooWithSuffix")
c := columns.ForStruct(fooWithSuffix{}, "fooWithSuffix", "id")
r.Equal(len(c.Cols), 2)
r.Equal(c.SymbolizedString(), ":amount, :amount_units")
r.Equal(c.String(), "amount, amount_units")
r.Equal(c.QuotedString(fooQuoter{}), "`amount`, `amount_units`")
}

func Test_Columns_IDField(t *testing.T) {
type withID struct {
ID string `db:"id"`
}

r := require.New(t)
c := columns.ForStruct(withID{}, "with_id", "id")
r.Equal(1, len(c.Cols), "%+v", c)
r.Equal(&columns.Column{Name: "id", Writeable: false, Readable: true, SelectSQL: "with_id.id"}, c.Cols["id"])
}

func Test_Columns_IDField_Readonly(t *testing.T) {
type withIDReadonly struct {
ID string `db:"id" rw:"r"`
}

r := require.New(t)
c := columns.ForStruct(withIDReadonly{}, "with_id_readonly", "id")
r.Equal(1, len(c.Cols), "%+v", c)
r.Equal(&columns.Column{Name: "id", Writeable: false, Readable: true, SelectSQL: "with_id_readonly.id"}, c.Cols["id"])
}

func Test_Columns_ID_Field_Not_ID(t *testing.T) {
type withNonStandardID struct {
PK string `db:"notid"`
}

r := require.New(t)

c := columns.ForStruct(withNonStandardID{}, "non_standard_id", "notid")
r.Equal(1, len(c.Cols), "%+v", c)
r.Equal(&columns.Column{Name: "notid", Writeable: false, Readable: true, SelectSQL: "non_standard_id.notid"}, c.Cols["notid"])
}
6 changes: 3 additions & 3 deletions columns/readable_columns_test.go
Expand Up @@ -10,7 +10,7 @@ import (
func Test_Columns_ReadableString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Readable().String()
r.Equal(u, "LastName, first_name, read")
}
Expand All @@ -19,7 +19,7 @@ func Test_Columns_ReadableString(t *testing.T) {
func Test_Columns_Readable_SelectString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Readable().SelectString()
r.Equal(u, "first_name as f, foo.LastName, foo.read")
}
Expand All @@ -28,7 +28,7 @@ func Test_Columns_Readable_SelectString(t *testing.T) {
func Test_Columns_ReadableString_Symbolized(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Readable().SymbolizedString()
r.Equal(u, ":LastName, :first_name, :read")
}
Expand Down
8 changes: 4 additions & 4 deletions columns/writeable_columns_test.go
Expand Up @@ -10,7 +10,7 @@ import (
func Test_Columns_WriteableString_Symbolized(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Writeable().SymbolizedString()
r.Equal(u, ":LastName, :write")
}
Expand All @@ -19,7 +19,7 @@ func Test_Columns_WriteableString_Symbolized(t *testing.T) {
func Test_Columns_UpdateString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Writeable().UpdateString()
r.Equal(u, "LastName = :LastName, write = :write")
}
Expand All @@ -35,7 +35,7 @@ func Test_Columns_QuotedUpdateString(t *testing.T) {
r := require.New(t)
q := testQuoter{}
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Writeable().QuotedUpdateString(q)
r.Equal(u, "\"LastName\" = :LastName, \"write\" = :write")
}
Expand All @@ -44,7 +44,7 @@ func Test_Columns_QuotedUpdateString(t *testing.T) {
func Test_Columns_WriteableString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
c := columns.ForStruct(f, "foo", "id")
u := c.Writeable().String()
r.Equal(u, "LastName, write")
}
Expand Down
10 changes: 5 additions & 5 deletions executors.go
Expand Up @@ -228,7 +228,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
}

tn := m.TableName()
cols := columns.ForStructWithAlias(m.Value, tn, m.As)
cols := m.Columns()

if tn == sm.TableName() {
cols.Remove(excludeColumns...)
Expand Down Expand Up @@ -350,8 +350,8 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
}

tn := m.TableName()
cols := columns.ForStructWithAlias(model, tn, m.As)
cols.Remove("id", "created_at")
cols := columns.ForStructWithAlias(model, tn, m.As, m.IDField())
cols.Remove(m.IDField(), "created_at")

if tn == sm.TableName() {
cols.Remove(excludeColumns...)
Expand Down Expand Up @@ -393,11 +393,11 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err

cols := columns.Columns{}
if len(columnNames) > 0 && tn == sm.TableName() {
cols = columns.NewColumnsWithAlias(tn, m.As)
cols = columns.NewColumnsWithAlias(tn, m.As, sm.IDField())
cols.Add(columnNames...)

} else {
cols = columns.ForStructWithAlias(model, tn, m.As)
cols = columns.ForStructWithAlias(model, tn, m.As, m.IDField())
}
cols.Remove("id", "created_at")

Expand Down
70 changes: 70 additions & 0 deletions executors_test.go
Expand Up @@ -510,6 +510,28 @@ func Test_Create_With_Non_ID_PK_String(t *testing.T) {
})
}

func Test_Create_Non_PK_ID(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

r.NoError(tx.Create(&NonStandardID{OutfacingID: "make sure the tested entry does not have pk=0"}))

count, err := tx.Count(&NonStandardID{})
entry := &NonStandardID{
OutfacingID: "beautiful to the outside ID",
}
r.NoError(tx.Create(entry))

ctx, err := tx.Count(&NonStandardID{})
r.NoError(err)
r.Equal(count+1, ctx)
r.NotZero(entry.ID)
})
}

func Test_Eager_Create_Has_Many(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down Expand Up @@ -1470,6 +1492,54 @@ func Test_Update_UUID(t *testing.T) {
})
}

func Test_Update_With_Non_ID_PK(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

r.NoError(tx.Create(&CrookedColour{Name: "cc is not the first one"}))

cc := CrookedColour{
Name: "You?",
}
err := tx.Create(&cc)
r.NoError(err)
r.NotZero(cc.ID)
id := cc.ID

updatedName := "Me!"
cc.Name = updatedName
r.NoError(tx.Update(&cc))
r.Equal(id, cc.ID)

r.NoError(tx.Reload(&cc))
r.Equal(updatedName, cc.Name)
r.Equal(id, cc.ID)
})
}

func Test_Update_Non_PK_ID(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

client := &NonStandardID{
OutfacingID: "my awesome hydra client",
}
r.NoError(tx.Create(client))

updatedID := "your awesome hydra client"
client.OutfacingID = updatedID
r.NoError(tx.Update(client))
r.NoError(tx.Reload(client))
r.Equal(updatedID, client.OutfacingID)
})
}

func Test_Destroy(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down
18 changes: 17 additions & 1 deletion model.go
Expand Up @@ -2,6 +2,7 @@ package pop

import (
"fmt"
"github.com/gobuffalo/pop/v5/columns"
"github.com/pkg/errors"
"reflect"
"sync"
Expand Down Expand Up @@ -46,7 +47,18 @@ func (m *Model) ID() interface{} {
// IDField returns the name of the DB field used for the ID.
// By default, it will return "id".
func (m *Model) IDField() string {
field, ok := reflect.TypeOf(m.Value).Elem().FieldByName("ID")
modelType := reflect.TypeOf(m.Value)

// remove all indirections
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}

if modelType.Kind() == reflect.String {
return "id"
}

field, ok := modelType.FieldByName("ID")
if !ok {
return "id"
}
Expand Down Expand Up @@ -101,6 +113,10 @@ func (m *Model) TableName() string {
return tableMap[cacheKey]
}

func (m *Model) Columns() columns.Columns {
return columns.ForStructWithAlias(m.Value, m.TableName(), m.As, m.IDField())
}

func (m *Model) cacheKey(t reflect.Type) string {
return t.PkgPath() + "." + t.Name()
}
Expand Down
5 changes: 5 additions & 0 deletions pop_test.go
Expand Up @@ -419,3 +419,8 @@ type CrookedSong struct {
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

type NonStandardID struct {
ID int `db:"pk"`
OutfacingID string `db:"id"`
}

0 comments on commit f36afb5

Please sign in to comment.