diff --git a/mysql.go b/mysql.go index e927f3d..0e4c7a2 100644 --- a/mysql.go +++ b/mysql.go @@ -5,6 +5,8 @@ import ( "database/sql" "fmt" "math" + "regexp" + "strconv" "strings" "time" @@ -119,7 +121,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true dialector.Config.DontSupportNullAsDefaultValue = true - withReturning = true + if checkVersion(dialector.ServerVersion, "10.5") { + withReturning = true + } } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") { dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameColumn = true @@ -454,3 +458,29 @@ func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error } + +var versionTrimerRegexp = regexp.MustCompile(`^(\d+).*$`) + +// checkVersion newer or equal returns true, old returns false +func checkVersion(newVersion, oldVersion string) bool { + if newVersion == oldVersion { + return true + } + + newVersions := strings.Split(newVersion, ".") + oldVersions := strings.Split(oldVersion, ".") + for idx, nv := range newVersions { + if len(oldVersions) <= idx { + return true + } + + nvi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(nv, "$1")) + ovi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(oldVersions[idx], "$1")) + if nvi == ovi { + continue + } + return nvi > ovi + } + + return false +} diff --git a/mysql_test.go b/mysql_test.go index 4341ec8..26a07b6 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -58,3 +58,20 @@ func BenchmarkDialector_QuoteTo(b *testing.B) { buf.Reset() } } + +func TestCheckVersion(t *testing.T) { + versions := map[string]string{ + "5.6.1": "5.6", + "5.10.2": "5.6", + "5.10": "5.6", + "10.6.26-MariaDB-1:10.4.26+maria~ubu2004": "10.6", + "10.6.26-MariaDB-1:10.4.26+maria~ubu2005": "10.6.3", + "10.4.26-MariaDB-1:10.4.26+maria~ubu2004": "5.6", + } + + for k, v := range versions { + if !checkVersion(k, v) || checkVersion(v, k) { + t.Fatalf("returns %v when comparing %v, %v", checkVersion(k, v), k, v) + } + } +}