Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: associations for embedded fields #707

Merged
merged 2 commits into from Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 23 additions & 4 deletions associations/associations_for_struct.go
Expand Up @@ -27,6 +27,11 @@ var associationBuilders = map[string]associationBuilder{}
// it throws an error when it finds a field that does
// not exist for a model.
func ForStruct(s interface{}, fields ...string) (Associations, error) {
return forStruct(s, s, fields)
}

// forStruct is a recursive helper that passes the root model down for embedded fields
func forStruct(parent, s interface{}, fields []string) (Associations, error) {
t, v := getModelDefinition(s)
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("could not get struct associations: not a struct but %T", s)
Expand Down Expand Up @@ -74,7 +79,20 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) {

// inline embedded field
if f.Anonymous {
innerAssociations, err := ForStruct(v.Field(i).Interface(), fields...)
field := v.Field(i)
// we need field to be a pointer, so that we can later set the value
// if the embedded field is of type struct {...}, we have to take its address
if field.Kind() != reflect.Ptr {
field = field.Addr()
}
if fieldIsNil(field) {
// initialize zero value
field = reflect.New(field.Type().Elem())
// we can only get in this case if v.Field(i) is a pointer type because it could not be nil otherwise
// => it is safe to set it here as is
v.Field(i).Set(field)
}
innerAssociations, err := forStruct(parent, field.Interface(), fields)
if err != nil {
return nil, err
}
Expand All @@ -92,11 +110,12 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) {
for name, builder := range associationBuilders {
tag := tags.Find(name)
if !tag.Empty() {
pt, pv := getModelDefinition(parent)
params := associationParams{
field: f,
model: s,
modelType: t,
modelValue: v,
model: parent,
modelType: pt,
modelValue: pv,
popTags: tags,
innerAssociations: fieldsWithInnerAssociation[f.Name],
}
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yml
@@ -1,4 +1,4 @@
version: '2'
version: '2.1'

services:
mysql:
Expand All @@ -25,7 +25,7 @@ services:
- ./sqldumps:/docker-entrypoint-initdb.d
cockroach:
image: cockroachdb/cockroach:v20.2.4
user: ${CURRENT_UID:?"Please run as follows 'CURRENT_UID=$(id -u):$(id -g) docker-compose up'"}
user: ${CURRENT_UID:?"Please run as follows 'CURRENT_UID=$$(id -u):$$(id -g) docker-compose up'"}
ports:
- "26257:26257"
volumes:
Expand Down
144 changes: 144 additions & 0 deletions executors_test.go
Expand Up @@ -1231,6 +1231,150 @@ func Test_Eager_Creation_Without_Associations(t *testing.T) {
})
}

func Test_Eager_Embedded_Struct(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

type AssocFields struct {
Books Books `has_many:"books" order_by:"title asc"`
FavoriteSong Song `has_one:"song" fk_id:"u_id"`
Houses Addresses `many_to_many:"users_addresses"`
}

type User struct {
ID int `db:"id"`
UserName string `db:"user_name"`
Email string `db:"email"`
Name nulls.String `db:"name"`
Alive nulls.Bool `db:"alive"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
BirthDate nulls.Time `db:"birth_date"`
Bio nulls.String `db:"bio"`
Price nulls.Float64 `db:"price"`
FullName nulls.String `db:"full_name" select:"name as full_name"`

AssocFields
}

count, _ := tx.Count(&User{})
user := User{
UserName: "dumb-dumb",
Name: nulls.NewString("Arthur Dent"),
AssocFields: AssocFields{
Books: Books{{Title: "The Hitchhiker's Guide to the Galaxy", Description: "Comedy Science Fiction somewhere in Space", Isbn: "PB42"}},
FavoriteSong: Song{Title: "Wish You Were Here", ComposedBy: Composer{Name: "Pink Floyd"}},
Houses: Addresses{
Address{HouseNumber: 155, Street: "Country Lane"},
},
},
}

err := tx.Eager().Create(&user)
r.NoError(err)
r.NotZero(user.ID)

ctx, _ := tx.Count(&User{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Book{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Song{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Address{})
r.Equal(count+1, ctx)

u := User{}
q := tx.Eager().Where("name = ?", user.Name.String)
err = q.First(&u)
r.NoError(err)

r.Equal(user.Name.String, u.Name.String)
r.Len(u.Books, 1)
r.Equal(user.Books[0].Title, u.Books[0].Title)
r.Equal(user.FavoriteSong.Title, u.FavoriteSong.Title)
r.Len(u.Houses, 1)
r.Equal(user.Houses[0].Street, u.Houses[0].Street)
})
}

func Test_Eager_Embedded_Ptr_Struct(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

type AssocFields struct {
Books Books `has_many:"books" order_by:"title asc"`
FavoriteSong Song `has_one:"song" fk_id:"u_id"`
Houses Addresses `many_to_many:"users_addresses"`
}

type User struct {
ID int `db:"id"`
UserName string `db:"user_name"`
Email string `db:"email"`
Name nulls.String `db:"name"`
Alive nulls.Bool `db:"alive"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
BirthDate nulls.Time `db:"birth_date"`
Bio nulls.String `db:"bio"`
Price nulls.Float64 `db:"price"`
FullName nulls.String `db:"full_name" select:"name as full_name"`

*AssocFields
}

count, _ := tx.Count(&User{})
user := User{
UserName: "dumb-dumb",
Name: nulls.NewString("Arthur Dent"),
AssocFields: &AssocFields{
Books: Books{{Title: "The Hitchhiker's Guide to the Galaxy", Description: "Comedy Science Fiction somewhere in Space", Isbn: "PB42"}},
FavoriteSong: Song{Title: "Wish You Were Here", ComposedBy: Composer{Name: "Pink Floyd"}},
Houses: Addresses{
Address{HouseNumber: 155, Street: "Country Lane"},
},
},
}

err := tx.Eager().Create(&user)
r.NoError(err)
r.NotZero(user.ID)

ctx, _ := tx.Count(&User{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Book{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Song{})
r.Equal(count+1, ctx)

ctx, _ = tx.Count(&Address{})
r.Equal(count+1, ctx)

u := User{}
q := tx.Eager().Where("name = ?", user.Name.String)
err = q.First(&u)
r.NoError(err)

r.Equal(user.Name.String, u.Name.String)
r.Len(u.Books, 1)
r.Equal(user.Books[0].Title, u.Books[0].Title)
r.Equal(user.FavoriteSong.Title, u.FavoriteSong.Title)
r.Len(u.Houses, 1)
r.Equal(user.Houses[0].Street, u.Houses[0].Street)
})
}

func Test_Create_UUID(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down