diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index 254cfb78cf..a34d5db50c 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er a.sqlDialect = nil a.atDriver = nil }() - if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil { + if err := a.sqlDialect.init(ctx); err != nil { return err } if a.universalID { @@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { } } defer func() { a.sqlDialect = nil }() + if err := a.sqlDialect.init(ctx); err != nil { + return err + } // Open a transaction for backwards compatibility, // even if the migration is not transactional. tx, err := a.sqlDialect.Tx(ctx) if err != nil { return err } - if err := a.sqlDialect.init(ctx, tx); err != nil { - return err - } a.atDriver, err = a.sqlDialect.atOpen(tx) if err != nil { return err diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 92c61e48f3..a6f0d12855 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { } func (m *Migrate) create(ctx context.Context, tables ...*Table) error { + if err := m.init(ctx); err != nil { + return err + } tx, err := m.Tx(ctx) if err != nil { return err } - if err := m.init(ctx, tx); err != nil { - return rollback(tx, err) - } if m.universalID { if err := m.types(ctx, tx); err != nil { return rollback(tx, err) @@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table) if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create table %q: %w", t.Name, err) } - // If global unique identifier is enabled and it's not + // If global unique identifier is enabled, and it's not // a relation table, allocate a range for the table pk. if m.universalID && len(t.PrimaryKey) == 1 { if err := m.allocPKRange(ctx, tx, t); err != nil { @@ -606,7 +606,7 @@ func indexOf(a []string, s string) int { type sqlDialect interface { atBuilder dialect.Driver - init(context.Context, dialect.ExecQuerier) error + init(context.Context) error table(context.Context, dialect.Tx, string) (*Table, error) tableExist(context.Context, dialect.ExecQuerier, string) (bool, error) fkExist(context.Context, dialect.Tx, string) (bool, error) diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index 67f9bb0122..5aef0e7296 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -29,9 +29,9 @@ type MySQL struct { } // init loads the MySQL version from the database for later use in the migration process. -func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error { +func (d *MySQL) init(ctx context.Context) error { rows := &sql.Rows{} - if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil { + if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil { return fmt.Errorf("mysql: querying mysql version %w", err) } defer rows.Close() diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index e048a697ff..e93eff6056 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -1375,9 +1375,9 @@ type mysqlMock struct { } func (m mysqlMock) start(version string) { - m.ExpectBegin() m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version)) + m.ExpectBegin() } func (m mysqlMock) tableExists(table string, exists bool) { diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index 140b3f5a3b..530dcce071 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -29,9 +29,9 @@ type Postgres struct { // init loads the Postgres version from the database for later use in the migration process. // It returns an error if the server version is lower than v10. -func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error { +func (d *Postgres) init(ctx context.Context) error { rows := &sql.Rows{} - if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil { + if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil { return fmt.Errorf("querying server version %w", err) } defer rows.Close() diff --git a/dialect/sql/schema/postgres_test.go b/dialect/sql/schema/postgres_test.go index d7bdd60b26..69b3468511 100644 --- a/dialect/sql/schema/postgres_test.go +++ b/dialect/sql/schema/postgres_test.go @@ -1010,9 +1010,9 @@ type pgMock struct { } func (m pgMock) start(version string) { - m.ExpectBegin() m.ExpectQuery(escape("SHOW server_version_num")). WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version)) + m.ExpectBegin() } func (m pgMock) tableExists(table string, exists bool) { diff --git a/dialect/sql/schema/sqlite.go b/dialect/sql/schema/sqlite.go index d26ee32f35..775ab68522 100644 --- a/dialect/sql/schema/sqlite.go +++ b/dialect/sql/schema/sqlite.go @@ -6,6 +6,7 @@ package schema import ( "context" + stdsql "database/sql" "fmt" "strconv" "strings" @@ -19,15 +20,51 @@ import ( "ariga.io/atlas/sql/sqlite" ) -// SQLite is an SQLite migration driver. -type SQLite struct { - dialect.Driver - WithForeignKeys bool +type ( + // SQLite is an SQLite migration driver. + SQLite struct { + dialect.Driver + WithForeignKeys bool + } + // SQLiteTx implements dialect.Tx. + SQLiteTx struct { + dialect.Tx + commit func() error // Override Commit to toggle foreign keys back on after Commit. + rollback func() error // Override Rollback to toggle foreign keys back on after Rollback. + } +) + +// Tx implements opens a transaction. +func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) { + db := &db{d} + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil { + return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err) + } + t, err := d.Driver.Tx(ctx) + if err != nil { + return nil, err + } + tx := &tx{t} + cm, err := sqlite.CommitFunc(ctx, db, tx, true) + if err != nil { + return nil, err + } + return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil +} + +// Commit ensures foreign keys are toggled back on after commit. +func (tx *SQLiteTx) Commit() error { + return tx.commit() +} + +// Rollback ensures foreign keys are toggled back on after rollback. +func (tx *SQLiteTx) Rollback() error { + return tx.rollback() } // init makes sure that foreign_keys support is enabled. -func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error { - on, err := exist(ctx, tx, "PRAGMA foreign_keys") +func (d *SQLite) init(ctx context.Context) error { + on, err := exist(ctx, d, "PRAGMA foreign_keys") if err != nil { return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err) } @@ -453,9 +490,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro return nil } -func (SQLite) atTypeRangeSQL(ts ...string) string { +func (*SQLite) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", ")) } + +type tx struct { + dialect.Tx +} + +func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { + rows := &sql.Rows{} + if err := tx.Query(ctx, query, args, rows); err != nil { + return nil, err + } + return rows.ColumnScanner.(*stdsql.Rows), nil +} + +func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { + var r stdsql.Result + if err := tx.Exec(ctx, query, args, &r); err != nil { + return nil, err + } + return r, nil +} diff --git a/dialect/sql/schema/sqlite_test.go b/dialect/sql/schema/sqlite_test.go index c59d2db9ab..433bb8bf99 100644 --- a/dialect/sql/schema/sqlite_test.go +++ b/dialect/sql/schema/sqlite_test.go @@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) { name: "no tables", before: func(mock sqliteMock) { mock.start() - mock.ExpectCommit() + mock.commit() }, }, { @@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) { mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) { mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) { mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) { mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))). WillReturnResult(sqlmock.NewResult(0, 1)) } - mock.ExpectCommit() + mock.commit() }, }, { @@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) { WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) { mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, { @@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) { mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.commit() }, }, } @@ -450,9 +450,21 @@ type sqliteMock struct { } func (m sqliteMock) start() { - m.ExpectBegin() m.ExpectQuery("PRAGMA foreign_keys"). WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1)) + m.ExpectExec("PRAGMA foreign_keys = off"). + WillReturnResult(sqlmock.NewResult(0, 1)) + m.ExpectBegin() + m.ExpectQuery("PRAGMA foreign_key_check"). + WillReturnRows(sqlmock.NewRows([]string{})) // empty +} + +func (m sqliteMock) commit() { + m.ExpectQuery("PRAGMA foreign_key_check"). + WillReturnRows(sqlmock.NewRows([]string{})) // empty + m.ExpectCommit() + m.ExpectExec("PRAGMA foreign_keys = on"). + WillReturnResult(sqlmock.NewResult(0, 1)) } func (m sqliteMock) tableExists(table string, exists bool) { diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index d171b4a778..520e2bf830 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -1939,7 +1939,18 @@ func (f writerFunc) Write(p []byte) (int, error) { return f(p) } func NoSchemaChanges(t *testing.T, client *ent.Client) { w := writerFunc(func(p []byte) (int, error) { stmt := strings.Trim(string(p), "\n;") - if stmt != "BEGIN" && stmt != "COMMIT" { + ok := []string{"BEGIN", "COMMIT"} + if strings.Contains(t.Name(), "SQLite") { + ok = append(ok, "PRAGMA foreign_keys = off", "PRAGMA foreign_keys = on") + } + if !func() bool { + for _, s := range ok { + if s == stmt { + return true + } + } + return false + }() { t.Errorf("expect no statement to execute. got: %q", stmt) } return len(p), nil diff --git a/entc/integration/migrate/migrate_test.go b/entc/integration/migrate/migrate_test.go index 93ea490c97..44be1ced0e 100644 --- a/entc/integration/migrate/migrate_test.go +++ b/entc/integration/migrate/migrate_test.go @@ -39,6 +39,7 @@ import ( "ariga.io/atlas/sql/postgres" atlas "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqltool" + "entgo.io/ent/schema/field" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -224,6 +225,75 @@ func TestSQLite(t *testing.T) { Versioned(t, vdrv, "sqlite3://file?mode=memory&cache=shared&_fk=1", versioned.NewClient(versioned.Driver(vdrv))) } +// https://github.com/ent/ent/issues/2954 +func TestSQLite_ForeignKeyTx(t *testing.T) { + var ( + usersColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "name", Type: field.TypeString}, + } + usersTable = &schema.Table{ + Name: "users", + Columns: usersColumns, + PrimaryKey: []*schema.Column{usersColumns[0]}, + } + userFollowingColumns = []*schema.Column{ + {Name: "user_id", Type: field.TypeInt}, + {Name: "follower_id", Type: field.TypeInt}, + } + userFollowingTable = &schema.Table{ + Name: "user_following", + Columns: userFollowingColumns, + PrimaryKey: []*schema.Column{userFollowingColumns[0], userFollowingColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "user_following_user_id", + Columns: []*schema.Column{userFollowingColumns[0]}, + RefColumns: []*schema.Column{usersColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "user_following_follower_id", + Columns: []*schema.Column{userFollowingColumns[1]}, + RefColumns: []*schema.Column{usersColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } + ) + userFollowingTable.ForeignKeys[0].RefTable = usersTable + userFollowingTable.ForeignKeys[1].RefTable = usersTable + + drv, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + defer drv.Close() + ctx := context.Background() + m, err := schema.NewMigrate(drv) + require.NoError(t, err) + + // Migrate once. + require.NoError(t, m.Create(ctx, usersTable, userFollowingTable)) + + // Add data. + var exec = func(stmt string) { + _, err := drv.DB().ExecContext(ctx, stmt) + require.NoError(t, err) + } + exec("INSERT INTO `users` (`id`, `name`) VALUES (1, 'Ariel'), (2, 'Jannik');") + exec("INSERT INTO `user_following` (`user_id`, `follower_id`) VALUES (1,2), (2,1);") + var n int + require.NoError(t, drv.DB().QueryRow("SELECT COUNT(*) FROM `user_following`").Scan(&n)) + require.Equal(t, 2, n) + + // Modify a column in the users table. + usersTable.Columns[1].Nullable = true + require.NoError(t, m.Create(ctx, usersTable, userFollowingTable)) + + // Ensure the data in the join table does still exist. + require.NoError(t, drv.DB().QueryRow("SELECT COUNT(*) FROM `user_following`").Scan(&n)) + require.Equal(t, 2, n) +} + func TestStorageKey(t *testing.T) { require.Equal(t, "user_pet_id", migratev2.PetsTable.ForeignKeys[0].Symbol) require.Equal(t, "user_friend_id1", migratev2.FriendsTable.ForeignKeys[0].Symbol) diff --git a/go.mod b/go.mod index 976b412c9b..a29e7df07c 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module entgo.io/ent go 1.19 require ( - ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9 + ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/go-openapi/inflect v0.19.0 github.com/go-sql-driver/mysql v1.6.0 diff --git a/go.sum b/go.sum index 62c43a3f88..83496d98bf 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ ariga.io/atlas v0.6.2-0.20220819114704-2060066abac7 h1:qhVEfrV5Z9XZyQJxgogBq6c2p ariga.io/atlas v0.6.2-0.20220819114704-2060066abac7/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9 h1:hb7cCS3+idkvWRxKIiH0pBiyO9tJ9gRiecY+ohA//VY= ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= +ariga.io/atlas v0.7.2-0.20220927063044-1d12e0ad4813 h1:KdsIBtO4ZpBAAWdg2fEpXuV5YKGpxBTWO7A9V5MM+/Y= +ariga.io/atlas v0.7.2-0.20220927063044-1d12e0ad4813/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= +ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a h1:6/nt4DODfgxzHTTg3tYy7YkVzruGQGZ/kRvXpA45KUo= +ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=