Skip to content

Commit

Permalink
Add receiver nil checks for Conn methods
Browse files Browse the repository at this point in the history
Fixes #17
  • Loading branch information
zombiezen committed Aug 15, 2021
1 parent af0400e commit 2ed604b
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 12 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`crawshaw.io/sqlite` API, but they can all be migrated automatically with the
migration tool. ([#16](https://github.com/zombiezen/go-sqlite/issues/16))

### Changed

- Method calls to a `nil` `*sqlite.Conn` will return an error rather than panic.
([#17](https://github.com/zombiezen/go-sqlite/issues/17))

### Removed

- Removed `OpenFlags` that are only used for VFS.
Expand Down
3 changes: 3 additions & 0 deletions auth.go
Expand Up @@ -25,6 +25,9 @@ type Authorizer interface {
// SetAuthorizer registers an authorizer for the database connection.
// SetAuthorizer(nil) clears any authorizer previously set.
func (c *Conn) SetAuthorizer(auth Authorizer) error {
if c == nil {
return fmt.Errorf("sqlite: set authorizer: nil connection")
}
if auth == nil {
c.releaseAuthorizer()
res := ResultCode(lib.Xsqlite3_set_authorizer(c.tls, c.conn, 0, 0))
Expand Down
19 changes: 11 additions & 8 deletions blob.go
Expand Up @@ -38,6 +38,9 @@ var (
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(dbn, table, column string, row int64, write bool) (*Blob, error) {
if c == nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q: nil connection", table, column)
}
return c.openBlob(dbn, table, column, row, write)
}

Expand All @@ -52,7 +55,7 @@ func (c *Conn) openBlob(dbn, table, column string, row int64, write bool) (_ *Bl
var err error
cdb, err = libc.CString(dbn)
if err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
defer libc.Xfree(c.tls, cdb)
}
Expand All @@ -62,7 +65,7 @@ func (c *Conn) openBlob(dbn, table, column string, row int64, write bool) (_ *Bl
}
buf, err := malloc(c.tls, blobBufSize)
if err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
defer func() {
if err != nil {
Expand All @@ -72,23 +75,23 @@ func (c *Conn) openBlob(dbn, table, column string, row int64, write bool) (_ *Bl

ctable, err := libc.CString(table)
if err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
defer libc.Xfree(c.tls, ctable)
ccolumn, err := libc.CString(column)
if err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
defer libc.Xfree(c.tls, ccolumn)

blobPtrPtr, err := malloc(c.tls, ptrSize)
if err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
defer libc.Xfree(c.tls, blobPtrPtr)
for {
if err := c.interrupted(); err != nil {
return nil, fmt.Errorf("sqlite: open blob %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
res := ResultCode(lib.Xsqlite3_blob_open(
c.tls,
Expand All @@ -103,7 +106,7 @@ func (c *Conn) openBlob(dbn, table, column string, row int64, write bool) (_ *Bl
switch res {
case ResultLockedSharedCache:
if err := reserr(waitForUnlockNotify(c.tls, c.conn, c.unlockNote)); err != nil {
return nil, fmt.Errorf("sqlite: open %q.%q blob: %w", table, column, err)
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, err)
}
// loop
case ResultOK:
Expand All @@ -115,7 +118,7 @@ func (c *Conn) openBlob(dbn, table, column string, row int64, write bool) (_ *Bl
size: lib.Xsqlite3_blob_bytes(c.tls, blobPtr),
}, nil
default:
return nil, fmt.Errorf("sqlite: open %q.%q blob: %w", table, column, c.extreserr(res))
return nil, fmt.Errorf("sqlite: open blob %q.%q: %w", table, column, c.extreserr(res))
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions func.go
Expand Up @@ -433,6 +433,9 @@ type FunctionImpl struct {
//
// https://sqlite.org/appfunc.html
func (c *Conn) CreateFunction(name string, impl *FunctionImpl) error {
if c == nil {
return fmt.Errorf("sqlite: create function: nil connection")
}
if name == "" {
return fmt.Errorf("sqlite: create function: no name provided")
}
Expand Down
9 changes: 9 additions & 0 deletions session.go
Expand Up @@ -45,6 +45,9 @@ type Session struct {
//
// https://www.sqlite.org/session/sqlite3session_create.html
func (c *Conn) CreateSession(db string) (*Session, error) {
if c == nil {
return nil, fmt.Errorf("sqlite: create session: nil connection")
}
var cdb uintptr
if db == "" || db == "main" {
cdb = mainCString
Expand Down Expand Up @@ -210,6 +213,9 @@ func (s *Session) WritePatchset(w io.Writer) error {
// resolve the conflict. See https://www.sqlite.org/session/sqlite3changeset_apply.html
// for more details.
func (c *Conn) ApplyChangeset(r io.Reader, filterFn func(tableName string) bool, conflictFn ConflictHandler) error {
if c == nil {
return fmt.Errorf("sqlite: apply changeset: nil connection")
}
if conflictFn == nil {
return fmt.Errorf("sqlite: apply changeset: no conflict handler provided")
}
Expand Down Expand Up @@ -293,6 +299,9 @@ func changesetApplyConflict(tls *libc.TLS, pCtx uintptr, eConflict int32, p uint
// ApplyInverseChangeset applies the inverse of a changeset to the database.
// See ApplyChangeset and InvertChangeset for more details.
func (c *Conn) ApplyInverseChangeset(r io.Reader, filterFn func(tableName string) bool, conflictFn ConflictHandler) error {
if c == nil {
return fmt.Errorf("sqlite: apply changeset: nil connection")
}
pr, pw := io.Pipe()
go func() {
err := InvertChangeset(pw, pr)
Expand Down
39 changes: 35 additions & 4 deletions sqlite.go
Expand Up @@ -168,6 +168,9 @@ func openConn(path string, openFlags OpenFlags) (_ *Conn, err error) {
// Close closes the database connection using sqlite3_close and finalizes
// persistent prepared statements. https://www.sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
if c == nil {
return fmt.Errorf("sqlite: close: nil connection")
}
c.cancelInterrupt()
c.closed = true
for _, stmt := range c.stmts {
Expand All @@ -191,15 +194,20 @@ func (c *Conn) Close() error {
// AutocommitEnabled reports whether conn is in autocommit mode.
// https://sqlite.org/c3ref/get_autocommit.html
func (c *Conn) AutocommitEnabled() bool {
if c == nil {
return false
}
return lib.Xsqlite3_get_autocommit(c.tls, c.conn) != 0
}

// CheckReset reports whether any statement on this connection is in the process
// of returning results.
func (c *Conn) CheckReset() string {
for _, stmt := range c.stmts {
if stmt.lastHasRow {
return stmt.query
if c != nil {
for _, stmt := range c.stmts {
if stmt.lastHasRow {
return stmt.query
}
}
}
return ""
Expand All @@ -224,6 +232,9 @@ func (c *Conn) CheckReset() string {
//
// SetInterrupt returns the old doneCh assigned to the connection.
func (c *Conn) SetInterrupt(doneCh <-chan struct{}) (oldDoneCh <-chan struct{}) {
if c == nil {
return nil
}
if c.closed {
panic("sqlite.Conn is closed")
}
Expand Down Expand Up @@ -261,7 +272,9 @@ func (c *Conn) SetInterrupt(doneCh <-chan struct{}) (oldDoneCh <-chan struct{})
//
// https://www.sqlite.org/c3ref/busy_timeout.html
func (c *Conn) SetBusyTimeout(d time.Duration) {
lib.Xsqlite3_busy_timeout(c.tls, c.conn, int32(d/time.Millisecond))
if c != nil {
lib.Xsqlite3_busy_timeout(c.tls, c.conn, int32(d/time.Millisecond))
}
}

func (c *Conn) interrupted() error {
Expand Down Expand Up @@ -317,6 +330,9 @@ func (c *Conn) Prep(query string) *Stmt {
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) Prepare(query string) (*Stmt, error) {
if c == nil {
return nil, fmt.Errorf("sqlite: prepare %q: nil connection", query)
}
if stmt := c.stmts[query]; stmt != nil {
if err := stmt.Reset(); err != nil {
return nil, err
Expand Down Expand Up @@ -349,6 +365,9 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareTransient(query string) (stmt *Stmt, trailingBytes int, err error) {
if c == nil {
return nil, 0, fmt.Errorf("sqlite: prepare %q: nil connection", query)
}
// TODO(soon)
// if stmt != nil {
// runtime.SetFinalizer(stmt, func(stmt *Stmt) {
Expand Down Expand Up @@ -416,13 +435,19 @@ func (c *Conn) prepare(query string, flags uint32) (*Stmt, int, error) {
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int {
if c == nil {
return 0
}
return int(lib.Xsqlite3_changes(c.tls, c.conn))
}

// LastInsertRowID reports the rowid of the most recently successful INSERT.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
if c == nil {
return 0
}
return lib.Xsqlite3_last_insert_rowid(c.tls, c.conn)
}

Expand Down Expand Up @@ -1163,6 +1188,9 @@ func (limit Limit) String() string {
//
// https://sqlite.org/c3ref/limit.html
func (c *Conn) Limit(id Limit, value int32) int32 {
if c == nil {
return 0
}
return lib.Xsqlite3_limit(c.tls, c.conn, int32(id), int32(value))
}

Expand All @@ -1176,6 +1204,9 @@ func (c *Conn) Limit(id Limit, value int32) int32 {
// Writes to the sqlite_dbpage virtual table.
// Direct writes to shadow tables.
func (c *Conn) SetDefensive(enabled bool) error {
if c == nil {
return fmt.Errorf("sqlite: set defensive=%t: nil connection", enabled)
}
varArgs := libc.Xmalloc(c.tls, ptrSize)
if varArgs == 0 {
return fmt.Errorf("sqlite: set defensive=%t: cannot allocate memory", enabled)
Expand Down

0 comments on commit 2ed604b

Please sign in to comment.