diff --git a/sqlite3.go b/sqlite3.go index 7f0e7c00..3b5b633b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -864,6 +864,14 @@ func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) { if _, err := c.exec(ctx, c.txlock, nil); err != nil { + select { + case <-ctx.Done(): + // context cancellation can be handled between c.txlock completes inside DB and we notice its result. + // in such case we may leave transaction opened, but exec() returns context cancellation error. + // Rollback will make sure we don't leave connection in that state. It is no-op otherwise. + _, _ = c.exec(context.Background(), "ROLLBACK", nil) + default: + } return nil, err } return &SQLiteTx{c}, nil @@ -1895,13 +1903,14 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result return rv.r, rv.err case <-ctx.Done(): select { - case <-resultCh: // no need to interrupt + case rv := <-resultCh: // no need to interrupt, operation completed in db + return rv.r, rv.err default: // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. C.sqlite3_interrupt(s.c.db) - <-resultCh // ensure goroutine completed + <-resultCh // wait for goroutine completed + return nil, ctx.Err() } - return nil, ctx.Err() } } @@ -1992,7 +2001,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { resultCh <- rc.nextSyncLocked(dest) }() select { - case err := <- resultCh: + case err := <-resultCh: return err case <-rc.ctx.Done(): select { diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index cfc89b04..3742079c 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -136,6 +136,44 @@ func TestShortTimeout(t *testing.T) { } } +func TestBeginTxCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 100) + + wg := sync.WaitGroup{} + // create many go-routines to expose racy issue + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithCancel(context.Background()) + go cancel() // make it cancel concurrently with exec("BEGIN"); + tx, err := db.BeginTx(ctx, nil) + switch err { + case nil: + switch err := tx.Rollback(); err { + case nil, sql.ErrTxDone: + default: + t.Error(err) + } + case context.Canceled: + default: + // must not fail with "cannot start a transaction within a transaction" + t.Error(err) + } + }() + } + wg.Wait() +} + func TestQueryRowContextCancel(t *testing.T) { srcTempFilename := TempFilename(t) defer os.Remove(srcTempFilename)