Skip to content

Commit

Permalink
dialect/sql/schema: disable foreign keys before opening a transaction (
Browse files Browse the repository at this point in the history
…#2966)

* dialect/sql/schema: disable foreign keys before opening a transaction

* dialect/sql/schema: disable foreign keys before opening a transaction

* fix tests

* add test for bug

* apply CR
  • Loading branch information
masseelch committed Sep 28, 2022
1 parent e02622a commit c41d223
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 36 deletions.
8 changes: 4 additions & 4 deletions dialect/sql/schema/atlas.go
Expand Up @@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
a.sqlDialect = nil
a.atDriver = nil
}()
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
if a.universalID {
Expand Down Expand Up @@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
}
}
defer func() { a.sqlDialect = nil }()
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
if err := a.sqlDialect.init(ctx, tx); err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions dialect/sql/schema/migrate.go
Expand Up @@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
}

func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
if err := m.init(ctx); err != nil {
return err
}
tx, err := m.Tx(ctx)
if err != nil {
return err
}
if err := m.init(ctx, tx); err != nil {
return rollback(tx, err)
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return rollback(tx, err)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create table %q: %w", t.Name, err)
}
// If global unique identifier is enabled and it's not
// If global unique identifier is enabled, and it's not
// a relation table, allocate a range for the table pk.
if m.universalID && len(t.PrimaryKey) == 1 {
if err := m.allocPKRange(ctx, tx, t); err != nil {
Expand Down Expand Up @@ -606,7 +606,7 @@ func indexOf(a []string, s string) int {
type sqlDialect interface {
atBuilder
dialect.Driver
init(context.Context, dialect.ExecQuerier) error
init(context.Context) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/mysql.go
Expand Up @@ -29,9 +29,9 @@ type MySQL struct {
}

// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
func (d *MySQL) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()
Expand Down
2 changes: 1 addition & 1 deletion dialect/sql/schema/mysql_test.go
Expand Up @@ -1375,9 +1375,9 @@ type mysqlMock struct {
}

func (m mysqlMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version))
m.ExpectBegin()
}

func (m mysqlMock) tableExists(table string, exists bool) {
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/postgres.go
Expand Up @@ -29,9 +29,9 @@ type Postgres struct {

// init loads the Postgres version from the database for later use in the migration process.
// It returns an error if the server version is lower than v10.
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
func (d *Postgres) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
return fmt.Errorf("querying server version %w", err)
}
defer rows.Close()
Expand Down
2 changes: 1 addition & 1 deletion dialect/sql/schema/postgres_test.go
Expand Up @@ -1010,9 +1010,9 @@ type pgMock struct {
}

func (m pgMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW server_version_num")).
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version))
m.ExpectBegin()
}

func (m pgMock) tableExists(table string, exists bool) {
Expand Down
71 changes: 64 additions & 7 deletions dialect/sql/schema/sqlite.go
Expand Up @@ -6,6 +6,7 @@ package schema

import (
"context"
stdsql "database/sql"
"fmt"
"strconv"
"strings"
Expand All @@ -19,15 +20,51 @@ import (
"ariga.io/atlas/sql/sqlite"
)

// SQLite is an SQLite migration driver.
type SQLite struct {
dialect.Driver
WithForeignKeys bool
type (
// SQLite is an SQLite migration driver.
SQLite struct {
dialect.Driver
WithForeignKeys bool
}
// SQLiteTx implements dialect.Tx.
SQLiteTx struct {
dialect.Tx
commit func() error // Override Commit to toggle foreign keys back on after Commit.
rollback func() error // Override Rollback to toggle foreign keys back on after Rollback.
}
)

// Tx implements opens a transaction.
func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) {
db := &db{d}
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil {
return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err)
}
t, err := d.Driver.Tx(ctx)
if err != nil {
return nil, err
}
tx := &tx{t}
cm, err := sqlite.CommitFunc(ctx, db, tx, true)
if err != nil {
return nil, err
}
return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil
}

// Commit ensures foreign keys are toggled back on after commit.
func (tx *SQLiteTx) Commit() error {
return tx.commit()
}

// Rollback ensures foreign keys are toggled back on after rollback.
func (tx *SQLiteTx) Rollback() error {
return tx.rollback()
}

// init makes sure that foreign_keys support is enabled.
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
func (d *SQLite) init(ctx context.Context) error {
on, err := exist(ctx, d, "PRAGMA foreign_keys")
if err != nil {
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
}
Expand Down Expand Up @@ -453,9 +490,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro
return nil
}

func (SQLite) atTypeRangeSQL(ts ...string) string {
func (*SQLite) atTypeRangeSQL(ts ...string) string {
for i := range ts {
ts[i] = fmt.Sprintf("('%s')", ts[i])
}
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
}

type tx struct {
dialect.Tx
}

func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
return rows.ColumnScanner.(*stdsql.Rows), nil
}

func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
var r stdsql.Result
if err := tx.Exec(ctx, query, args, &r); err != nil {
return nil, err
}
return r, nil
}
36 changes: 24 additions & 12 deletions dialect/sql/schema/sqlite_test.go
Expand Up @@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) {
name: "no tables",
before: func(mock sqliteMock) {
mock.start()
mock.ExpectCommit()
mock.commit()
},
},
{
Expand All @@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("users", false)
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
}
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
}
Expand All @@ -450,9 +450,21 @@ type sqliteMock struct {
}

func (m sqliteMock) start() {
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_keys").
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
m.ExpectExec("PRAGMA foreign_keys = off").
WillReturnResult(sqlmock.NewResult(0, 1))
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
}

func (m sqliteMock) commit() {
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
m.ExpectCommit()
m.ExpectExec("PRAGMA foreign_keys = on").
WillReturnResult(sqlmock.NewResult(0, 1))
}

func (m sqliteMock) tableExists(table string, exists bool) {
Expand Down
13 changes: 12 additions & 1 deletion entc/integration/integration_test.go
Expand Up @@ -1939,7 +1939,18 @@ func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
func NoSchemaChanges(t *testing.T, client *ent.Client) {
w := writerFunc(func(p []byte) (int, error) {
stmt := strings.Trim(string(p), "\n;")
if stmt != "BEGIN" && stmt != "COMMIT" {
ok := []string{"BEGIN", "COMMIT"}
if strings.Contains(t.Name(), "SQLite") {
ok = append(ok, "PRAGMA foreign_keys = off", "PRAGMA foreign_keys = on")
}
if !func() bool {
for _, s := range ok {
if s == stmt {
return true
}
}
return false
}() {
t.Errorf("expect no statement to execute. got: %q", stmt)
}
return len(p), nil
Expand Down

0 comments on commit c41d223

Please sign in to comment.