From 862b95943f99f3b40e317a79d41c27ac4b742011 Mon Sep 17 00:00:00 2001 From: Andrii Zavorotnii Date: Fri, 28 Aug 2020 08:43:21 -0700 Subject: [PATCH] Fix "cannot start a transaction within a transaction" issue (#764) (#765) * Fix "cannot start a transaction within a transaction" issue [why] If db.BeginTx(ctx, nil) context is cancelled too fast, "BEGIN" statement can be completed inside DB, but we still try to cancel it with sqlite3_interrupt. In such case we get context.Cancelled or context.DeadlineExceeded from exec(), but operation really completed. Connection returned into pool, and returns "cannot start a transaction within a transaction" error for next db.BeginTx() call. [how] Handle status code returned from cancelled operation. [testing] Added unit-test which reproduces issue. * Reduce TestQueryRowContextCancelParallel concurrency [why] Tests times out in travis-ci when run with -race option. --- sqlite3.go | 21 +++++++++--- sqlite3_go113_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++ sqlite3_go18_test.go | 40 ++++++++++++++++++++++- 3 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 sqlite3_go113_test.go diff --git a/sqlite3.go b/sqlite3.go index ababcfd2..63e1c4f7 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1918,6 +1918,14 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { return s.exec(context.Background(), list) } +func isInterruptErr(err error) bool { + sqliteErr, ok := err.(Error) + if ok { + return sqliteErr.Code == ErrInterrupt + } + return false +} + // exec executes a query that doesn't return rows. Attempts to honor context timeout. func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { if ctx.Done() == nil { @@ -1933,19 +1941,22 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result r, err := s.execSync(args) resultCh <- result{r, err} }() + var rv result select { - case rv := <-resultCh: - return rv.r, rv.err + case rv = <-resultCh: case <-ctx.Done(): select { - case <-resultCh: // no need to interrupt + case rv = <-resultCh: // no need to interrupt, operation completed in db default: // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. C.sqlite3_interrupt(s.c.db) - <-resultCh // ensure goroutine completed + rv = <-resultCh // wait for goroutine completed + if isInterruptErr(rv.err) { + return nil, ctx.Err() + } } - return nil, ctx.Err() } + return rv.r, rv.err } func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) { diff --git a/sqlite3_go113_test.go b/sqlite3_go113_test.go new file mode 100644 index 00000000..74036f81 --- /dev/null +++ b/sqlite3_go113_test.go @@ -0,0 +1,74 @@ +// Copyright (C) 2019 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build go1.13,cgo + +package sqlite3 + +import ( + "context" + "database/sql" + "database/sql/driver" + "os" + "testing" +) + +func TestBeginTxCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + defer db.Close() + initDatabase(t, db, 100) + + // create several go-routines to expose racy issue + for i := 0; i < 1000; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = conn.Raw(func(driverConn interface{}) error { + d, ok := driverConn.(driver.ConnBeginTx) + if !ok { + t.Fatal("unexpected: wrong type") + } + + go cancel() // make it cancel concurrently with exec("BEGIN"); + tx, err := d.BeginTx(ctx, driver.TxOptions{}) + switch err { + case nil: + switch err := tx.Rollback(); err { + case nil, sql.ErrTxDone: + default: + return err + } + case context.Canceled: + default: + // must not fail with "cannot start a transaction within a transaction" + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }() + } +} diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index cfc89b04..5ee3d811 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -136,6 +136,44 @@ func TestShortTimeout(t *testing.T) { } } +func TestExecContextCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + ts := time.Now() + initDatabase(t, db, 1000) + spent := time.Since(ts) + if spent < 100*time.Millisecond { + t.Skip("test will be too racy, as ExecContext below will be too fast.") + } + + // expected to be extremely slow query + q := ` +INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data) +SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data +FROM test_table t1 LEFT OUTER JOIN test_table t2` + // expect query above take ~ same time as setup above + ctx, cancel := context.WithTimeout(context.Background(), spent/2) + defer cancel() + ts = time.Now() + r, err := db.ExecContext(ctx, q) + // racy check + if r != nil { + n, err := r.RowsAffected() + t.Log(n, err, time.Since(ts)) + } + if err != context.DeadlineExceeded { + t.Fatal(err, ctx.Err()) + } +} + func TestQueryRowContextCancel(t *testing.T) { srcTempFilename := TempFilename(t) defer os.Remove(srcTempFilename) @@ -191,7 +229,7 @@ func TestQueryRowContextCancelParallel(t *testing.T) { testCtx, cancel := context.WithCancel(context.Background()) defer cancel() - for i := 0; i < 50; i++ { + for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done()