diff --git a/migrator.go b/migrator.go index f20f767..25078e1 100644 --- a/migrator.go +++ b/migrator.go @@ -249,8 +249,27 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return m.DB.Connection(func(tx *gorm.DB) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} if fieldColumnType.DatabaseTypeName() != fileType.SQL { - if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { - return err + filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() + if field.AutoIncrement && filedColumnAutoIncrement { // update + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { + if err := m.UpdateSequence(tx, stmt, field, serialDatabaseType); err != nil { + return err + } + } + } else if field.AutoIncrement && !filedColumnAutoIncrement { // create + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if err := m.CreateSequence(tx, stmt, field, serialDatabaseType); err != nil { + return err + } + } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete + if err := m.DeleteSequence(tx, stmt, field, fileType); err != nil { + return err + } + } else { + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { + return err + } } } @@ -475,3 +494,91 @@ func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{} } return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table } + +func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + serialDatabaseType string) (err error) { + + _, table := m.CurrentSchema(stmt, stmt.Table) + tableName := table.(string) + + sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_") + if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, + clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')", + clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?", + clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil { + return err + } + return +} + +func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + serialDatabaseType string) (err error) { + + sequenceName, err := m.getColumnSequenceName(tx, stmt, field) + if err != nil { + return err + } + + if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { + return err + } + return +} + +func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, + fileType clause.Expr) (err error) { + + sequenceName, err := m.getColumnSequenceName(tx, stmt, field) + if err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { + return err + } + + if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", + m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil { + return err + } + + if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil { + return err + } + + return +} + +func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) ( + sequenceName string, err error) { + _, table := m.CurrentSchema(stmt, stmt.Table) + + // DefaultValueValue is reset by ColumnTypes, search again. + var columnDefault string + err = tx.Raw( + `SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`, + table, field.DBName).Scan(&columnDefault).Error + + if err != nil { + return + } + + sequenceName = strings.TrimSuffix( + strings.TrimPrefix(columnDefault, `nextval('`), + `'::regclass)`, + ) + return +} diff --git a/postgres.go b/postgres.go index eae2e35..027703c 100644 --- a/postgres.go +++ b/postgres.go @@ -206,3 +206,16 @@ func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { tx.Exec("ROLLBACK TO SAVEPOINT " + name) return nil } + +func getSerialDatabaseType(s string) (dbType string, ok bool) { + switch s { + case "smallserial": + return "smallint", true + case "serial": + return "integer", true + case "bigserial": + return "bigint", true + default: + return "", false + } +}