diff --git a/go.mod b/go.mod index cf3b406..0b2f7f7 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/jackc/pgx/v4 v4.17.2 - gorm.io/gorm v1.23.7 + gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755 ) diff --git a/go.sum b/go.sum index b9ce810..51fb379 100644 --- a/go.sum +++ b/go.sum @@ -192,4 +192,6 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/gorm v1.23.7 h1:ww+9Mu5WwHKDSOQZFC4ipu/sgpKMr9EtrJ0uwBqNtB0= gorm.io/gorm v1.23.7/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755 h1:7AdrbfcvKnzejfqP5g37fdSZOXH/JvaPIzBIHTOqXKk= +gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/migrator.go b/migrator.go index 26573da..92f61fb 100644 --- a/migrator.go +++ b/migrator.go @@ -197,6 +197,8 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if err := m.Migrator.AddColumn(value, field); err != nil { return err } + m.resetPreparedStmts() + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { if field.Comment != "" { @@ -266,7 +268,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { var ( columnTypes, _ = m.DB.Migrator().ColumnTypes(value) @@ -347,6 +349,12 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { } return fmt.Errorf("failed to look up field with name: %s", field) }) + + if err != nil { + return err + } + m.resetPreparedStmts() + return nil } func (m Migrator) HasConstraint(value interface{}, name string) bool { @@ -694,3 +702,30 @@ func groupByIndexName(indexList []*Index) map[string][]*Index { func (m Migrator) GetTypeAliases(databaseTypeName string) []string { return typeAliasMap[databaseTypeName] } + +// should reset prepared stmts when table changed +func (m Migrator) resetPreparedStmts() { + if m.DB.PrepareStmt { + if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok { + pdb.Reset() + } + } +} + +func (m Migrator) DropColumn(dst interface{}, field string) error { + if err := m.Migrator.DropColumn(dst, field); err != nil { + return err + } + + m.resetPreparedStmts() + return nil +} + +func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { + if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil { + return err + } + + m.resetPreparedStmts() + return nil +}