Skip to content

Commit

Permalink
Enable WithReturning by default for MariaDB, close #62, #85
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Oct 8, 2022
1 parent ff88afd commit ecb6c6a
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions mysql.go
Expand Up @@ -15,6 +15,7 @@ import (
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)

type Config struct {
Expand All @@ -25,6 +26,7 @@ type Config struct {
SkipInitializeWithVersion bool
DefaultStringSize uint
DefaultDatetimePrecision *int
DisableWithReturning bool
DisableDatetimePrecision bool
DontSupportRenameIndex bool
DontSupportRenameColumn bool
Expand Down Expand Up @@ -86,14 +88,6 @@ func (dialector Dialector) Apply(config *gorm.Config) error {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
ctx := context.Background()

// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
})

if dialector.DriverName == "" {
dialector.DriverName = "mysql"
}
Expand All @@ -111,6 +105,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
}

withReturning := false
if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil {
Expand All @@ -122,6 +117,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportNullAsDefaultValue = true
withReturning = true
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
Expand All @@ -137,6 +133,32 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
}

// register callbacks
callbackConfig := &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
}

if !dialector.Config.DisableWithReturning && withReturning {
callbackConfig.LastInsertIDReversed = true

if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
}

if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
}

if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
}

callbacks.RegisterDefaultCallbacks(db, callbackConfig)

for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
Expand Down Expand Up @@ -176,6 +198,8 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
if column.Name != "" {
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
}

builder.(*gorm.Statement).AddClause(onConflict)
}
}

Expand Down

0 comments on commit ecb6c6a

Please sign in to comment.