Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkers: add sqlQuery checker #932

Merged
merged 1 commit into from May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
164 changes: 164 additions & 0 deletions checkers/sqlQuery_checker.go
@@ -0,0 +1,164 @@
package checkers

import (
"go/ast"
"go/types"

"github.com/go-lintpack/lintpack"
"github.com/go-lintpack/lintpack/astwalk"
"github.com/go-toolsmith/astcast"
)

func init() {
var info lintpack.CheckerInfo
info.Name = "sqlQuery"
info.Tags = []string{"diagnostic", "experimental"}
info.Summary = "Detects issue in Query() and Exec() calls"
info.Before = `_, err := db.Query("UPDATE ...")`
info.After = `_, err := db.Exec("UPDATE ...")`

collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker {
return astwalk.WalkerForStmt(&sqlQueryChecker{ctx: ctx})
})
}

type sqlQueryChecker struct {
astwalk.WalkHandler
ctx *lintpack.CheckerContext
}

func (c *sqlQueryChecker) VisitStmt(stmt ast.Stmt) {
assign := astcast.ToAssignStmt(stmt)
if len(assign.Lhs) != 2 { // Query() has 2 return values.
return
}
if len(assign.Rhs) != 1 {
return
}

call := astcast.ToCallExpr(assign.Rhs[0])
funcExpr := astcast.ToSelectorExpr(call.Fun)
if !c.funcIsQuery(funcExpr) {
return
}

// If Query() is called, but first return value is ignored,
// there is no way to close/read the returned rows.
// This can cause a connection leak.
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name != "_" {
return
}

if c.typeHasExecMethod(c.ctx.TypesInfo.TypeOf(funcExpr.X)) {
c.warnAndSuggestExec(funcExpr)
} else {
c.warnRowsIgnored(funcExpr)
}
}

func (c *sqlQueryChecker) funcIsQuery(funcExpr *ast.SelectorExpr) bool {
switch funcExpr.Sel.Name {
case "Query", "QueryContext":
// Stdlib and friends.
case "Queryx", "QueryxContext":
// sqlx.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

default:
return false
}

// To avoid false positives (unrelated types can have Query method)
// check that the 1st returned type has Row-like name.
typ, ok := c.ctx.TypesInfo.TypeOf(funcExpr).Underlying().(*types.Signature)
if !ok || typ.Results() == nil || typ.Results().Len() != 2 {
return false
}
if !c.typeIsRowsLike(typ.Results().At(0).Type()) {
return false
}

return true
}

func (c *sqlQueryChecker) typeIsRowsLike(typ types.Type) bool {
switch typ := typ.(type) {
case *types.Pointer:
return c.typeIsRowsLike(typ.Elem())
case *types.Named:
return typ.Obj().Name() == "Rows"
default:
return false
}
}

func (c *sqlQueryChecker) funcIsExec(fn *types.Func) bool {
if fn.Name() != "Exec" {
return false
}

// Expect exactly 2 results.
sig := fn.Type().(*types.Signature)
if sig.Results() == nil || sig.Results().Len() != 2 {
return false
}

// Expect at least 1 param and it should be a string (query).
params := sig.Params()
if params == nil || params.Len() == 0 {
return false
}
if typ, ok := params.At(0).Type().(*types.Basic); !ok || typ.Kind() != types.String {
return false
}

return true
}

func (c *sqlQueryChecker) typeHasExecMethod(typ types.Type) bool {
switch typ := typ.(type) {
case *types.Struct:
for i := 0; i < typ.NumFields(); i++ {
if c.typeHasExecMethod(typ.Field(i).Type()) {
return true
}
}
case *types.Interface:
for i := 0; i < typ.NumMethods(); i++ {
if c.funcIsExec(typ.Method(i)) {
return true
}
}
case *types.Pointer:
return c.typeHasExecMethod(typ.Elem())
case *types.Named:
for i := 0; i < typ.NumMethods(); i++ {
if c.funcIsExec(typ.Method(i)) {
return true
}
}
switch ut := typ.Underlying().(type) {
case *types.Interface:
return c.typeHasExecMethod(ut)
case *types.Struct:
// Check embedded types.
for i := 0; i < ut.NumFields(); i++ {
field := ut.Field(i)
if !field.Embedded() {
continue
}
if c.typeHasExecMethod(field.Type()) {
return true
}
}
}
}

return false
}

func (c *sqlQueryChecker) warnAndSuggestExec(funcExpr *ast.SelectorExpr) {
c.ctx.Warn(funcExpr, "use %s.Exec() if returned result is not needed", funcExpr.X)
}

func (c *sqlQueryChecker) warnRowsIgnored(funcExpr *ast.SelectorExpr) {
c.ctx.Warn(funcExpr, "ignoring Query() rows result may lead to a connection leak")
}
32 changes: 32 additions & 0 deletions checkers/testdata/sqlQuery/negative_tests.go
@@ -0,0 +1,32 @@
package checker_test

import (
"database/sql"
)

func queryResultIsUsed(db *sql.DB, qe QueryExecer, mydb *myDatabase) {
const queryString = "SELECT * FROM users"

rows1, err := db.Query(queryString)
_ = rows1

rows2, err := qe.Query(queryString)
_ = rows2

rows3, err := mydb.Query(queryString)
_ = rows3

_ = err
}

func execIsUsed(db *sql.DB, qe QueryExecer, mydb *myDatabase) {
const queryString = "UPDATE users SET name = 'gopher'"

var err error

_, err = db.Exec(queryString)
_, err = qe.Exec(queryString)
_, err = mydb.Exec(queryString)

_ = err
}
59 changes: 59 additions & 0 deletions checkers/testdata/sqlQuery/positive_tests.go
@@ -0,0 +1,59 @@
package checker_test

import (
"database/sql"
)

type myDatabase struct {
*sql.DB
}

type Rows struct{}

type Row struct{}

type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Queryx(query string, args ...interface{}) (*Rows, error)
QueryRowx(query string, args ...interface{}) *Row
}

type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}

type QueryExecer interface {
Queryer
Execer
}

type QueryExecerAlias = QueryExecer

func resultIgnored(db *sql.DB, q Queryer, qe QueryExecer, qea QueryExecerAlias, mydb *myDatabase) {
const queryString = "UPDATE users SET name = 'gopher'"

var err error

/*! use db.Exec() if returned result is not needed */
_, err = db.Query(queryString)

/*! use qe.Exec() if returned result is not needed */
_, err = qe.Query(queryString)

/*! use qe.Exec() if returned result is not needed */
_, err = qe.Queryx(queryString)

/*! use mydb.Exec() if returned result is not needed */
_, err = mydb.Query(queryString)

/*! ignoring Query() rows result may lead to a connection leak */
_, err = q.Query(queryString)

/*! use qea.Exec() if returned result is not needed */
_, err = qea.Query(queryString)

/*! ignoring Query() rows result may lead to a connection leak */
_, err = q.Queryx(queryString)

_ = err
}