From ecb6c6a014f6fa2daa89e207f3e7d75f6c4b9f91 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 19:38:11 +0800 Subject: [PATCH] Enable WithReturning by default for MariaDB, close #62, #85 --- mysql.go | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/mysql.go b/mysql.go index 20b6712..bf0ed07 100644 --- a/mysql.go +++ b/mysql.go @@ -15,6 +15,7 @@ import ( "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) type Config struct { @@ -25,6 +26,7 @@ type Config struct { SkipInitializeWithVersion bool DefaultStringSize uint DefaultDatetimePrecision *int + DisableWithReturning bool DisableDatetimePrecision bool DontSupportRenameIndex bool DontSupportRenameColumn bool @@ -86,14 +88,6 @@ func (dialector Dialector) Apply(config *gorm.Config) error { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { ctx := context.Background() - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - CreateClauses: CreateClauses, - QueryClauses: QueryClauses, - UpdateClauses: UpdateClauses, - DeleteClauses: DeleteClauses, - }) - if dialector.DriverName == "" { dialector.DriverName = "mysql" } @@ -111,6 +105,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } } + withReturning := false if !dialector.Config.SkipInitializeWithVersion { err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion) if err != nil { @@ -122,6 +117,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true dialector.Config.DontSupportNullAsDefaultValue = true + withReturning = true } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") { dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameColumn = true @@ -137,6 +133,32 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } } + // register callbacks + callbackConfig := &callbacks.Config{ + CreateClauses: CreateClauses, + QueryClauses: QueryClauses, + UpdateClauses: UpdateClauses, + DeleteClauses: DeleteClauses, + } + + if !dialector.Config.DisableWithReturning && withReturning { + callbackConfig.LastInsertIDReversed = true + + if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") { + callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") + } + + if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") { + callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") + } + + if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") { + callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") + } + } + + callbacks.RegisterDefaultCallbacks(db, callbackConfig) + for k, v := range dialector.ClauseBuilders() { db.ClauseBuilders[k] = v } @@ -176,6 +198,8 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { if column.Name != "" { onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} } + + builder.(*gorm.Statement).AddClause(onConflict) } }