From 844d7fde16e0936dbaea1745bf2b1b7bffca73b7 Mon Sep 17 00:00:00 2001 From: Andrii Zavorotnii Date: Thu, 5 Sep 2019 15:26:33 -0700 Subject: [PATCH] Fix context cancellation racy handling [why] Context cancellation goroutine is not in sync with Next() method lifetime. It leads to sql.ErrNoRows instead of context.Canceled often (easy to reproduce). It leads to interruption of next query executed on same connection (harder to reproduce). [how] Do query in goroutine, wait when interruption done. [testing] Add unit test that reproduces error cases. --- sqlite3.go | 89 ++++++++++++++++++++++++++------------------ sqlite3_go18_test.go | 83 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 36 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 4000173c..5fb8423b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -328,7 +328,7 @@ type SQLiteRows struct { decltype []string cls bool closed bool - done chan struct{} + ctx context.Context // no better alternative to pass context into Next() method } type functionInfo struct { @@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, decltype: nil, cls: s.cls, closed: false, - done: make(chan struct{}), - } - - if ctxdone := ctx.Done(); ctxdone != nil { - go func(db *C.sqlite3) { - select { - case <-ctxdone: - select { - case <-rows.done: - default: - C.sqlite3_interrupt(db) - rows.Close() - } - case <-rows.done: - } - }(s.c.db) + ctx: ctx, } return rows, nil @@ -1890,28 +1875,40 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { } func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { + if ctx.Done() == nil { + return s.execSync(args) + } + + type result struct { + r driver.Result + err error + } + resultCh := make(chan result) + go func() { + r, err := s.execSync(args) + resultCh <- result{r, err} + }() + select { + case rv := <- resultCh: + return rv.r, rv.err + case <-ctx.Done(): + select { + case <-resultCh: // no need to interrupt + default: + C.sqlite3_interrupt(s.c.db) + <-resultCh // ensure goroutine completed + } + return nil, ctx.Err() + } +} + +func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) { if err := s.bind(args); err != nil { C.sqlite3_reset(s.s) C.sqlite3_clear_bindings(s.s) return nil, err } - if ctxdone := ctx.Done(); ctxdone != nil { - done := make(chan struct{}) - defer close(done) - go func(db *C.sqlite3) { - select { - case <-done: - case <-ctxdone: - select { - case <-done: - default: - C.sqlite3_interrupt(db) - } - } - }(s.c.db) - } - var rowid, changes C.longlong rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { @@ -1932,9 +1929,6 @@ func (rc *SQLiteRows) Close() error { return nil } rc.closed = true - if rc.done != nil { - close(rc.done) - } if rc.cls { rc.s.mu.Unlock() return rc.s.Close() @@ -1980,8 +1974,31 @@ func (rc *SQLiteRows) DeclTypes() []string { // Next move cursor to next. func (rc *SQLiteRows) Next(dest []driver.Value) error { + if rc.ctx.Done() == nil { + return rc.nextSync(dest) + } + resultCh := make(chan error) + go func() { + resultCh <- rc.nextSync(dest) + }() + select { + case err := <- resultCh: + return err + case <-rc.ctx.Done(): + select { + case <-resultCh: // no need to interrupt + default: + C.sqlite3_interrupt(rc.s.c.db) + <-resultCh // ensure goroutine completed + } + return rc.ctx.Err() + } +} + +func (rc *SQLiteRows) nextSync(dest []driver.Value) error { rc.s.mu.Lock() defer rc.s.mu.Unlock() + if rc.s.closed { return io.EOF } diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index c9e79e7a..92f0e796 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "math/rand" "os" + "sync" "testing" "time" ) @@ -135,6 +136,88 @@ func TestShortTimeout(t *testing.T) { } } +func TestQueryRowContextCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 100) + + const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` + var keyID string + unexpectedErrors := make(map[string]int) + for i := 0; i < 10000; i++ { + ctx, cancel := context.WithCancel(context.Background()) + row := db.QueryRowContext(ctx, query) + + cancel() + // it is fine to get "nil" as context cancellation can be handled with delay + if err := row.Scan(&keyID); err != nil && err != context.Canceled { + unexpectedErrors[err.Error()]++ + } + } + for errText, count := range unexpectedErrors { + t.Error(errText, count) + } +} + +func TestQueryRowContextCancelParallel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + defer db.Close() + initDatabase(t, db, 100) + + const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` + wg := sync.WaitGroup{} + defer wg.Wait() + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + var keyID string + for { + select { + case <-testCtx.Done(): + return + default: + } + ctx, cancel := context.WithCancel(context.Background()) + row := db.QueryRowContext(ctx, query) + + cancel() + _ = row.Scan(&keyID) // see TestQueryRowContextCancel + } + }() + } + + var keyID string + for i := 0; i < 10000; i++ { + // NOTE: testCtx is not cancelled during query execution + row := db.QueryRowContext(testCtx, query) + + if err := row.Scan(&keyID); err != nil { + t.Fatal(i, err) + } + } +} + func TestExecCancel(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil {