Skip to content

Commit

Permalink
Merge pull request #744 from azavorotnii/ctx_cancel
Browse files Browse the repository at this point in the history
Fix context cancellation racy handling
  • Loading branch information
mattn committed Nov 18, 2019
2 parents 0cf797e + 7e1a61d commit 590d44c
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 37 deletions.
96 changes: 59 additions & 37 deletions sqlite3.go
Expand Up @@ -328,7 +328,7 @@ type SQLiteRows struct {
decltype []string
cls bool
closed bool
done chan struct{}
ctx context.Context // no better alternative to pass context into Next() method
}

type functionInfo struct {
Expand Down Expand Up @@ -1847,22 +1847,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil,
cls: s.cls,
closed: false,
done: make(chan struct{}),
}

if ctxdone := ctx.Done(); ctxdone != nil {
go func(db *C.sqlite3) {
select {
case <-ctxdone:
select {
case <-rows.done:
default:
C.sqlite3_interrupt(db)
rows.Close()
}
case <-rows.done:
}
}(s.c.db)
ctx: ctx,
}

return rows, nil
Expand Down Expand Up @@ -1890,29 +1875,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), list)
}

// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if ctx.Done() == nil {
return s.execSync(args)
}

type result struct {
r driver.Result
err error
}
resultCh := make(chan result)
go func() {
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
select {
case rv := <- resultCh:
return rv.r, rv.err
case <-ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
C.sqlite3_interrupt(s.c.db)
<-resultCh // ensure goroutine completed
}
return nil, ctx.Err()
}
}

func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}

if ctxdone := ctx.Done(); ctxdone != nil {
done := make(chan struct{})
defer close(done)
go func(db *C.sqlite3) {
select {
case <-done:
case <-ctxdone:
select {
case <-done:
default:
C.sqlite3_interrupt(db)
}
}
}(s.c.db)
}

var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
Expand All @@ -1933,9 +1932,6 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.closed = true
if rc.done != nil {
close(rc.done)
}
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
Expand Down Expand Up @@ -1979,13 +1975,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.declTypes()
}

// Next move cursor to next.
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()

if rc.s.closed {
return io.EOF
}

if rc.ctx.Done() == nil {
return rc.nextSyncLocked(dest)
}
resultCh := make(chan error)
go func() {
resultCh <- rc.nextSyncLocked(dest)
}()
select {
case err := <- resultCh:
return err
case <-rc.ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
}
return rc.ctx.Err()
}
}

// nextSyncLocked moves cursor to next; must be called with locked mutex.
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
rv := C._sqlite3_step_internal(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
Expand Down
88 changes: 88 additions & 0 deletions sqlite3_go18_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"io/ioutil"
"math/rand"
"os"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
}
}

func TestQueryRowContextCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}
defer db.Close()
initDatabase(t, db, 100)

const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
var keyID string
unexpectedErrors := make(map[string]int)
for i := 0; i < 10000; i++ {
ctx, cancel := context.WithCancel(context.Background())
row := db.QueryRowContext(ctx, query)

cancel()
// it is fine to get "nil" as context cancellation can be handled with delay
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
if err.Error() == "sql: Rows are closed" {
// see https://github.com/golang/go/issues/24431
// fixed in 1.11.1 to properly return context error
continue
}
unexpectedErrors[err.Error()]++
}
}
for errText, count := range unexpectedErrors {
t.Error(errText, count)
}
}

func TestQueryRowContextCancelParallel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)

defer db.Close()
initDatabase(t, db, 100)

const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
wg := sync.WaitGroup{}
defer wg.Wait()

testCtx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()

var keyID string
for {
select {
case <-testCtx.Done():
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
row := db.QueryRowContext(ctx, query)

cancel()
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
}
}()
}

var keyID string
for i := 0; i < 10000; i++ {
// note that testCtx is not cancelled during query execution
row := db.QueryRowContext(testCtx, query)

if err := row.Scan(&keyID); err != nil {
t.Fatal(i, err)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down

0 comments on commit 590d44c

Please sign in to comment.