Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialect/sql/schema: disable foreign keys before opening a transaction #2966

Merged
merged 5 commits into from Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions dialect/sql/schema/atlas.go
Expand Up @@ -164,7 +164,7 @@ func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) er
a.sqlDialect = nil
a.atDriver = nil
}()
if err := a.sqlDialect.init(ctx, a.sqlDialect); err != nil {
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
if a.universalID {
Expand Down Expand Up @@ -656,15 +656,15 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
}
}
defer func() { a.sqlDialect = nil }()
if err := a.sqlDialect.init(ctx); err != nil {
return err
}
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
if err := a.sqlDialect.init(ctx, tx); err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions dialect/sql/schema/migrate.go
Expand Up @@ -139,13 +139,13 @@ func (m *Migrate) Create(ctx context.Context, tables ...*Table) error {
}

func (m *Migrate) create(ctx context.Context, tables ...*Table) error {
if err := m.init(ctx); err != nil {
return err
}
tx, err := m.Tx(ctx)
if err != nil {
return err
}
if err := m.init(ctx, tx); err != nil {
return rollback(tx, err)
}
if m.universalID {
if err := m.types(ctx, tx); err != nil {
return rollback(tx, err)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table)
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("create table %q: %w", t.Name, err)
}
// If global unique identifier is enabled and it's not
// If global unique identifier is enabled, and it's not
// a relation table, allocate a range for the table pk.
if m.universalID && len(t.PrimaryKey) == 1 {
if err := m.allocPKRange(ctx, tx, t); err != nil {
Expand Down Expand Up @@ -606,7 +606,7 @@ func indexOf(a []string, s string) int {
type sqlDialect interface {
atBuilder
dialect.Driver
init(context.Context, dialect.ExecQuerier) error
init(context.Context) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.ExecQuerier, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/mysql.go
Expand Up @@ -29,9 +29,9 @@ type MySQL struct {
}

// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, conn dialect.ExecQuerier) error {
func (d *MySQL) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := conn.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %w", err)
}
defer rows.Close()
Expand Down
2 changes: 1 addition & 1 deletion dialect/sql/schema/mysql_test.go
Expand Up @@ -1375,9 +1375,9 @@ type mysqlMock struct {
}

func (m mysqlMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")).
WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version))
m.ExpectBegin()
}

func (m mysqlMock) tableExists(table string, exists bool) {
Expand Down
4 changes: 2 additions & 2 deletions dialect/sql/schema/postgres.go
Expand Up @@ -29,9 +29,9 @@ type Postgres struct {

// init loads the Postgres version from the database for later use in the migration process.
// It returns an error if the server version is lower than v10.
func (d *Postgres) init(ctx context.Context, tx dialect.ExecQuerier) error {
func (d *Postgres) init(ctx context.Context) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil {
return fmt.Errorf("querying server version %w", err)
}
defer rows.Close()
Expand Down
2 changes: 1 addition & 1 deletion dialect/sql/schema/postgres_test.go
Expand Up @@ -1010,9 +1010,9 @@ type pgMock struct {
}

func (m pgMock) start(version string) {
m.ExpectBegin()
m.ExpectQuery(escape("SHOW server_version_num")).
WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version))
m.ExpectBegin()
}

func (m pgMock) tableExists(table string, exists bool) {
Expand Down
71 changes: 64 additions & 7 deletions dialect/sql/schema/sqlite.go
Expand Up @@ -6,6 +6,7 @@ package schema

import (
"context"
masseelch marked this conversation as resolved.
Show resolved Hide resolved
stdsql "database/sql"
"fmt"
"strconv"
"strings"
Expand All @@ -19,15 +20,51 @@ import (
"ariga.io/atlas/sql/sqlite"
)

// SQLite is an SQLite migration driver.
type SQLite struct {
dialect.Driver
WithForeignKeys bool
type (
// SQLite is an SQLite migration driver.
SQLite struct {
dialect.Driver
WithForeignKeys bool
}
// SQLiteTx implements dialect.Tx.
SQLiteTx struct {
dialect.Tx
commit func() error // Override Commit to toggle foreign keys back on after Commit.
rollback func() error // Override Rollback to toggle foreign keys back on after Rollback.
}
)

// Tx implements opens a transaction.
func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) {
db := &db{d}
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil {
return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err)
}
t, err := d.Driver.Tx(ctx)
if err != nil {
return nil, err
}
tx := &tx{t}
cm, err := sqlite.CommitFunc(ctx, db, tx, true)
if err != nil {
return nil, err
}
return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil
}

// Commit ensures foreign keys are toggled back on after commit.
func (tx *SQLiteTx) Commit() error {
return tx.commit()
}

// Rollback ensures foreign keys are toggled back on after rollback.
func (tx *SQLiteTx) Rollback() error {
return tx.rollback()
}

// init makes sure that foreign_keys support is enabled.
func (d *SQLite) init(ctx context.Context, tx dialect.ExecQuerier) error {
on, err := exist(ctx, tx, "PRAGMA foreign_keys")
func (d *SQLite) init(ctx context.Context) error {
on, err := exist(ctx, d, "PRAGMA foreign_keys")
if err != nil {
return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err)
}
Expand Down Expand Up @@ -453,9 +490,29 @@ func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) erro
return nil
}

func (SQLite) atTypeRangeSQL(ts ...string) string {
func (*SQLite) atTypeRangeSQL(ts ...string) string {
for i := range ts {
ts[i] = fmt.Sprintf("('%s')", ts[i])
}
return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", "))
}

type tx struct {
dialect.Tx
}

func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, err
}
return rows.ColumnScanner.(*stdsql.Rows), nil
}

func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
var r stdsql.Result
if err := tx.Exec(ctx, query, args, &r); err != nil {
return nil, err
}
return r, nil
}
36 changes: 24 additions & 12 deletions dialect/sql/schema/sqlite_test.go
Expand Up @@ -47,7 +47,7 @@ func TestSQLite_Create(t *testing.T) {
name: "no tables",
before: func(mock sqliteMock) {
mock.start()
mock.ExpectCommit()
mock.commit()
},
},
{
Expand All @@ -73,7 +73,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("users", false)
mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestSQLite_Create(t *testing.T) {
mock.tableExists("pets", false)
mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))).
WillReturnResult(sqlmock.NewResult(0, 1))
}
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestSQLite_Create(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -389,7 +389,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
{
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")).
WithArgs("groups", 1<<32).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
mock.commit()
},
},
}
Expand All @@ -450,9 +450,21 @@ type sqliteMock struct {
}

func (m sqliteMock) start() {
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_keys").
WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1))
m.ExpectExec("PRAGMA foreign_keys = off").
WillReturnResult(sqlmock.NewResult(0, 1))
m.ExpectBegin()
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
}

func (m sqliteMock) commit() {
m.ExpectQuery("PRAGMA foreign_key_check").
WillReturnRows(sqlmock.NewRows([]string{})) // empty
m.ExpectCommit()
m.ExpectExec("PRAGMA foreign_keys = on").
WillReturnResult(sqlmock.NewResult(0, 1))
}

func (m sqliteMock) tableExists(table string, exists bool) {
Expand Down
13 changes: 12 additions & 1 deletion entc/integration/integration_test.go
Expand Up @@ -1939,7 +1939,18 @@ func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
func NoSchemaChanges(t *testing.T, client *ent.Client) {
w := writerFunc(func(p []byte) (int, error) {
stmt := strings.Trim(string(p), "\n;")
if stmt != "BEGIN" && stmt != "COMMIT" {
ok := []string{"BEGIN", "COMMIT"}
if strings.Contains(t.Name(), "SQLite") {
ok = append(ok, "PRAGMA foreign_keys = off", "PRAGMA foreign_keys = on")
}
if !func() bool {
for _, s := range ok {
if s == stmt {
return true
}
}
return false
}() {
t.Errorf("expect no statement to execute. got: %q", stmt)
}
return len(p), nil
Expand Down