diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd2..032bf4a1c 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index b5b672a93..f45b26185 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -148,3 +149,53 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +}