diff --git a/associations/associations_for_struct.go b/associations/associations_for_struct.go index bdd7e180..529b210b 100644 --- a/associations/associations_for_struct.go +++ b/associations/associations_for_struct.go @@ -27,6 +27,11 @@ var associationBuilders = map[string]associationBuilder{} // it throws an error when it finds a field that does // not exist for a model. func ForStruct(s interface{}, fields ...string) (Associations, error) { + return forStruct(s, s, fields) +} + +// forStruct is a recursive helper that passes the root model down for embedded fields +func forStruct(parent, s interface{}, fields []string) (Associations, error) { t, v := getModelDefinition(s) if t.Kind() != reflect.Struct { return nil, fmt.Errorf("could not get struct associations: not a struct but %T", s) @@ -74,7 +79,20 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) { // inline embedded field if f.Anonymous { - innerAssociations, err := ForStruct(v.Field(i).Interface(), fields...) + field := v.Field(i) + // we need field to be a pointer, so that we can later set the value + // if the embedded field is of type struct {...}, we have to take its address + if field.Kind() != reflect.Ptr { + field = field.Addr() + } + if fieldIsNil(field) { + // initialize zero value + field = reflect.New(field.Type().Elem()) + // we can only get in this case if v.Field(i) is a pointer type because it could not be nil otherwise + // => it is safe to set it here as is + v.Field(i).Set(field) + } + innerAssociations, err := forStruct(parent, field.Interface(), fields) if err != nil { return nil, err } @@ -92,11 +110,12 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) { for name, builder := range associationBuilders { tag := tags.Find(name) if !tag.Empty() { + pt, pv := getModelDefinition(parent) params := associationParams{ field: f, - model: s, - modelType: t, - modelValue: v, + model: parent, + modelType: pt, + modelValue: pv, popTags: tags, innerAssociations: fieldsWithInnerAssociation[f.Name], } diff --git a/docker-compose.yml b/docker-compose.yml index 54b90986..2201bec8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '2' +version: '2.1' services: mysql: @@ -25,7 +25,7 @@ services: - ./sqldumps:/docker-entrypoint-initdb.d cockroach: image: cockroachdb/cockroach:v20.2.4 - user: ${CURRENT_UID:?"Please run as follows 'CURRENT_UID=$(id -u):$(id -g) docker-compose up'"} + user: ${CURRENT_UID:?"Please run as follows 'CURRENT_UID=$$(id -u):$$(id -g) docker-compose up'"} ports: - "26257:26257" volumes: diff --git a/executors_test.go b/executors_test.go index d428e364..e47e4d0a 100644 --- a/executors_test.go +++ b/executors_test.go @@ -1231,6 +1231,150 @@ func Test_Eager_Creation_Without_Associations(t *testing.T) { }) } +func Test_Eager_Embedded_Struct(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + type AssocFields struct { + Books Books `has_many:"books" order_by:"title asc"` + FavoriteSong Song `has_one:"song" fk_id:"u_id"` + Houses Addresses `many_to_many:"users_addresses"` + } + + type User struct { + ID int `db:"id"` + UserName string `db:"user_name"` + Email string `db:"email"` + Name nulls.String `db:"name"` + Alive nulls.Bool `db:"alive"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + BirthDate nulls.Time `db:"birth_date"` + Bio nulls.String `db:"bio"` + Price nulls.Float64 `db:"price"` + FullName nulls.String `db:"full_name" select:"name as full_name"` + + AssocFields + } + + count, _ := tx.Count(&User{}) + user := User{ + UserName: "dumb-dumb", + Name: nulls.NewString("Arthur Dent"), + AssocFields: AssocFields{ + Books: Books{{Title: "The Hitchhiker's Guide to the Galaxy", Description: "Comedy Science Fiction somewhere in Space", Isbn: "PB42"}}, + FavoriteSong: Song{Title: "Wish You Were Here", ComposedBy: Composer{Name: "Pink Floyd"}}, + Houses: Addresses{ + Address{HouseNumber: 155, Street: "Country Lane"}, + }, + }, + } + + err := tx.Eager().Create(&user) + r.NoError(err) + r.NotZero(user.ID) + + ctx, _ := tx.Count(&User{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Book{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Song{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Address{}) + r.Equal(count+1, ctx) + + u := User{} + q := tx.Eager().Where("name = ?", user.Name.String) + err = q.First(&u) + r.NoError(err) + + r.Equal(user.Name.String, u.Name.String) + r.Len(u.Books, 1) + r.Equal(user.Books[0].Title, u.Books[0].Title) + r.Equal(user.FavoriteSong.Title, u.FavoriteSong.Title) + r.Len(u.Houses, 1) + r.Equal(user.Houses[0].Street, u.Houses[0].Street) + }) +} + +func Test_Eager_Embedded_Ptr_Struct(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + type AssocFields struct { + Books Books `has_many:"books" order_by:"title asc"` + FavoriteSong Song `has_one:"song" fk_id:"u_id"` + Houses Addresses `many_to_many:"users_addresses"` + } + + type User struct { + ID int `db:"id"` + UserName string `db:"user_name"` + Email string `db:"email"` + Name nulls.String `db:"name"` + Alive nulls.Bool `db:"alive"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + BirthDate nulls.Time `db:"birth_date"` + Bio nulls.String `db:"bio"` + Price nulls.Float64 `db:"price"` + FullName nulls.String `db:"full_name" select:"name as full_name"` + + *AssocFields + } + + count, _ := tx.Count(&User{}) + user := User{ + UserName: "dumb-dumb", + Name: nulls.NewString("Arthur Dent"), + AssocFields: &AssocFields{ + Books: Books{{Title: "The Hitchhiker's Guide to the Galaxy", Description: "Comedy Science Fiction somewhere in Space", Isbn: "PB42"}}, + FavoriteSong: Song{Title: "Wish You Were Here", ComposedBy: Composer{Name: "Pink Floyd"}}, + Houses: Addresses{ + Address{HouseNumber: 155, Street: "Country Lane"}, + }, + }, + } + + err := tx.Eager().Create(&user) + r.NoError(err) + r.NotZero(user.ID) + + ctx, _ := tx.Count(&User{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Book{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Song{}) + r.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Address{}) + r.Equal(count+1, ctx) + + u := User{} + q := tx.Eager().Where("name = ?", user.Name.String) + err = q.First(&u) + r.NoError(err) + + r.Equal(user.Name.String, u.Name.String) + r.Len(u.Books, 1) + r.Equal(user.Books[0].Title, u.Books[0].Title) + r.Equal(user.FavoriteSong.Title, u.FavoriteSong.Title) + r.Len(u.Houses, 1) + r.Equal(user.Houses[0].Street, u.Houses[0].Street) + }) +} + func Test_Create_UUID(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests")