diff --git a/callbacks/associations.go b/callbacks/associations.go index fd3141cfe..4a50e6c24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} @@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joins = reflect.Append(joins, joinValue) } + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { elem := f.Index(i) - + if !isPtr { + elem = elem.Addr() + } objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + identityMap[cacheKey] = true + distinctElems = reflect.Append(distinctElems, elem) + } + } } } @@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems, selectColumns, restricted, nil) + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 28b441bd8..7b45befb6 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } + +func TestDuplicateMany2ManyAssociation(t *testing.T) { + user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-2"}, + }} + + user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-3"}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0bbef382a..3d6a78589 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) { value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) - } func findColumnType(dest interface{}, columnName string) ( - foundColumn gorm.ColumnType, err error) { + foundColumn gorm.ColumnType, err error, +) { columnTypes, err := DB.Migrator().ColumnTypes(dest) if err != nil { err = fmt.Errorf("ColumnTypes err:%v", err) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 80e015ffa..7232f9df4 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) - } func TestSerializerAssignFirstOrCreate(t *testing.T) { @@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { } AssertEqual(t, result, out) - //update record + // update record data.Roles = append(data.Roles, "r3") data.JobInfo.Location = "Gates Hillman Complex" if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {