Skip to content

Commit

Permalink
sqlite3: handle trailing comments and multiple SQL statements in Queries
Browse files Browse the repository at this point in the history
This commit fixes *SQLiteConn.Query to properly handle trailing comments
after a SQL query statement. Previously, trailing comments could lead to
an infinite loop.

It also changes Query to error if the provided SQL statement contains
multiple queries ("SELECT 1; SELECT 2;") - previously only the last
query was executed ("SELECT 1; SELECT 2;" would yield only 2).

This may be a breaking change as previously: Query consumed all of its
args - despite only using the last query (Query now only uses the args
required to satisfy the first query and errors if there is a mismatch);
Query used only the last query and there may be code using this library
that depends on this behavior.

Personally, I believe the behavior introduced by this commit is correct
and any code relying on the prior undocumented behavior incorrect, but
it could still be a break.
  • Loading branch information
charlievieth committed Apr 20, 2023
1 parent 5880fdc commit e430326
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 106 deletions.
92 changes: 36 additions & 56 deletions sqlite3.go
Expand Up @@ -31,7 +31,6 @@ package sqlite3
#endif
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#ifdef __CYGWIN__
# include <errno.h>
Expand Down Expand Up @@ -80,16 +79,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
return rv;
}
static const char *
_trim_leading_spaces(const char *str) {
if (str) {
while (isspace(*str)) {
str++;
}
}
return str;
}
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
Expand All @@ -110,11 +99,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
}
#else
Expand All @@ -137,12 +122,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
}
#endif
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
Expand Down Expand Up @@ -938,46 +920,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
op := pquery // original pointer
defer C.free(unsafe.Pointer(op))

var stmtArgs []driver.NamedValue
var tail *C.char
s := new(SQLiteStmt) // escapes to the heap so reuse it
start := 0
for {
*s = SQLiteStmt{c: c, cls: true} // reset
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
s := &SQLiteStmt{c: c, cls: true} // TODO: delay allocating this
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
if s.s == nil {
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
}
na := s.NumInput()
if n := len(args); n != na {
s.finalize()
if n < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
}
rows, err := s.query(ctx, args)
if err != nil && err != driver.ErrSkip {
s.finalize() // WARN
return rows, err
}

na := s.NumInput()
if len(args)-start < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
}
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
rows, err := s.query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.finalize()
return rows, err
// Consume the rest of the query
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
var stmt *C.sqlite3_stmt
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
if rv != C.SQLITE_OK {
rows.Close()
return nil, c.lastError()
}
start += na
if tail == nil || *tail == '\000' {
return rows, nil
if stmt != nil {
rows.Close()
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
}
rows.Close()
s.finalize()
pquery = tail
}

return rows, nil
}

// Begin transaction.
Expand Down Expand Up @@ -2029,7 +2009,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.query(context.Background(), list)
}

func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
Expand Down
182 changes: 132 additions & 50 deletions sqlite3_test.go
Expand Up @@ -18,6 +18,7 @@ import (
"math/rand"
"net/url"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
Expand Down Expand Up @@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
defer db.Close()

_, err = db.Exec(`
create table foo (id integer); -- one comment
insert into foo(id) values(?);
insert into foo(id) values(?);
insert into foo(id) values(?); -- another comment
CREATE TABLE foo (id INTEGER); -- one comment
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?); -- another comment
`, 1, 2, 3)
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
}

func TestQueryer(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()

_, err = db.Exec(`
create table foo (id integer);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
if seed {
if _, err := db.Exec(`create table foo (id integer);`); err != nil {
t.Fatal(err)
}
_, err := db.Exec(`
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
`, 3, 2, 1)
if err != nil {
t.Fatal(err)
}
}

_, err = db.Exec(`
insert into foo(id) values(?);
insert into foo(id) values(?);
insert into foo(id) values(?);
`, 3, 2, 1)
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
rows, err := db.Query(`
select id from foo order by id;
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
defer rows.Close()
n := 0
for rows.Next() {
var id int
err = rows.Scan(&id)
// Capture panic so tests can continue
defer func() {
if e := recover(); e != nil {
buf := make([]byte, 32*1024)
n := runtime.Stack(buf, false)
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
}
}()
test(t, db)
}

func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
var values []interface{}
testQuery(t, true, func(t *testing.T, db *sql.DB) {
rows, err := db.Query(query, args...)
if err != nil {
t.Error("Failed to db.Query:", err)
t.Fatal(err)
}
if id != n + 1 {
t.Error("Failed to db.Query: not matched results")
if rows == nil {
t.Fatal("nil rows")
}
n = n + 1
for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
var v interface{}
if err := rows.Scan(&v); err != nil {
t.Fatal(err)
}
values = append(values, v)
}
if err := rows.Err(); err != nil {
t.Fatal(err)
}
if err := rows.Close(); err != nil {
t.Fatal(err)
}
})
return values
}

func TestQuery(t *testing.T) {
queries := []struct {
query string
args []interface{}
}{
{"SELECT id FROM foo ORDER BY id;", nil},
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},

// Comments
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
{
`-- FOO
SELECT id FROM foo ORDER BY id; -- BAR
/* BAZ */`,
nil,
},
}
if err := rows.Err(); err != nil {
t.Errorf("Post-scan failed: %v\n", err)
want := []interface{}{
int64(1),
int64(2),
int64(3),
}
for _, q := range queries {
t.Run("", func(t *testing.T) {
got := testQueryValues(t, q.query, q.args...)
if !reflect.DeepEqual(got, want) {
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
}
})
}
if n != 3 {
t.Errorf("Expected 3 rows but retrieved %v", n)
}

func TestQueryNoSQL(t *testing.T) {
got := testQueryValues(t, "")
if got != nil {
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
}
}

func testQueryError(t *testing.T, query string, args ...interface{}) {
testQuery(t, true, func(t *testing.T, db *sql.DB) {
rows, err := db.Query(query, args...)
if err == nil {
t.Error("Expected an error got:", err)
}
if rows != nil {
t.Error("Returned rows should be nil on error!")
// Attempt to iterate over rows to make sure they don't panic.
for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
}
if err := rows.Err(); err != nil {
t.Error(err)
}
rows.Close()
}
})
}

func TestQueryNotEnoughArgs(t *testing.T) {
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
}

func TestQueryTooManyArgs(t *testing.T) {
// TODO: test error message / kind
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
}

func TestQueryMultipleStatements(t *testing.T) {
testQueryError(t, "SELECT 1; SELECT 2;")
}

func TestQueryInvalidTable(t *testing.T) {
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
}

func TestStress(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
Expand Down Expand Up @@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkRows", F: benchmarkRows},
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
}

func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
Expand Down Expand Up @@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
}
}
}

func benchmarkQueryStep(b *testing.B) {
var i int
for n := 0; n < b.N; n++ {
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
b.Fatal(err)
}
}
}

0 comments on commit e430326

Please sign in to comment.