Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unnecessary judgments and fix misspellings #95

Merged
merged 1 commit into from Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Expand Up @@ -4,6 +4,7 @@ go 1.14

require (
github.com/go-sql-driver/mysql v1.6.0
github.com/jinzhu/now v1.1.5 // indirect
gorm.io/gorm v1.23.8
)

require github.com/jinzhu/now v1.1.5 // indirect
6 changes: 2 additions & 4 deletions migrator.go
Expand Up @@ -317,10 +317,8 @@ func groupByIndexName(indexList []*Index) map[string][]*Index {
}

func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
if strings.Contains(table, ".") {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}

return m.CurrentDatabase(), table
Expand Down
77 changes: 40 additions & 37 deletions mysql.go
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/go-sql-driver/mysql"

"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -76,22 +77,20 @@ func (dialector Dialector) NowFunc(n int) func() time.Time {
}

func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NowFunc == nil {
if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}

// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
if config.NowFunc != nil {
return nil
}

if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
return nil
}

func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
ctx := context.Background()

if dialector.DriverName == "" {
dialector.DriverName = "mysql"
}
Expand All @@ -111,7 +110,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {

withReturning := false
if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion)
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil {
return err
}
Expand All @@ -121,9 +120,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportNullAsDefaultValue = true
if checkVersion(dialector.ServerVersion, "10.5") {
withReturning = true
}
withReturning = checkVersion(dialector.ServerVersion, "10.5")
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
Expand Down Expand Up @@ -176,7 +173,7 @@ const (
ClauseOnConflict = "ON CONFLICT"
// ClauseValues for clause.ClauseBuilder VALUES key
ClauseValues = "VALUES"
// ClauseValues for clause.ClauseBuilder FOR key
// ClauseFor for clause.ClauseBuilder FOR key
ClauseFor = "FOR"
)

Expand Down Expand Up @@ -393,11 +390,11 @@ func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
}

func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
precision := ""
if !dialector.DisableDatetimePrecision && field.Precision == 0 {
field.Precision = *dialector.DefaultDatetimePrecision
}

var precision string
if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision)
}
Expand All @@ -421,27 +418,31 @@ func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
}

func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
sqlType := "bigint"
constraint := func(sqlType string) string {
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
if field.NotNull {
sqlType += " NOT NULL"
}
if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
}
return sqlType
}

switch {
case field.Size <= 8:
sqlType = "tinyint"
return constraint("tinyint")
case field.Size <= 16:
sqlType = "smallint"
return constraint("smallint")
case field.Size <= 24:
sqlType = "mediumint"
return constraint("mediumint")
case field.Size <= 32:
sqlType = "int"
}

if field.DataType == schema.Uint {
sqlType += " unsigned"
}

if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
return constraint("int")
default:
return constraint("bigint")
}

return sqlType
}

func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
Expand All @@ -462,23 +463,25 @@ func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
}

var versionTrimerRegexp = regexp.MustCompile(`^(\d+).*$`)

// checkVersion newer or equal returns true, old returns false
func checkVersion(newVersion, oldVersion string) bool {
if newVersion == oldVersion {
return true
}

newVersions := strings.Split(newVersion, ".")
oldVersions := strings.Split(oldVersion, ".")
var (
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)

newVersions = strings.Split(newVersion, ".")
oldVersions = strings.Split(oldVersion, ".")
)
for idx, nv := range newVersions {
if len(oldVersions) <= idx {
return true
}

nvi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
if nvi == ovi {
continue
}
Expand Down