diff --git a/go.mod b/go.mod index 47e0bd3..f7e4ea2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/go-sql-driver/mysql v1.6.0 - github.com/jinzhu/now v1.1.5 // indirect gorm.io/gorm v1.23.8 ) + +require github.com/jinzhu/now v1.1.5 // indirect diff --git a/migrator.go b/migrator.go index 5c706c4..47b116b 100644 --- a/migrator.go +++ b/migrator.go @@ -317,10 +317,8 @@ func groupByIndexName(indexList []*Index) map[string][]*Index { } func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) { - if strings.Contains(table, ".") { - if tables := strings.Split(table, `.`); len(tables) == 2 { - return tables[0], tables[1] - } + if tables := strings.Split(table, `.`); len(tables) == 2 { + return tables[0], tables[1] } return m.CurrentDatabase(), table diff --git a/mysql.go b/mysql.go index 6f764de..5772a65 100644 --- a/mysql.go +++ b/mysql.go @@ -11,6 +11,7 @@ import ( "time" "github.com/go-sql-driver/mysql" + "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" @@ -76,22 +77,20 @@ func (dialector Dialector) NowFunc(n int) func() time.Time { } func (dialector Dialector) Apply(config *gorm.Config) error { - if config.NowFunc == nil { - if dialector.DefaultDatetimePrecision == nil { - dialector.DefaultDatetimePrecision = &defaultDatetimePrecision - } - - // while maintaining the readability of the code, separate the business logic from - // the general part and leave it to the function to do it here. - config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision) + if config.NowFunc != nil { + return nil } + if dialector.DefaultDatetimePrecision == nil { + dialector.DefaultDatetimePrecision = &defaultDatetimePrecision + } + // while maintaining the readability of the code, separate the business logic from + // the general part and leave it to the function to do it here. + config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision) return nil } func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - ctx := context.Background() - if dialector.DriverName == "" { dialector.DriverName = "mysql" } @@ -111,7 +110,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { withReturning := false if !dialector.Config.SkipInitializeWithVersion { - err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion) + err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion) if err != nil { return err } @@ -121,9 +120,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportForShareClause = true dialector.Config.DontSupportNullAsDefaultValue = true - if checkVersion(dialector.ServerVersion, "10.5") { - withReturning = true - } + withReturning = checkVersion(dialector.ServerVersion, "10.5") } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") { dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameColumn = true @@ -176,7 +173,7 @@ const ( ClauseOnConflict = "ON CONFLICT" // ClauseValues for clause.ClauseBuilder VALUES key ClauseValues = "VALUES" - // ClauseValues for clause.ClauseBuilder FOR key + // ClauseFor for clause.ClauseBuilder FOR key ClauseFor = "FOR" ) @@ -393,11 +390,11 @@ func (dialector Dialector) getSchemaStringType(field *schema.Field) string { } func (dialector Dialector) getSchemaTimeType(field *schema.Field) string { - precision := "" if !dialector.DisableDatetimePrecision && field.Precision == 0 { field.Precision = *dialector.DefaultDatetimePrecision } + var precision string if field.Precision > 0 { precision = fmt.Sprintf("(%d)", field.Precision) } @@ -421,27 +418,31 @@ func (dialector Dialector) getSchemaBytesType(field *schema.Field) string { } func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string { - sqlType := "bigint" + constraint := func(sqlType string) string { + if field.DataType == schema.Uint { + sqlType += " unsigned" + } + if field.NotNull { + sqlType += " NOT NULL" + } + if field.AutoIncrement { + sqlType += " AUTO_INCREMENT" + } + return sqlType + } + switch { case field.Size <= 8: - sqlType = "tinyint" + return constraint("tinyint") case field.Size <= 16: - sqlType = "smallint" + return constraint("smallint") case field.Size <= 24: - sqlType = "mediumint" + return constraint("mediumint") case field.Size <= 32: - sqlType = "int" - } - - if field.DataType == schema.Uint { - sqlType += " unsigned" - } - - if field.AutoIncrement { - sqlType += " AUTO_INCREMENT" + return constraint("int") + default: + return constraint("bigint") } - - return sqlType } func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { @@ -462,23 +463,25 @@ func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error } -var versionTrimerRegexp = regexp.MustCompile(`^(\d+).*$`) - // checkVersion newer or equal returns true, old returns false func checkVersion(newVersion, oldVersion string) bool { if newVersion == oldVersion { return true } - newVersions := strings.Split(newVersion, ".") - oldVersions := strings.Split(oldVersion, ".") + var ( + versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`) + + newVersions = strings.Split(newVersion, ".") + oldVersions = strings.Split(oldVersion, ".") + ) for idx, nv := range newVersions { if len(oldVersions) <= idx { return true } - nvi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(nv, "$1")) - ovi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(oldVersions[idx], "$1")) + nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1")) + ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1")) if nvi == ovi { continue }