Skip to content

Commit

Permalink
dialect/sql/schema: disable foreign keys before opening a transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
masseelch committed Sep 27, 2022
1 parent 7ad7df2 commit 8b86528
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 33 deletions.
8 changes: 4 additions & 4 deletions dialect/sql/schema/atlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
a.sqlDialect = nil
a.atDriver = nil
}()
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
if a.universalID {
Expand Down Expand Up @@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
}
}
defer func() { a.sqlDialect = nil }()
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
if err := a.sqlDialect.init(ctx, tx); err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions dialect/sql/schema/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
}

func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
if err := m.init(ctx); err != nil {
return err
}
tx, err := m.Tx(ctx)
if err != nil {
return err
}
if err := m.init(ctx, tx); err != nil {
return rollback(tx, err)
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return rollback(tx, err)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create table %q: %w", t.Name, err)
}
// If global unique identifier is enabled and it's not
// If global unique identifier is enabled, and it's not
// a relation table, allocate a range for the table pk.
if m.universalID && len(t.PrimaryKey) == 1 {
if err := m.allocPKRange(ctx, tx, t); err != nil {
Expand Down Expand Up @@ -606,7 +606,7 @@ func indexOf(a []string, s string) int {
type sqlDialect interface {
atBuilder
dialect.Driver
init(context.Context, dialect.ExecQuerier) error
init(context.Context) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ type MySQL struct {
}

// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
func (d *MySQL) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ type Postgres struct {

// init loads the Postgres version from the database for later use in the migration process.
// It returns an error if the server version is lower than v10.
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
func (d *Postgres) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
return fmt.Errorf("querying server version %w", err)
}
defer rows.Close()
Expand Down
68 changes: 61 additions & 7 deletions dialect/sql/schema/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package schema

import (
"context"
gosql "database/sql"
"fmt"
"strconv"
"strings"
Expand All @@ -19,15 +20,48 @@ import (
"ariga.io/atlas/sql/sqlite"
)

// SQLite is an SQLite migration driver.
type SQLite struct {
dialect.Driver
WithForeignKeys bool
type (
// SQLite is an SQLite migration driver.
SQLite struct {
dialect.Driver
WithForeignKeys bool
}
// SQLiteTx implements dialect.Tx.
SQLiteTx struct {
dialect.Tx
commit func() error // Override Commit to toggle foreign keys back on after Commit.
rollback func() error // Override Rollback to toggle foreign keys back on after Rollback.
}
)

// Tx implements opens a transaction.
func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) {
db := &db{d}
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil {
return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err)
}
t, err := d.Driver.Tx(ctx)
if err != nil {
return nil, err
}
tx := &tx{ExecQuerier: t, Tx: t}
cm, err := sqlite.CommitFunc(ctx, db, tx, true)
return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil
}

// Commit ensures foreign keys are toggled back on after commit.
func (tx *SQLiteTx) Commit() error {
return tx.commit()
}

// Rollback ensures foreign keys are toggled back on after rollback.
func (tx *SQLiteTx) Rollback() error {
return tx.rollback()
}

// init makes sure that foreign_keys support is enabled.
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
func (d *SQLite) init(ctx context.Context) error {
on, err := exist(ctx, d, "PRAGMA foreign_keys")
if err != nil {
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
}
Expand Down Expand Up @@ -453,9 +487,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro
return nil
}

func (SQLite) atTypeRangeSQL(ts ...string) string {
func (*SQLite) atTypeRangeSQL(ts ...string) string {
for i := range ts {
ts[i] = fmt.Sprintf("('%s')", ts[i])
}
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
}

type tx struct {
dialect.Tx
}

func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*gosql.Rows, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
return rows.ColumnScanner.(*gosql.Rows), nil
}

func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (gosql.Result, error) {
var r gosql.Result
if err := tx.Exec(ctx, query, args, &r); err != nil {
return nil, err
}
return r, nil
}
44 changes: 32 additions & 12 deletions dialect/sql/schema/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) {
name: "no tables",
before: func(mock sqliteMock) {
mock.start()
mock.ExpectCommit()
mock.commit()
},
},
{
Expand All @@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("users", false)
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
}
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
}
Expand All @@ -450,9 +450,29 @@ type sqliteMock struct {
}

func (m sqliteMock) start() {
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_keys").
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
m.ExpectExec("PRAGMA foreign_keys = off").
WillReturnResult(sqlmock.NewResult(0, 1))
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
}

func (m sqliteMock) commit() {
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
m.ExpectCommit()
m.ExpectExec("PRAGMA foreign_keys = on").
WillReturnResult(sqlmock.NewResult(0, 1))
}

func (m sqliteMock) rollback() {
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
m.ExpectRollback()
m.ExpectExec("PRAGMA foreign_keys = on").
WillReturnResult(sqlmock.NewResult(0, 1))
}

func (m sqliteMock) tableExists(table string, exists bool) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module entgo.io/ent
go 1.19

require (
ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9
ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/go-openapi/inflect v0.19.0
github.com/go-sql-driver/mysql v1.6.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ ariga.io/atlas v0.6.2-0.20220819114704-2060066abac7 h1:qhVEfrV5Z9XZyQJxgogBq6c2p
ariga.io/atlas v0.6.2-0.20220819114704-2060066abac7/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE=
ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9 h1:hb7cCS3+idkvWRxKIiH0pBiyO9tJ9gRiecY+ohA//VY=
ariga.io/atlas v0.6.5-0.20220907173155-3332f3c1b8c9/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE=
ariga.io/atlas v0.7.2-0.20220927063044-1d12e0ad4813 h1:KdsIBtO4ZpBAAWdg2fEpXuV5YKGpxBTWO7A9V5MM+/Y=
ariga.io/atlas v0.7.2-0.20220927063044-1d12e0ad4813/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE=
ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a h1:6/nt4DODfgxzHTTg3tYy7YkVzruGQGZ/kRvXpA45KUo=
ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
Expand Down

0 comments on commit 8b86528

Please sign in to comment.