diff --git a/associations/association.go b/associations/association.go index 436c4581..8ed165a4 100644 --- a/associations/association.go +++ b/associations/association.go @@ -160,5 +160,9 @@ func fieldIsNil(f reflect.Value) bool { // IsZeroOfUnderlyingType will check if the value of anything is the equal to the Zero value of that type. func IsZeroOfUnderlyingType(x interface{}) bool { + if x == nil { + return true + } + return reflect.DeepEqual(x, reflect.Zero(reflect.TypeOf(x)).Interface()) } diff --git a/associations/association_test.go b/associations/association_test.go new file mode 100644 index 00000000..504c52e9 --- /dev/null +++ b/associations/association_test.go @@ -0,0 +1,37 @@ +package associations + +import ( + "database/sql" + "fmt" + "github.com/gobuffalo/nulls" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_IsZeroOfUnderlyingType(t *testing.T) { + for k, tc := range []struct { + in interface{} + zero bool + }{ + {in: nil, zero: true}, + {in: 0, zero: true}, + {in: 1, zero: false}, + {in: false, zero: true}, + {in: "", zero: true}, + {in: interface{}(nil), zero: true}, + {in: uuid.NullUUID{}, zero: true}, + {in: uuid.UUID{}, zero: true}, + {in: uuid.NullUUID{Valid: true}, zero: false}, + {in: nulls.Int{}, zero: true}, + {in: nulls.String{}, zero: true}, + {in: nulls.Bool{}, zero: true}, + {in: nulls.Float64{}, zero: true}, + {in: sql.NullString{}, zero: true}, + {in: sql.NullString{Valid: true}, zero: false}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + assert.EqualValues(t, tc.zero, IsZeroOfUnderlyingType(tc.in)) + }) + } +} diff --git a/associations/associations_for_struct.go b/associations/associations_for_struct.go index 67c8ca76..544e8786 100644 --- a/associations/associations_for_struct.go +++ b/associations/associations_for_struct.go @@ -1,7 +1,6 @@ package associations import ( - "errors" "fmt" "reflect" "regexp" @@ -30,7 +29,7 @@ var associationBuilders = map[string]associationBuilder{} func ForStruct(s interface{}, fields ...string) (Associations, error) { t, v := getModelDefinition(s) if t.Kind() != reflect.Struct { - return nil, errors.New("could not get struct associations: not a struct") + return nil, fmt.Errorf("could not get struct associations: not a struct but %T", s) } fields = trimFields(fields) associations := Associations{} diff --git a/associations/belongs_to_association.go b/associations/belongs_to_association.go index fce161d0..c64d7208 100644 --- a/associations/belongs_to_association.go +++ b/associations/belongs_to_association.go @@ -128,15 +128,22 @@ func (b *belongsToAssociation) BeforeInterface() interface{} { func (b *belongsToAssociation) BeforeSetup() error { ownerID := reflect.Indirect(reflect.ValueOf(b.ownerModel.Interface())).FieldByName("ID") - if b.ownerID.CanSet() { - if n := nulls.New(b.ownerID.Interface()); n != nil { - b.ownerID.Set(reflect.ValueOf(n.Parse(ownerID.Interface()))) - } else if b.ownerID.Kind() == reflect.Ptr { - b.ownerID.Set(ownerID.Addr()) + toSet := b.ownerID + switch b.ownerID.Type().Name() { + case "NullUUID": + b.ownerID.FieldByName("Valid").Set(reflect.ValueOf(true)) + toSet = b.ownerID.FieldByName("UUID") + } + + if toSet.CanSet() { + if n := nulls.New(toSet.Interface()); n != nil { + toSet.Set(reflect.ValueOf(n.Parse(ownerID.Interface()))) + } else if toSet.Kind() == reflect.Ptr { + toSet.Set(ownerID.Addr()) } else { - b.ownerID.Set(ownerID) + toSet.Set(ownerID) } return nil } - return fmt.Errorf("could not set '%s' to '%s'", ownerID, b.ownerID) + return fmt.Errorf("could not set '%s' to '%s'", ownerID, toSet) } diff --git a/associations/belongs_to_association_test.go b/associations/belongs_to_association_test.go index f7a09ba3..ca72c137 100644 --- a/associations/belongs_to_association_test.go +++ b/associations/belongs_to_association_test.go @@ -22,6 +22,11 @@ type barBelongsTo struct { Foo fooBelongsTo `belongs_to:"foo"` } +type barBelongsToNullable struct { + FooID uuid.NullUUID `db:"foo_id"` + Foo *fooBelongsTo `belongs_to:"foo"` +} + func Test_Belongs_To_Association(t *testing.T) { a := require.New(t) @@ -50,3 +55,17 @@ func Test_Belongs_To_Association(t *testing.T) { a.Equal(nil, before[index].BeforeInterface()) } } + +func Test_Belongs_To_Nullable_Association(t *testing.T) { + a := require.New(t) + id, _ := uuid.NewV1() + + bar := barBelongsToNullable{Foo: &fooBelongsTo{id}} + as, err := associations.ForStruct(&bar, "Foo") + a.NoError(err) + + before := as.AssociationsBeforeCreatable() + for index := range before { + a.Equal(nil, before[index].BeforeSetup()) + } +} diff --git a/finders.go b/finders.go index c90d1032..ccd4bbd5 100644 --- a/finders.go +++ b/finders.go @@ -282,7 +282,14 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) innerQuery := Q(query.Connection) innerQuery.eagerFields = inner.Fields - err = innerQuery.eagerAssociations(v.Addr().Interface()) + + switch v.Kind() { + case reflect.Ptr: + err = innerQuery.eagerAssociations(v.Interface()) + default: + err = innerQuery.eagerAssociations(v.Addr().Interface()) + } + if err != nil { return err } diff --git a/finders_test.go b/finders_test.go index 7e2daa02..cc5bd925 100644 --- a/finders_test.go +++ b/finders_test.go @@ -314,6 +314,42 @@ func Test_Find_Eager_Has_One(t *testing.T) { }) } +func Test_Find_Eager_Has_One_With_Inner_Associations_Pointer(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + user := UserPointerAssocs{Name: nulls.NewString("Mark")} + err := tx.Create(&user) + r.NoError(err) + + coolSong := Song{Title: "Hook - Blues Traveler", UserID: user.ID} + err = tx.Create(&coolSong) + r.NoError(err) + + u := UserPointerAssocs{} + err = tx.Eager("FavoriteSong.ComposedBy").Find(&u, user.ID) + r.NoError(err) + + r.NotEqual(u.ID, 0) + r.Equal(u.Name.String, "Mark") + r.Equal(u.FavoriteSong.ID, coolSong.ID) + + // eager should work with rawquery + uid := u.ID + u = UserPointerAssocs{} + err = tx.RawQuery("select * from users where id=?", uid).First(&u) + r.NoError(err) + r.Nil(u.FavoriteSong) + + err = tx.RawQuery("select * from users where id=?", uid).Eager("FavoriteSong").First(&u) + r.NoError(err) + r.Equal(u.FavoriteSong.ID, coolSong.ID) + }) +} + func Test_Find_Eager_Has_One_With_Inner_Associations_Struct(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") diff --git a/pop_test.go b/pop_test.go index bfe7ea78..0166e9d0 100644 --- a/pop_test.go +++ b/pop_test.go @@ -106,6 +106,27 @@ type User struct { Houses Addresses `many_to_many:"users_addresses"` } +type UserPointerAssocs struct { + ID int `db:"id"` + UserName string `db:"user_name"` + Email string `db:"email"` + Name nulls.String `db:"name"` + Alive nulls.Bool `db:"alive"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + BirthDate nulls.Time `db:"birth_date"` + Bio nulls.String `db:"bio"` + Price nulls.Float64 `db:"price"` + FullName nulls.String `db:"full_name" select:"name as full_name"` + Books Books `has_many:"books" order_by:"title asc" fk_id:"user_id"` + FavoriteSong *Song `has_one:"song" fk_id:"u_id"` + Houses Addresses `many_to_many:"users_addresses"` +} + +func (UserPointerAssocs) TableName() string { + return "users" +} + // Validate gets run every time you call a "Validate*" (ValidateAndSave, ValidateAndCreate, ValidateAndUpdate) method. // This method is not required and may be deleted. func (u *User) Validate(tx *Connection) (*validate.Errors, error) { @@ -277,6 +298,23 @@ type CourseCode struct { // Course Course `belongs_to:"course"` } +type NetClient struct { + ID uuid.UUID `json:"id" db:"id"` + Hops []Hop `json:"hop_id" has_many:"hops"` +} + +type Hop struct { + ID uuid.UUID `json:"id" db:"id"` + NetClient *NetClient `json:"net_client" belongs_to:"net_client" fk_id:"NetClientID"` + NetClientID uuid.UUID `json:"net_client_id" db:"net_client_id"` + Server *Server `json:"course" belongs_to:"server" fk_id:"ServerID" oder_by:"id asc"` + ServerID uuid.NullUUID `json:"server_id" db:"server_id"` +} + +type Server struct { + ID uuid.UUID `json:"id" db:"id"` +} + type ValidatableCar struct { ID int64 `db:"id"` Name string `db:"name"` diff --git a/preload_associations.go b/preload_associations.go index 72a7f4c3..6e8fec10 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -268,13 +268,18 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf // 3) iterate over every model and fill it with the assoc. foreignField := asoc.getDBFieldTaggedWith(fk) mmi.iterate(func(mvalue reflect.Value) { - modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) valueField := reflect.Indirect(mmi.mapper.FieldByName(asocValue, foreignField.Path)) if mmi.mapper.FieldByName(mvalue, "ID").Interface() == valueField.Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), valueField) { - + // IMPORTANT + // + // FieldByName will initialize the value. It is important that this happens AFTER + // we checked whether the field should be set. Otherwise, we'll set a zero value! + // + // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 + modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) switch { case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) @@ -330,16 +335,25 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo // 3) iterate over every model and fill it with the assoc. foreignField := asoc.getDBFieldTaggedWith(fk) mmi.iterate(func(mvalue reflect.Value) { - modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { - if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { + // IMPORTANT + // + // FieldByName will initialize the value. It is important that this happens AFTER + // we checked whether the field should be set. Otherwise, we'll set a zero value! + // + // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 + modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) + switch { + case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - continue + case modelAssociationField.Kind() == reflect.Ptr: + modelAssociationField.Elem().Set(asocValue) + default: + modelAssociationField.Set(asocValue) } - modelAssociationField.Set(asocValue) } } }) @@ -392,13 +406,18 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI if isFieldNilPtr(mvalue, fi) { return } - modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path)) - if fkField.Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || - reflect.DeepEqual(fkField, mmi.mapper.FieldByName(asocValue, "ID")) { - + field := mmi.mapper.FieldByName(asocValue, "ID") + if fkField.Interface() == field.Interface() || reflect.DeepEqual(fkField, field) { + // IMPORTANT + // + // FieldByName will initialize the value. It is important that this happens AFTER + // we checked whether the field should be set. Otherwise, we'll set a zero value! + // + // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 + modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) switch { case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) @@ -497,11 +516,17 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta mmi.iterate(func(mvalue reflect.Value) { id := mmi.mapper.FieldByName(mvalue, "ID").Interface() if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok { - modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) for _, fkid := range assocFkIds { if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) { + // IMPORTANT + // + // FieldByName will initialize the value. It is important that this happens AFTER + // we checked whether the field should be set. Otherwise, we'll set a zero value! + // + // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 + modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) } } diff --git a/preload_associations_test.go b/preload_associations_test.go index 879531c2..3d70bad9 100644 --- a/preload_associations_test.go +++ b/preload_associations_test.go @@ -1,10 +1,9 @@ package pop import ( - "testing" - "github.com/gobuffalo/nulls" "github.com/stretchr/testify/require" + "testing" ) func Test_New_Implementation_For_Nplus1(t *testing.T) { @@ -53,6 +52,18 @@ func Test_New_Implementation_For_Nplus1(t *testing.T) { a.Len(book.Writers, 1) a.Equal("Larry", book.Writers[0].Name) a.Equal("Mark", book.User.Name.String) + + usersWithPointers := []UserPointerAssocs{} + a.NoError(tx.All(&usersWithPointers)) + + // FILL THE HAS-MANY and HAS_ONE + a.NoError(preload(tx, &usersWithPointers)) + + a.Len(usersWithPointers[0].Books, 1) + a.Len(usersWithPointers[1].Books, 1) + a.Len(usersWithPointers[2].Books, 1) + a.Equal(usersWithPointers[0].FavoriteSong.UserID, users[0].ID) + a.Len(usersWithPointers[0].Houses, 1) }) } @@ -101,6 +112,68 @@ func Test_New_Implementation_For_Nplus1_With_UUID(t *testing.T) { }) } +func Test_New_Implementation_For_Nplus1_With_NullUUIDs_And_FK_ID(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + + // This test suite prevents regressions of an obscure bug in the preload code which caused + // pointer values to be set with their empty values when relations did not exist. + // + // See also: https://github.com/gobuffalo/pop/issues/139 + transaction(func(tx *Connection) { + a := require.New(t) + + var server Server + a.NoError(tx.Create(&server)) + + class := &NetClient{ + // The bug only appears when we have two elements in the slice where + // one has a relation and the other one has no such relation. + Hops: []Hop{ + {Server: &server}, + {}, + }} + + // This code basically just sets up + a.NoError(tx.Eager().Create(class)) + + var expected NetClient + a.NoError(tx.EagerPreload("Hops.Server").First(&expected)) + + // What would happen before the patch resolved this issue is that: + // + // Classes.CourseCodes[0].Course would be the correct value (a filled struct) + // + // "server": { + // "id": "fa51f71f-e884-4641-8005-923258b814f9", + // "created_at": "2021-12-09T23:20:10.208019+01:00", + // "updated_at": "2021-12-09T23:20:10.208019+01:00" + // }, + // + // Classes.CourseCodes[1].Course would an "empty" struct of Course even though there is no relation set up: + // + // "server": { + // "id": "00000000-0000-0000-0000-000000000000", + // "created_at": "0001-01-01T00:00:00Z", + // "updated_at": "0001-01-01T00:00:00Z" + // }, + var foundValid, foundEmpty int + for _, hop := range expected.Hops { + if hop.ServerID.Valid { + foundValid++ + a.NotNil(hop.Server, "%+v", hop) + } else { + foundEmpty++ + a.Nil(hop.Server, "%+v", hop) + } + } + + a.Equal(1, foundValid) + a.Equal(1, foundEmpty) + }) +} + func Test_New_Implementation_For_Nplus1_Single(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") diff --git a/testdata/migrations/20210104145902_network.down.fizz b/testdata/migrations/20210104145902_network.down.fizz new file mode 100644 index 00000000..151bae5c --- /dev/null +++ b/testdata/migrations/20210104145902_network.down.fizz @@ -0,0 +1,3 @@ +drop_table("clients") +drop_table("hops") +drop_table("servers") diff --git a/testdata/migrations/20210104145902_network.up.fizz b/testdata/migrations/20210104145902_network.up.fizz new file mode 100644 index 00000000..dc624e6c --- /dev/null +++ b/testdata/migrations/20210104145902_network.up.fizz @@ -0,0 +1,18 @@ +create_table("net_clients") { + t.Column("id", "uuid", {"primary": true}) + t.DisableTimestamps() +} + +create_table("servers") { + t.Column("id", "uuid", {"primary": true}) + t.DisableTimestamps() +} + +create_table("hops") { + t.Column("id", "uuid", {"primary": true}) + t.Column("server_id", "uuid", {"null":true}) + t.ForeignKey("server_id", {"servers": ["id"]}, {}) + t.Column("net_client_id", "uuid") + t.ForeignKey("net_client_id", {"net_clients": ["id"]}, {}) + t.DisableTimestamps() +}