From 4ef63c9c0db77925ab91b95237f9e3802c4710a4 Mon Sep 17 00:00:00 2001 From: Joshua Hull Date: Fri, 2 Sep 2022 02:45:11 +0000 Subject: [PATCH] Rollback on constraint failure (#1071) Always rollback on a commit error --- sqlite3.go | 6 ++++-- sqlite3_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index e037857d..9c0f4d89 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -494,10 +494,12 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) { // Commit transaction. func (tx *SQLiteTx) Commit() error { _, err := tx.c.exec(context.Background(), "COMMIT", nil) - if err != nil && err.(Error).Code == C.SQLITE_BUSY { - // sqlite3 will leave the transaction open in this scenario. + if err != nil { + // sqlite3 may leave the transaction open in this scenario. // However, database/sql considers the transaction complete once we // return from Commit() - we must clean up to honour its semantics. + // We don't know if the ROLLBACK is strictly necessary, but according + // to sqlite's docs, there is no harm in calling ROLLBACK unnecessarily. tx.c.exec(context.Background(), "ROLLBACK", nil) } return err diff --git a/sqlite3_test.go b/sqlite3_test.go index 33d03fd4..326361ec 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -248,6 +248,43 @@ func TestForeignKeys(t *testing.T) { } } +func TestDeferredForeignKey(t *testing.T) { + fname := TempFilename(t) + uri := "file:" + fname + "?_foreign_keys=1" + db, err := sql.Open("sqlite3", uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + } + _, err = db.Exec("CREATE TABLE bar (id INTEGER PRIMARY KEY)") + if err != nil { + t.Errorf("failed creating tables: %v", err) + } + _, err = db.Exec("CREATE TABLE foo (bar_id INTEGER, FOREIGN KEY(bar_id) REFERENCES bar(id) DEFERRABLE INITIALLY DEFERRED)") + if err != nil { + t.Errorf("failed creating tables: %v", err) + } + tx, err := db.Begin() + if err != nil { + t.Errorf("Failed to begin transaction: %v", err) + } + _, err = tx.Exec("INSERT INTO foo (bar_id) VALUES (123)") + if err != nil { + t.Errorf("Failed to insert row: %v", err) + } + err = tx.Commit() + if err == nil { + t.Errorf("Expected an error: %v", err) + } + _, err = db.Begin() + if err != nil { + t.Errorf("Failed to begin transaction: %v", err) + } + + db.Close() + os.Remove(fname) +} + func TestRecursiveTriggers(t *testing.T) { cases := map[string]bool{ "?_recursive_triggers=1": true,