Skip to content

Commit

Permalink
Fix context cancellation racy handling
Browse files Browse the repository at this point in the history
[why]
Context cancellation goroutine is not in sync with Next() method lifetime.
It leads to sql.ErrNoRows instead of context.Canceled often (easy to reproduce).
It leads to interruption of next query executed on same connection (harder to reproduce).

[how]
Do query in goroutine, wait when interruption done.

[testing]
Add unit test that reproduces error cases.
  • Loading branch information
azavorotnii committed Sep 6, 2019
1 parent d3c6909 commit 844d7fd
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 36 deletions.
89 changes: 53 additions & 36 deletions sqlite3.go
Expand Up @@ -328,7 +328,7 @@ type SQLiteRows struct {
decltype []string
cls bool
closed bool
done chan struct{}
ctx context.Context // no better alternative to pass context into Next() method
}

type functionInfo struct {
Expand Down Expand Up @@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil,
cls: s.cls,
closed: false,
done: make(chan struct{}),
}

if ctxdone := ctx.Done(); ctxdone != nil {
go func(db *C.sqlite3) {
select {
case <-ctxdone:
select {
case <-rows.done:
default:
C.sqlite3_interrupt(db)
rows.Close()
}
case <-rows.done:
}
}(s.c.db)
ctx: ctx,
}

return rows, nil
Expand Down Expand Up @@ -1890,28 +1875,40 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
}

func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if ctx.Done() == nil {
return s.execSync(args)
}

type result struct {
r driver.Result
err error
}
resultCh := make(chan result)
go func() {
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
select {
case rv := <- resultCh:
return rv.r, rv.err
case <-ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
C.sqlite3_interrupt(s.c.db)
<-resultCh // ensure goroutine completed
}
return nil, ctx.Err()
}
}

func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}

if ctxdone := ctx.Done(); ctxdone != nil {
done := make(chan struct{})
defer close(done)
go func(db *C.sqlite3) {
select {
case <-done:
case <-ctxdone:
select {
case <-done:
default:
C.sqlite3_interrupt(db)
}
}
}(s.c.db)
}

var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
Expand All @@ -1932,9 +1929,6 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.closed = true
if rc.done != nil {
close(rc.done)
}
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
Expand Down Expand Up @@ -1980,8 +1974,31 @@ func (rc *SQLiteRows) DeclTypes() []string {

// Next move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
if rc.ctx.Done() == nil {
return rc.nextSync(dest)
}
resultCh := make(chan error)
go func() {
resultCh <- rc.nextSync(dest)
}()
select {
case err := <- resultCh:
return err
case <-rc.ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
}
return rc.ctx.Err()
}
}

func (rc *SQLiteRows) nextSync(dest []driver.Value) error {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()

if rc.s.closed {
return io.EOF
}
Expand Down
83 changes: 83 additions & 0 deletions sqlite3_go18_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"io/ioutil"
"math/rand"
"os"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -135,6 +136,88 @@ func TestShortTimeout(t *testing.T) {
}
}

func TestQueryRowContextCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}
defer db.Close()
initDatabase(t, db, 100)

const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
var keyID string
unexpectedErrors := make(map[string]int)
for i := 0; i < 10000; i++ {
ctx, cancel := context.WithCancel(context.Background())
row := db.QueryRowContext(ctx, query)

cancel()
// it is fine to get "nil" as context cancellation can be handled with delay
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
unexpectedErrors[err.Error()]++
}
}
for errText, count := range unexpectedErrors {
t.Error(errText, count)
}
}

func TestQueryRowContextCancelParallel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)

db, err := sql.Open("sqlite3", srcTempFilename)
if err != nil {
t.Fatal(err)
}
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)

defer db.Close()
initDatabase(t, db, 100)

const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
wg := sync.WaitGroup{}
defer wg.Wait()

testCtx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()

var keyID string
for {
select {
case <-testCtx.Done():
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
row := db.QueryRowContext(ctx, query)

cancel()
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
}
}()
}

var keyID string
for i := 0; i < 10000; i++ {
// NOTE: testCtx is not cancelled during query execution
row := db.QueryRowContext(testCtx, query)

if err := row.Scan(&keyID); err != nil {
t.Fatal(i, err)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down

0 comments on commit 844d7fd

Please sign in to comment.