Skip to content

Commit

Permalink
fix: association many2many duplicate elem (#5473)
Browse files Browse the repository at this point in the history
* fix: association many2many duplicate elem

* chore: gofumpt style
  • Loading branch information
a631807682 committed Jul 1, 2022
1 parent 235c093 commit c74bc57
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
27 changes: 20 additions & 7 deletions callbacks/associations.go
Expand Up @@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}

Expand All @@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joins = reflect.Append(joins, joinValue)
}

identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))

for i := 0; i < f.Len(); i++ {
elem := f.Index(i)

if !isPtr {
elem = elem.Addr()
}
objs = append(objs, v)
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
elems = reflect.Append(elems, elem)

relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}

cacheKey := utils.ToStringKey(relPrimaryValues)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
distinctElems = reflect.Append(distinctElems, elem)
}

}
}
}
Expand All @@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}

for i := 0; i < elemLen; i++ {
Expand Down
27 changes: 27 additions & 0 deletions tests/associations_many2many_test.go
Expand Up @@ -3,6 +3,7 @@ package tests_test
import (
"testing"

"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)

Expand Down Expand Up @@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Team").Clear()
AssertAssociationCount(t, users, "Team", 0, "After Clear")
}

func TestDuplicateMany2ManyAssociation(t *testing.T) {
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
}}

user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)

var findUser1 User
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)

var findUser2 User
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}
4 changes: 2 additions & 2 deletions tests/migrate_test.go
Expand Up @@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) {
value, ok = ct.DefaultValue()
AssertEqual(t, "", value)
AssertEqual(t, false, ok)

}

func findColumnType(dest interface{}, columnName string) (
foundColumn gorm.ColumnType, err error) {
foundColumn gorm.ColumnType, err error,
) {
columnTypes, err := DB.Migrator().ColumnTypes(dest)
if err != nil {
err = fmt.Errorf("ColumnTypes err:%v", err)
Expand Down
3 changes: 1 addition & 2 deletions tests/serializer_test.go
Expand Up @@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) {
}

AssertEqual(t, result, data)

}

func TestSerializerAssignFirstOrCreate(t *testing.T) {
Expand Down Expand Up @@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
}
AssertEqual(t, result, out)

//update record
// update record
data.Roles = append(data.Roles, "r3")
data.JobInfo.Location = "Gates Hillman Complex"
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
Expand Down

0 comments on commit c74bc57

Please sign in to comment.