diff --git a/associations/association.go b/associations/association.go index 0bda8694..8fa335fd 100644 --- a/associations/association.go +++ b/associations/association.go @@ -43,7 +43,7 @@ func (a *associationComposite) InnerAssociations() InnerAssociations { // association for Song. type InnerAssociation struct { Name string - Fields string + Fields []string } // InnerAssociations is a group of InnerAssociation. diff --git a/associations/associations_for_struct.go b/associations/associations_for_struct.go index b7c4742f..368e72c0 100644 --- a/associations/associations_for_struct.go +++ b/associations/associations_for_struct.go @@ -34,29 +34,39 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) { } fields = trimFields(fields) associations := Associations{} - innerAssociations := InnerAssociations{} + fieldsWithInnerAssociation := map[string]InnerAssociations{} // validate if fields contains a non existing field in struct. // and verify is it has inner associations. for i := range fields { - var innerField, field string + var innerField string if !validAssociationExpRegexp.MatchString(fields[i]) { return associations, fmt.Errorf("association '%s' does not match the format %s", fields[i], "'' or '.'") } - if strings.Contains(fields[i], ".") { - dotIndex := strings.Index(fields[i], ".") - field = fields[i][:dotIndex] - innerField = fields[i][dotIndex+1:] - fields[i] = field - } + fields[i], innerField = extractFieldAndInnerFields(fields[i]) + if _, ok := t.FieldByName(fields[i]); !ok { return associations, fmt.Errorf("field %s does not exist in model %s", fields[i], t.Name()) } if innerField != "" { - innerAssociations = append(innerAssociations, InnerAssociation{fields[i], innerField}) + var found bool + innerF, _ := extractFieldAndInnerFields(innerField) + + for j := range fieldsWithInnerAssociation[fields[i]] { + f, _ := extractFieldAndInnerFields(fieldsWithInnerAssociation[fields[i]][j].Fields[0]) + if innerF == f { + fieldsWithInnerAssociation[fields[i]][j].Fields = append(fieldsWithInnerAssociation[fields[i]][j].Fields, innerField) + found = true + break + } + } + + if !found { + fieldsWithInnerAssociation[fields[i]] = append(fieldsWithInnerAssociation[fields[i]], InnerAssociation{fields[i], []string{innerField}}) + } } } @@ -79,7 +89,7 @@ func ForStruct(s interface{}, fields ...string) (Associations, error) { modelType: t, modelValue: v, popTags: tags, - innerAssociations: innerAssociations, + innerAssociations: fieldsWithInnerAssociation[f.Name], } a, err := builder(params) @@ -121,3 +131,12 @@ func fieldIgnoredIn(fields []string, field string) bool { } return true } + +func extractFieldAndInnerFields(field string) (string, string) { + if !strings.Contains(field, ".") { + return field, "" + } + + dotIndex := strings.Index(field, ".") + return field[:dotIndex], field[dotIndex+1:] +} diff --git a/finders.go b/finders.go index fa1a220a..7de66924 100644 --- a/finders.go +++ b/finders.go @@ -273,12 +273,16 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { return err } + if err == sql.ErrNoRows { + continue + } + // load all inner associations. innerAssociations := association.InnerAssociations() for _, inner := range innerAssociations { v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) innerQuery := Q(query.Connection) - innerQuery.eagerFields = []string{inner.Fields} + innerQuery.eagerFields = inner.Fields err = innerQuery.eagerAssociations(v.Addr().Interface()) if err != nil { return err diff --git a/finders_test.go b/finders_test.go index 7f30727e..7e2daa02 100644 --- a/finders_test.go +++ b/finders_test.go @@ -908,3 +908,43 @@ func Test_FindManyToMany(t *testing.T) { r.NoError(err) }) } + +func Test_FindMultipleInnerHasMany(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + user := User{Name: nulls.NewString("Mark")} + err := tx.Create(&user) + r.NoError(err) + + book := Book{Title: "Pop Book", Isbn: "PB1", UserID: nulls.NewInt(user.ID)} + err = tx.Create(&book) + r.NoError(err) + + writer := Writer{Name: "Jhon", BookID: book.ID} + err = tx.Create(&writer) + r.NoError(err) + + friend := Friend{FirstName: "Frank", LastName: "Kafka", WriterID: writer.ID} + err = tx.Create(&friend) + r.NoError(err) + + address := Address{Street: "St 27", HouseNumber: 27, WriterID: writer.ID} + err = tx.Create(&address) + r.NoError(err) + + u := User{} + err = tx.Eager("Books.Writers.Addresses", "Books.Writers.Friends").Find(&u, user.ID) + r.NoError(err) + + r.Len(u.Books, 1) + r.Len(u.Books[0].Writers, 1) + r.Len(u.Books[0].Writers[0].Addresses, 1) + r.Equal(u.Books[0].Writers[0].Addresses[0].HouseNumber, 27) + r.Len(u.Books[0].Writers[0].Friends, 1) + r.Equal(u.Books[0].Writers[0].Friends[0].FirstName, "Frank") + }) +} diff --git a/pop_test.go b/pop_test.go index 1d473de0..1e02497e 100644 --- a/pop_test.go +++ b/pop_test.go @@ -163,6 +163,8 @@ type Books []Book type Writer struct { ID int `db:"id"` Name string `db:"name"` + Addresses Addresses `has_many:"addresses"` + Friends []Friend `has_many:"friends"` BookID int `db:"book_id"` Book Book `belongs_to:"book"` CreatedAt time.Time `db:"created_at"` @@ -174,6 +176,7 @@ type Writers []Writer type Address struct { ID int `db:"id"` Street string `db:"street"` + WriterID int `db:"writer_id"` HouseNumber int `db:"house_number"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` @@ -207,6 +210,7 @@ func (UsersAddressQuery) TableName() string { type Friend struct { ID int `db:"id"` FirstName string `db:"first_name"` + WriterID int `db:"writer_id"` LastName string `db:"last_name"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` diff --git a/testdata/migrations/20181104135526_good_friends.up.fizz b/testdata/migrations/20181104135526_good_friends.up.fizz index 61e24d3c..211c3848 100644 --- a/testdata/migrations/20181104135526_good_friends.up.fizz +++ b/testdata/migrations/20181104135526_good_friends.up.fizz @@ -2,5 +2,6 @@ create_table("good_friends") { t.Column("id", "int", {primary: true}) t.Column("first_name", "string", {}) t.Column("last_name", "string", {}) + t.Column("writer_id", "int",{}) t.Timestamps() } \ No newline at end of file diff --git a/testdata/migrations/20181104140340_addresses.up.fizz b/testdata/migrations/20181104140340_addresses.up.fizz index d57d70af..6520f03b 100644 --- a/testdata/migrations/20181104140340_addresses.up.fizz +++ b/testdata/migrations/20181104140340_addresses.up.fizz @@ -2,5 +2,6 @@ create_table("addresses") { t.Column("id", "int", {primary: true}) t.Column("street", "string", {}) t.Column("house_number", "int", {}) + t.Column("writer_id", "int",{}) t.Timestamps() } \ No newline at end of file