diff --git a/pop_test.go b/pop_test.go index 4ac25b68..1d473de0 100644 --- a/pop_test.go +++ b/pop_test.go @@ -138,14 +138,16 @@ type Book struct { } type Taxi struct { - ID int `db:"id"` - Model string `db:"model"` - UserID nulls.Int `db:"user_id"` - AddressID nulls.Int `db:"address_id"` - Driver *User `belongs_to:"user" fk_id:"user_id"` - Address Address `belongs_to:"address"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` + ID int `db:"id"` + Model string `db:"model"` + UserID nulls.Int `db:"user_id"` + AddressID nulls.Int `db:"address_id"` + Driver *User `belongs_to:"user" fk_id:"user_id"` + Address Address `belongs_to:"address"` + ToAddressID *int `db:"to_address_id"` + ToAddress *Address `belongs_to:"address"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` } // Validate gets run every time you call a "Validate*" (ValidateAndSave, ValidateAndCreate, ValidateAndUpdate) method. diff --git a/preload_associations.go b/preload_associations.go index 4a03b579..8ddea3c9 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -355,7 +355,9 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI fkids := []interface{}{} mmi.iterate(func(val reflect.Value) { - fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface()) + if !isFieldNilPtr(val, fi) { + fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface()) + } }) if len(fkids) == 0 { @@ -386,11 +388,15 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI // 3) iterate over every model and fill it with the assoc. mmi.iterate(func(mvalue reflect.Value) { + if isFieldNilPtr(mvalue, fi) { + return + } modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) - if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || - reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) { + fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path)) + if fkField.Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || + reflect.DeepEqual(fkField, mmi.mapper.FieldByName(asocValue, "ID")) { switch { case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: @@ -499,3 +505,8 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta } return nil } + +func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool { + fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index) + return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() +} diff --git a/preload_associations_test.go b/preload_associations_test.go index af1f392f..9d735fd4 100644 --- a/preload_associations_test.go +++ b/preload_associations_test.go @@ -231,3 +231,36 @@ func Test_New_Implementation_For_BelongsTo_Multiple_Fields(t *testing.T) { SetEagerMode(EagerDefault) }) } + +func Test_New_Implementation_For_BelongsTo_Ptr_Field(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + a := require.New(t) + toAddress := Address{HouseNumber: 1, Street: "Destination Ave"} + a.NoError(tx.Create(&toAddress)) + + taxi := Taxi{ToAddressID: &toAddress.ID} + a.NoError(tx.Create(&taxi)) + + book1 := Book{TaxiID: nulls.NewInt(taxi.ID), Title: "My Book"} + a.NoError(tx.Create(&book1)) + + taxiNilToAddress := Taxi{ToAddressID: nil} + a.NoError(tx.Create(&taxiNilToAddress)) + + book2 := Book{TaxiID: nulls.NewInt(taxiNilToAddress.ID), Title: "Another Book"} + a.NoError(tx.Create(&book2)) + + SetEagerMode(EagerPreload) + books := []Book{} + a.NoError(tx.EagerPreload("Taxi.ToAddress").Order("created_at").All(&books)) + a.Len(books, 2) + a.Equal(toAddress.Street, books[0].Taxi.ToAddress.Street) + a.NotNil(books[0].Taxi.ToAddressID) + a.Nil(books[1].Taxi.ToAddress) + a.Nil(books[1].Taxi.ToAddressID) + SetEagerMode(EagerDefault) + }) +} diff --git a/testdata/migrations/20181104135856_taxis.up.fizz b/testdata/migrations/20181104135856_taxis.up.fizz index 9cf85238..a11185c0 100644 --- a/testdata/migrations/20181104135856_taxis.up.fizz +++ b/testdata/migrations/20181104135856_taxis.up.fizz @@ -2,6 +2,7 @@ create_table("taxis") { t.Column("id", "int", {primary: true}) t.Column("model", "string", {}) t.Column("user_id", "int", {"null": true}) - t.Column("address_id", "int",{"null":true}) + t.Column("address_id", "int", {"null":true}) + t.Column("to_address_id", "int", {"null":true}) t.Timestamps() } \ No newline at end of file