Skip to content

Commit

Permalink
Add support for pointer FKs when preloading a belongs_to association (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
reggieriser committed Jan 5, 2021
1 parent 0e3d2e2 commit b2918a3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 12 deletions.
18 changes: 10 additions & 8 deletions pop_test.go
Expand Up @@ -138,14 +138,16 @@ type Book struct {
}

type Taxi struct {
ID int `db:"id"`
Model string `db:"model"`
UserID nulls.Int `db:"user_id"`
AddressID nulls.Int `db:"address_id"`
Driver *User `belongs_to:"user" fk_id:"user_id"`
Address Address `belongs_to:"address"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
ID int `db:"id"`
Model string `db:"model"`
UserID nulls.Int `db:"user_id"`
AddressID nulls.Int `db:"address_id"`
Driver *User `belongs_to:"user" fk_id:"user_id"`
Address Address `belongs_to:"address"`
ToAddressID *int `db:"to_address_id"`
ToAddress *Address `belongs_to:"address"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

// Validate gets run every time you call a "Validate*" (ValidateAndSave, ValidateAndCreate, ValidateAndUpdate) method.
Expand Down
17 changes: 14 additions & 3 deletions preload_associations.go
Expand Up @@ -355,7 +355,9 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI

fkids := []interface{}{}
mmi.iterate(func(val reflect.Value) {
fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface())
if !isFieldNilPtr(val, fi) {
fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface())
}
})

if len(fkids) == 0 {
Expand Down Expand Up @@ -386,11 +388,15 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI

// 3) iterate over every model and fill it with the assoc.
mmi.iterate(func(mvalue reflect.Value) {
if isFieldNilPtr(mvalue, fi) {
return
}
modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
for i := 0; i < slice.Elem().Len(); i++ {
asocValue := slice.Elem().Index(i)
if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() ||
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) {
fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path))
if fkField.Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() ||
reflect.DeepEqual(fkField, mmi.mapper.FieldByName(asocValue, "ID")) {

switch {
case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
Expand Down Expand Up @@ -499,3 +505,8 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta
}
return nil
}

func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool {
fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index)
return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()
}
33 changes: 33 additions & 0 deletions preload_associations_test.go
Expand Up @@ -231,3 +231,36 @@ func Test_New_Implementation_For_BelongsTo_Multiple_Fields(t *testing.T) {
SetEagerMode(EagerDefault)
})
}

func Test_New_Implementation_For_BelongsTo_Ptr_Field(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
a := require.New(t)
toAddress := Address{HouseNumber: 1, Street: "Destination Ave"}
a.NoError(tx.Create(&toAddress))

taxi := Taxi{ToAddressID: &toAddress.ID}
a.NoError(tx.Create(&taxi))

book1 := Book{TaxiID: nulls.NewInt(taxi.ID), Title: "My Book"}
a.NoError(tx.Create(&book1))

taxiNilToAddress := Taxi{ToAddressID: nil}
a.NoError(tx.Create(&taxiNilToAddress))

book2 := Book{TaxiID: nulls.NewInt(taxiNilToAddress.ID), Title: "Another Book"}
a.NoError(tx.Create(&book2))

SetEagerMode(EagerPreload)
books := []Book{}
a.NoError(tx.EagerPreload("Taxi.ToAddress").Order("created_at").All(&books))
a.Len(books, 2)
a.Equal(toAddress.Street, books[0].Taxi.ToAddress.Street)
a.NotNil(books[0].Taxi.ToAddressID)
a.Nil(books[1].Taxi.ToAddress)
a.Nil(books[1].Taxi.ToAddressID)
SetEagerMode(EagerDefault)
})
}
3 changes: 2 additions & 1 deletion testdata/migrations/20181104135856_taxis.up.fizz
Expand Up @@ -2,6 +2,7 @@ create_table("taxis") {
t.Column("id", "int", {primary: true})
t.Column("model", "string", {})
t.Column("user_id", "int", {"null": true})
t.Column("address_id", "int",{"null":true})
t.Column("address_id", "int", {"null":true})
t.Column("to_address_id", "int", {"null":true})
t.Timestamps()
}

0 comments on commit b2918a3

Please sign in to comment.