Skip to content

Commit

Permalink
fix: alter serial column (#98)
Browse files Browse the repository at this point in the history
* fix: alter sercial column

* chore: spelling mistakes

* fix: drop sequence
  • Loading branch information
a631807682 committed Apr 24, 2022
1 parent 71eadf4 commit 6fa74f7
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 2 deletions.
111 changes: 109 additions & 2 deletions migrator.go
Expand Up @@ -249,8 +249,27 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.DB.Connection(func(tx *gorm.DB) error {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
if field.AutoIncrement && filedColumnAutoIncrement { // update
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.UpdateSequence(tx, stmt, field, serialDatabaseType); err != nil {
return err
}
}
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(tx, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(tx, stmt, field, fileType); err != nil {
return err
}
} else {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
}
}

Expand Down Expand Up @@ -473,3 +492,91 @@ func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}
}
return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table
}

func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {

_, table := m.CurrentSchema(stmt, stmt.Table)
tableName := table.(string)

sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_")
if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName},
clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}

if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')",
clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}

if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?",
clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}
return
}

func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {

sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}

if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}

if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
return
}

func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
fileType clause.Expr) (err error) {

sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}

if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}

if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}

if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}

return
}

func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) (
sequenceName string, err error) {
_, table := m.CurrentSchema(stmt, stmt.Table)

// DefaultValueValue is reset by ColumnTypes, search again.
var columnDefault string
err = tx.Raw(
`SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`,
table, field.DBName).Scan(&columnDefault).Error

if err != nil {
return
}

sequenceName = strings.TrimSuffix(
strings.TrimPrefix(columnDefault, `nextval('`),
`'::regclass)`,
)
return
}
13 changes: 13 additions & 0 deletions postgres.go
Expand Up @@ -206,3 +206,16 @@ func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}

func getSerialDatabaseType(s string) (dbType string, ok bool) {
switch s {
case "smallserial":
return "smallint", true
case "serial":
return "integer", true
case "bigserial":
return "bigint", true
default:
return "", false
}
}

0 comments on commit 6fa74f7

Please sign in to comment.