Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "cannot start a transaction within a transaction" issue (#764) #765

Merged
merged 2 commits into from Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 16 additions & 5 deletions sqlite3.go
Expand Up @@ -1894,6 +1894,14 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), list)
}

func isInterruptErr(err error) bool {
sqliteErr, ok := err.(Error)
if ok {
return sqliteErr.Code == ErrInterrupt
}
return false
}

// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if ctx.Done() == nil {
Expand All @@ -1909,19 +1917,22 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
var rv result
select {
case rv := <-resultCh:
return rv.r, rv.err
case rv = <-resultCh:
case <-ctx.Done():
select {
case <-resultCh: // no need to interrupt
case rv = <-resultCh: // no need to interrupt, operation completed in db
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
C.sqlite3_interrupt(s.c.db)
<-resultCh // ensure goroutine completed
rv = <-resultCh // wait for goroutine completed
if isInterruptErr(rv.err) {
return nil, ctx.Err()
}
}
return nil, ctx.Err()
}
return rv.r, rv.err
}

func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
Expand Down
74 changes: 74 additions & 0 deletions sqlite3_go113_test.go
@@ -0,0 +1,74 @@
// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

// +build go1.13,cgo

package sqlite3

import (
"context"
"database/sql"
"database/sql/driver"
"os"
"testing"
)

func TestBeginTxCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}

db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)

defer db.Close()
initDatabase(t, db, 100)

// create several go-routines to expose racy issue
for i := 0; i < 1000; i++ {
func() {
ctx, cancel := context.WithCancel(context.Background())
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()

err = conn.Raw(func(driverConn interface{}) error {
d, ok := driverConn.(driver.ConnBeginTx)
if !ok {
t.Fatal("unexpected: wrong type")
}

go cancel() // make it cancel concurrently with exec("BEGIN");
tx, err := d.BeginTx(ctx, driver.TxOptions{})
switch err {
case nil:
switch err := tx.Rollback(); err {
case nil, sql.ErrTxDone:
default:
return err
}
case context.Canceled:
default:
// must not fail with "cannot start a transaction within a transaction"
return err
}
return nil
})
if err != nil {
t.Fatal(err)
}
}()
}
}
40 changes: 39 additions & 1 deletion sqlite3_go18_test.go
Expand Up @@ -136,6 +136,44 @@ func TestShortTimeout(t *testing.T) {
}
}

func TestExecContextCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}

defer db.Close()

ts := time.Now()
initDatabase(t, db, 1000)
spent := time.Since(ts)
if spent < 100*time.Millisecond {
t.Skip("test will be too racy, as ExecContext below will be too fast.")
}

// expected to be extremely slow query
q := `
INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data)
SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data
FROM test_table t1 LEFT OUTER JOIN test_table t2`
// expect query above take ~ same time as setup above
ctx, cancel := context.WithTimeout(context.Background(), spent/2)
defer cancel()
ts = time.Now()
r, err := db.ExecContext(ctx, q)
// racy check
if r != nil {
n, err := r.RowsAffected()
t.Log(n, err, time.Since(ts))
}
if err != context.DeadlineExceeded {
t.Fatal(err, ctx.Err())
}
}

func TestQueryRowContextCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)
Expand Down Expand Up @@ -191,7 +229,7 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
testCtx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < 50; i++ {
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
Expand Down