diff --git a/gorm.go b/gorm.go index 1f1dac213..d317e437d 100644 --- a/gorm.go +++ b/gorm.go @@ -31,6 +31,8 @@ type Config struct { NowFunc func() time.Time // DryRun generate sql without execute DryRun bool + // DryRunMigration prevent AutoMigrate() to change the schema + DryRunMigration bool // PrepareStmt executes the given query in cached statement PrepareStmt bool // DisableAutomaticPing @@ -97,6 +99,7 @@ type DB struct { // Session session config when create session with Session() method type Session struct { DryRun bool + DryRunMigration bool PrepareStmt bool NewDB bool Initialized bool @@ -274,6 +277,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.DryRun = true } + if config.DryRunMigration { + tx.Config.DryRunMigration = true + } + if config.QueryFields { tx.Config.QueryFields = true } @@ -457,10 +464,12 @@ func (db *DB) Use(plugin Plugin) error { // ToSQL for generate SQL string. // // db.ToSQL(func(tx *gorm.DB) *gorm.DB { -// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) -// .Limit(10).Offset(5) -// .Order("name ASC") -// .First(&User{}) +// +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) diff --git a/migrator/migrator.go b/migrator/migrator.go index e6782a13b..b61a24d75 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -94,6 +94,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { + if tx.DryRunMigration { + return fmt.Errorf("create table for model %T: %w", value, gorm.ErrDryRunModeUnsupported) + } + if err := tx.Migrator().CreateTable(value); err != nil { return err } @@ -117,6 +121,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column + if tx.DryRunMigration { + return fmt.Errorf("create column for model %T: %w", value, gorm.ErrDryRunModeUnsupported) + } + if err := tx.Migrator().AddColumn(value, dbName); err != nil { return err } @@ -130,6 +138,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { + if tx.DryRunMigration { + return fmt.Errorf("create constraint %s for model %T: %w", constraint.Name, value, gorm.ErrDryRunModeUnsupported) + } + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } @@ -139,6 +151,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, chk := range stmt.Schema.ParseCheckConstraints() { if !tx.Migrator().HasConstraint(value, chk.Name) { + if tx.DryRunMigration { + return fmt.Errorf("create constraint %s for model %T: %w", chk.Name, value, gorm.ErrDryRunModeUnsupported) + } + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } @@ -147,6 +163,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if !tx.Migrator().HasIndex(value, idx.Name) { + if tx.DryRunMigration { + return fmt.Errorf("create index %s for model %T: %w", idx.Name, value, gorm.ErrDryRunModeUnsupported) + } + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } @@ -478,6 +498,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { + if m.DB.DryRunMigration { + return fmt.Errorf("alter column %s for model %T: %w", field.Name, value, gorm.ErrDryRunModeUnsupported) + } + return m.DB.Migrator().AlterColumn(value, field.Name) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0b5bc5ebd..36389743a 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "fmt" "math/rand" "reflect" @@ -959,3 +960,64 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } + +type Origin struct { + ID int64 `gorm:"primaryKey"` + Data string `gorm:"null"` +} + +func (Origin) TableName() string { + return "test" +} + +func TestDryRunAutoMigrate(t *testing.T) { + type ChangeColumn struct { + Origin `gorm:"-"` + ID int64 `gorm:"primaryKey"` + Data int64 `gorm:""` + } + + type AddIndex struct { + Origin `gorm:"-"` + ID int64 `gorm:"primaryKey"` + Data string `gorm:"null;index"` + } + + type AddConstraint struct { + Origin `gorm:"-"` + ID int64 `gorm:"primaryKey"` + Data string `gorm:"null;check:,data <> 'migrate'"` + } + + var tests = []struct { + from, to interface{} + dryrunErr bool + }{ + {&Origin{}, &Origin{}, false}, + {&Origin{}, &ChangeColumn{}, true}, + {&Origin{}, &AddIndex{}, true}, + {&Origin{}, &AddConstraint{}, true}, + } + + for _, test := range tests { + name := strings.ReplaceAll(fmt.Sprintf("%T", test.to), "*", "") + t.Run(name, func(t *testing.T) { + DB.Migrator().DropTable(test.from, test.to) + t.Cleanup(func() { + DB.Migrator().DropTable(test.from, test.to) + }) + + err := DB.Migrator().CreateTable(test.from) + AssertEqual(t, nil, err) + + err = DB.Session(&gorm.Session{DryRunMigration: true}).AutoMigrate(test.to) + if err != nil { + t.Log("migrate error:", err) + } + AssertEqual(t, test.dryrunErr, errors.Is(err, gorm.ErrDryRunModeUnsupported)) + + err = DB.AutoMigrate(test.to) + AssertEqual(t, false, errors.Is(err, gorm.ErrDryRunModeUnsupported)) + }) + } +}