From c035852cb19c3d67d16483a53d4c31547436d70f Mon Sep 17 00:00:00 2001 From: Defool Date: Sun, 11 Sep 2022 21:36:24 +0800 Subject: [PATCH 1/3] DryRun for migrator --- migrator/migrator.go | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index e6782a13b..90c855b05 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -8,9 +8,11 @@ import ( "reflect" "regexp" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) @@ -30,6 +32,17 @@ type Config struct { gorm.Dialector } +type printSQLLogger struct{} + +func (l *printSQLLogger) LogMode(ll logger.LogLevel) logger.Interface { return l } +func (l *printSQLLogger) Info(ctx context.Context, log string, args ...interface{}) {} +func (l *printSQLLogger) Warn(ctx context.Context, log string, args ...interface{}) {} +func (l *printSQLLogger) Error(ctx context.Context, log string, args ...interface{}) {} +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string @@ -92,14 +105,19 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) - if !tx.Migrator().HasTable(value) { - if err := tx.Migrator().CreateTable(value); err != nil { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, err := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } @@ -117,10 +135,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, dbName); err != nil { + if err := execTx.Migrator().AddColumn(value, dbName); err != nil { return err } - } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { // found, smart migrate return err } @@ -129,8 +147,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } @@ -138,16 +156,16 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } } for _, idx := range stmt.Schema.ParseIndexes() { - if !tx.Migrator().HasIndex(value, idx.Name) { - if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } From ac9262f736c9ca9c8747b9f05325403035ce02ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 24 Dec 2022 12:05:30 +0800 Subject: [PATCH 2/3] Update migrator.go --- migrator/migrator.go | 72 +++++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 90c855b05..014865c99 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -424,32 +424,49 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn, isSameType bool + ) - // check type - if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { - alterColumn = true - } + if !field.PrimaryKey { + // check type + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if !field.PrimaryKey && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } @@ -471,17 +488,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check default value if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && !strings.EqualFold(field.DefaultValue, "NULL") dv, dvNotNull := columnType.DefaultValue() - if dvNotNull && field.DefaultValueInterface == nil { + if dvNotNull && !currentDefaultNotNull { // defalut value -> null alterColumn = true - } else if !dvNotNull && field.DefaultValueInterface != nil { + } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if dv != field.DefaultValue { + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { // default value not equal // not both null - if !(field.DefaultValueInterface == nil && !dvNotNull) { + if currentDefaultNotNull || dvNotNull { alterColumn = true } } @@ -496,7 +515,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + return m.DB.Migrator().AlterColumn(value, field.DBName) } return nil @@ -881,3 +900,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} From 011eeccd219890c7e144467aa78553a24819549e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 24 Dec 2022 12:06:26 +0800 Subject: [PATCH 3/3] Update migrator.go --- migrator/migrator.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 014865c99..eafe7bb29 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -32,15 +32,14 @@ type Config struct { gorm.Dialector } -type printSQLLogger struct{} +type printSQLLogger struct { + logger.Interface +} -func (l *printSQLLogger) LogMode(ll logger.LogLevel) logger.Interface { return l } -func (l *printSQLLogger) Info(ctx context.Context, log string, args ...interface{}) {} -func (l *printSQLLogger) Warn(ctx context.Context, log string, args ...interface{}) {} -func (l *printSQLLogger) Error(ctx context.Context, log string, args ...interface{}) {} func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) } // GormDataTypeInterface gorm data type interface @@ -109,7 +108,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { execTx := queryTx if m.DB.DryRun { queryTx.DryRun = false - execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{}}) + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) } if !queryTx.Migrator().HasTable(value) { if err := execTx.Migrator().CreateTable(value); err != nil {