diff --git a/ddlmod.go b/ddlmod.go index 463a1d6..9c93e6a 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -13,6 +13,7 @@ import ( var ( sqliteSeparator = "`|\"|'|\t" + uniqueRegexp = regexp.MustCompile(fmt.Sprintf(`^CONSTRAINT [%v]?[\w-]+[%v]? UNIQUE (.*)$`, sqliteSeparator, sqliteSeparator)) indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator)) tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) @@ -103,11 +104,24 @@ func parseDDL(strs ...string) (*ddl, error) { for _, f := range result.fields { fUpper := strings.ToUpper(f) - if strings.HasPrefix(fUpper, "CHECK") || - strings.HasPrefix(fUpper, "CONSTRAINT") { + if strings.HasPrefix(fUpper, "CHECK") { + continue + } + if strings.HasPrefix(fUpper, "CONSTRAINT") { + matches := uniqueRegexp.FindStringSubmatch(f) + if len(matches) > 0 { + if columns := getAllColumns(matches[1]); len(columns) == 1 { + for idx, column := range result.columns { + if column.NameValue.String == columns[0] { + column.UniqueValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = column + break + } + } + } + } continue } - if strings.HasPrefix(fUpper, "PRIMARY KEY") { for _, name := range getAllColumns(f) { for idx, column := range result.columns { @@ -159,14 +173,7 @@ func parseDDL(strs ...string) (*ddl, error) { } } } else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 { - for _, column := range getAllColumns(matches[1]) { - for idx, c := range result.columns { - if c.NameValue.String == column { - c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true} - result.columns[idx] = c - } - } - } + // don't report Unique by UniqueIndex } else { return nil, errors.New("invalid DDL") } @@ -268,20 +275,6 @@ func (d *ddl) getColumns() []string { return res } -func (d *ddl) alterColumn(name, sql string) bool { - reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") - - for i := 0; i < len(d.fields); i++ { - if reg.MatchString(d.fields[i]) { - d.fields[i] = sql - return false - } - } - - d.fields = append(d.fields, sql) - return true -} - func (d *ddl) removeColumn(name string) bool { reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") diff --git a/ddlmod_test.go b/ddlmod_test.go index 02d3669..6e81e2b 100644 --- a/ddlmod_test.go +++ b/ddlmod_test.go @@ -16,11 +16,12 @@ func TestParseDDL(t *testing.T) { columns []migrator.ColumnType }{ {"with_fk", []string{ - "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", + "CREATE TABLE `notes` (" + + "`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: false, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, @@ -56,28 +57,54 @@ func TestParseDDL(t *testing.T) { ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, - UniqueValue: sql.NullBool{Bool: true, Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}, }, }, - }, - { + }, { "unique index", []string{ "CREATE TABLE `test-b` (`field` integer NOT NULL)", "CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0", }, 1, - []migrator.ColumnType{ - { - NameValue: sql.NullString{String: "field", Valid: true}, - DataTypeValue: sql.NullString{String: "integer", Valid: true}, - ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, - PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, - UniqueValue: sql.NullBool{Bool: true, Valid: true}, - NullableValue: sql.NullBool{Bool: false, Valid: true}, - }, + []migrator.ColumnType{{ + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }}, + }, { + "normal index", + []string{ + "CREATE TABLE `test-c` (`field` integer NOT NULL)", + "CREATE INDEX `idx_uq` ON `test-c`(`field`)", }, + 1, + []migrator.ColumnType{{ + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }}, + }, { + "unique constraint", + []string{ + "CREATE TABLE `unique_struct` (`name` text,CONSTRAINT `uni_unique_struct_name` UNIQUE (`name`))", + }, + 2, + []migrator.ColumnType{{ + NameValue: sql.NullString{String: "name", Valid: true}, + DataTypeValue: sql.NullString{String: "text", Valid: true}, + ColumnTypeValue: sql.NullString{String: "text", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, + }}, }, { "non-unique index", diff --git a/go.mod b/go.mod index 0c031e7..1d14b55 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/mattn/go-sqlite3 v1.14.17 - gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 + gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde ) require ( diff --git a/go.sum b/go.sum index 45bb183..8816d52 100644 --- a/go.sum +++ b/go.sum @@ -8,3 +8,5 @@ gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 h1:sC1Xj4TYrLqg1n3AN10w871An7wJM0gzgcm8jkIkECQ= gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/migrator.go b/migrator.go index 1b85d6d..70e0e99 100644 --- a/migrator.go +++ b/migrator.go @@ -79,14 +79,28 @@ func (m Migrator) AlterColumn(value interface{}, name string) error { return m.RunWithoutForeignKey(func() error { return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { - if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) { - return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile()) + var sqlArgs []interface{} + for i, f := range ddl.fields { + if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName { + ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName) + sqlArgs = []interface{}{m.FullDataTypeOf(field)} + // table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`. + // FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint. + if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") { + uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName) + if uni != nil { + uniSQL, uniArgs := uni.Build() + ddl.addConstraint(uniName, uniSQL) + sqlArgs = append(sqlArgs, uniArgs...) + } + } + break + } } - - return ddl, []interface{}{m.FullDataTypeOf(field)}, nil + return ddl, sqlArgs, nil } - - return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name) + return nil, nil, fmt.Errorf("failed to alter field with name %v", name) }) }) } @@ -153,7 +167,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) return m.recreateTable(value, &table, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { @@ -164,12 +178,8 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { ) if constraint != nil { - constraintName = constraint.Name - constraintSql, constraintValues = buildConstraint(constraint) - } else if chk != nil { - constraintName = chk.Name - constraintSql = "CONSTRAINT ? CHECK (?)" - constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} + constraintName = constraint.GetName() + constraintSql, constraintValues = constraint.Build() } else { return nil, nil, nil } @@ -182,11 +192,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.recreateTable(value, &table, @@ -200,11 +208,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } m.DB.Raw( @@ -317,26 +323,44 @@ func (m Migrator) DropIndex(value interface{}, name string) error { }) } -func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { - sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } +type Index struct { + Seq int + Name string + Unique bool + Origin string + Partial bool +} - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) - return +// GetIndexes return Indexes []gorm.Index and execErr error, +// See the [doc] +// +// [doc]: https://www.sqlite.org/pragma.html#pragma_index_list +func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { + indexes := make([]gorm.Index, 0) + err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + rst := make([]*Index, 0) + if err := m.DB.Debug().Raw(fmt.Sprintf("PRAGMA index_list(%q)", stmt.Table)).Scan(&rst).Error; err != nil { + return err + } + for _, index := range rst { + if index.Origin == "u" { // skip the index was created by a UNIQUE constraint + continue + } + var columns []string + if err := m.DB.Raw(fmt.Sprintf("SELECT name from PRAGMA_index_info(%q)", index.Name)).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)` + return err + } + indexes = append(indexes, &migrator.Index{ + TableName: stmt.Table, + NameValue: index.Name, + ColumnList: columns, + PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY + UniqueValue: sql.NullBool{Bool: index.Unique, Valid: true}, + }) + } + return nil + }) + return indexes, err } func (m Migrator) getRawDDL(table string) (string, error) {