Skip to content

Commit

Permalink
Sort down migrations (#657)
Browse files Browse the repository at this point in the history
Basically, just reversing the up migration order does not work, as that puts "all" migrations before specific ones. Therefore, I added implemented the proper `Less` function for down migrations explicitly.

Related #533

Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
zepatrik and aeneasr committed Aug 10, 2021
1 parent 0ecad25 commit bb7527e
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 69 deletions.
10 changes: 9 additions & 1 deletion file_migrator.go
Expand Up @@ -81,7 +81,15 @@ func (fm *FileMigrator) findMigrations(runner func(mf Migration, tx *Connection)
Type: match.Type,
Runner: runner,
}
fm.Migrations[mf.Direction] = append(fm.Migrations[mf.Direction], mf)
switch mf.Direction {
case "up":
fm.UpMigrations.Migrations = append(fm.UpMigrations.Migrations, mf)
case "down":
fm.DownMigrations.Migrations = append(fm.DownMigrations.Migrations, mf)
default:
// the regex only matches `(up|down)` for direction, so a panic here is appropriate
panic("got unknown migration direction " + mf.Direction)
}
}
return nil
})
Expand Down
10 changes: 9 additions & 1 deletion migration_box.go
Expand Up @@ -75,7 +75,15 @@ func (fm *MigrationBox) findMigrations(runner func(f packd.File) func(mf Migrati
Type: match.Type,
Runner: runner(f),
}
fm.Migrations[mf.Direction] = append(fm.Migrations[mf.Direction], mf)
switch mf.Direction {
case "up":
fm.UpMigrations.Migrations = append(fm.UpMigrations.Migrations, mf)
case "down":
fm.DownMigrations.Migrations = append(fm.DownMigrations.Migrations, mf)
default:
// the regex only matches `(up|down)` for direction, so a panic here is appropriate
panic("got unknown migration direction " + mf.Direction)
}
return nil
})
}
14 changes: 7 additions & 7 deletions migration_box_test.go
Expand Up @@ -32,11 +32,11 @@ func Test_MigrationBox(t *testing.T) {

b, err := NewMigrationBox(packr.New("./testdata/migrations/multiple", "./testdata/migrations/multiple"), PDB)
r.NoError(err)
r.Equal(4, len(b.Migrations["up"]))
r.Equal("mysql", b.Migrations["up"][0].DBType)
r.Equal("postgres", b.Migrations["up"][1].DBType)
r.Equal("sqlite3", b.Migrations["up"][2].DBType)
r.Equal("all", b.Migrations["up"][3].DBType)
r.Equal(4, len(b.UpMigrations.Migrations))
r.Equal("mysql", b.UpMigrations.Migrations[0].DBType)
r.Equal("postgres", b.UpMigrations.Migrations[1].DBType)
r.Equal("sqlite3", b.UpMigrations.Migrations[2].DBType)
r.Equal("all", b.UpMigrations.Migrations[3].DBType)
})

t.Run("ignores clutter files", func(t *testing.T) {
Expand All @@ -45,7 +45,7 @@ func Test_MigrationBox(t *testing.T) {

b, err := NewMigrationBox(packr.New("./testdata/migrations/cluttered", "./testdata/migrations/cluttered"), PDB)
r.NoError(err)
r.Equal(1, len(b.Migrations["up"]))
r.Equal(1, len(b.UpMigrations.Migrations))
r.Equal(1, len(*logs))
r.Equal(logging.Warn, (*logs)[0].lvl)
r.Contains((*logs)[0].s, "ignoring file")
Expand All @@ -58,7 +58,7 @@ func Test_MigrationBox(t *testing.T) {

b, err := NewMigrationBox(packr.New("./testdata/migrations/unsupported_dialect", "./testdata/migrations/unsupported_dialect"), PDB)
r.NoError(err)
r.Equal(0, len(b.Migrations["up"]))
r.Equal(0, len(b.UpMigrations.Migrations))
r.Equal(1, len(*logs))
r.Equal(logging.Warn, (*logs)[0].lvl)
r.Contains((*logs)[0].s, "ignoring migration")
Expand Down
33 changes: 25 additions & 8 deletions migration_info.go
Expand Up @@ -36,14 +36,6 @@ func (mfs Migrations) Len() int {
return len(mfs)
}

func (mfs Migrations) Less(i, j int) bool {
if mfs[i].Version == mfs[j].Version {
// force "all" to the back
return mfs[i].DBType != "all"
}
return mfs[i].Version < mfs[j].Version
}

func (mfs Migrations) Swap(i, j int) {
mfs[i], mfs[j] = mfs[j], mfs[i]
}
Expand All @@ -57,3 +49,28 @@ func (mfs *Migrations) Filter(f func(mf Migration) bool) {
}
*mfs = vsf
}

type (
UpMigrations struct {
Migrations
}
DownMigrations struct {
Migrations
}
)

func (mfs UpMigrations) Less(i, j int) bool {
if mfs.Migrations[i].Version == mfs.Migrations[j].Version {
// force "all" to the back
return mfs.Migrations[i].DBType != "all"
}
return mfs.Migrations[i].Version < mfs.Migrations[j].Version
}

func (mfs DownMigrations) Less(i, j int) bool {
if mfs.Migrations[i].Version == mfs.Migrations[j].Version {
// force "all" to the back
return mfs.Migrations[i].DBType != "all"
}
return mfs.Migrations[i].Version > mfs.Migrations[j].Version
}
88 changes: 55 additions & 33 deletions migration_info_test.go
Expand Up @@ -8,43 +8,65 @@ import (
)

func TestSortingMigrations(t *testing.T) {
t.Run("case=enforces precedence for specific migrations", func(t *testing.T) {
migrations := Migrations{
{
Version: "1",
DBType: "all",
},
{
Version: "1",
DBType: "postgres",
},
{
Version: "2",
DBType: "cockroach",
},
{
Version: "2",
DBType: "all",
},
{
Version: "3",
DBType: "all",
},
{
Version: "3",
DBType: "mysql",
},
examples := Migrations{
{
Version: "1",
DBType: "all",
},
{
Version: "1",
DBType: "postgres",
},
{
Version: "2",
DBType: "cockroach",
},
{
Version: "2",
DBType: "all",
},
{
Version: "3",
DBType: "all",
},
{
Version: "3",
DBType: "mysql",
},
}

t.Run("case=enforces precedence for specific up migrations", func(t *testing.T) {
migrations := make(Migrations, len(examples))
copy(migrations, examples)

expectedOrder := Migrations{
examples[1],
examples[0],
examples[2],
examples[3],
examples[5],
examples[4],
}

sort.Sort(UpMigrations{migrations})

assert.Equal(t, expectedOrder, migrations)
})

t.Run("case=enforces precedence for specific down migrations", func(t *testing.T) {
migrations := make(Migrations, len(examples))
copy(migrations, examples)

expectedOrder := Migrations{
migrations[1],
migrations[0],
migrations[2],
migrations[3],
migrations[5],
migrations[4],
examples[5],
examples[4],
examples[2],
examples[3],
examples[1],
examples[0],
}

sort.Sort(migrations)
sort.Sort(DownMigrations{migrations})

assert.Equal(t, expectedOrder, migrations)
})
Expand Down
35 changes: 16 additions & 19 deletions migrator.go
Expand Up @@ -23,10 +23,6 @@ var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz
func NewMigrator(c *Connection) Migrator {
return Migrator{
Connection: c,
Migrations: map[string]Migrations{
"up": {},
"down": {},
},
}
}

Expand All @@ -35,9 +31,10 @@ func NewMigrator(c *Connection) Migrator {
// When building a new migration system, you should embed this
// type into your migrator.
type Migrator struct {
Connection *Connection
SchemaPath string
Migrations map[string]Migrations
Connection *Connection
SchemaPath string
UpMigrations UpMigrations
DownMigrations DownMigrations
}

func (m Migrator) migrationIsCompatible(d dialect, mi Migration) bool {
Expand All @@ -53,10 +50,10 @@ func (m Migrator) UpLogOnly() error {
c := m.Connection
return m.exec(func() error {
mtn := c.MigrationTableName()
mfs := m.Migrations["up"]
mfs := m.UpMigrations
sort.Sort(mfs)
return c.Transaction(func(tx *Connection) error {
for _, mi := range mfs {
for _, mi := range mfs.Migrations {
if !m.migrationIsCompatible(c.Dialect, mi) {
continue
}
Expand Down Expand Up @@ -89,12 +86,12 @@ func (m Migrator) UpTo(step int) (applied int, err error) {
c := m.Connection
err = m.exec(func() error {
mtn := c.MigrationTableName()
mfs := m.Migrations["up"]
mfs := m.UpMigrations
mfs.Filter(func(mf Migration) bool {
return m.migrationIsCompatible(c.Dialect, mf)
})
sort.Sort(mfs)
for _, mi := range mfs {
for _, mi := range mfs.Migrations {
exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
if err != nil {
return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
Expand Down Expand Up @@ -139,20 +136,20 @@ func (m Migrator) Down(step int) error {
if err != nil {
return errors.Wrap(err, "migration down: unable count existing migration")
}
mfs := m.Migrations["down"]
mfs := m.DownMigrations
mfs.Filter(func(mf Migration) bool {
return m.migrationIsCompatible(c.Dialect, mf)
})
sort.Sort(sort.Reverse(mfs))
sort.Sort(mfs)
// skip all ran migration
if len(mfs) > count {
mfs = mfs[len(mfs)-count:]
if len(mfs.Migrations) > count {
mfs.Migrations = mfs.Migrations[len(mfs.Migrations)-count:]
}
// run only required steps
if step > 0 && len(mfs) >= step {
mfs = mfs[:step]
if step > 0 && len(mfs.Migrations) >= step {
mfs.Migrations = mfs.Migrations[:step]
}
for _, mi := range mfs {
for _, mi := range mfs.Migrations {
exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
if err != nil {
return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
Expand Down Expand Up @@ -228,7 +225,7 @@ func (m Migrator) Status(out io.Writer) error {
}
w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent)
_, _ = fmt.Fprintln(w, "Version\tName\tStatus\t")
for _, mf := range m.Migrations["up"] {
for _, mf := range m.UpMigrations.Migrations {
exists, err := m.Connection.Where("version = ?", mf.Version).Exists(m.Connection.MigrationTableName())
if err != nil {
return errors.Wrapf(err, "problem with migration")
Expand Down

0 comments on commit bb7527e

Please sign in to comment.