diff --git a/rules/sql.go b/rules/sql.go index 6f4dbf1264..ee99737d64 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -15,9 +15,9 @@ package rules import ( + "fmt" "go/ast" "regexp" - "strings" "github.com/securego/gosec/v2" ) @@ -30,6 +30,51 @@ type sqlStatement struct { patterns []*regexp.Regexp } +var sqlCallIdents = map[string]map[string]int{ + "*database/sql.DB": { + "Exec": 0, + "ExecContext": 1, + "Query": 0, + "QueryContext": 1, + "QueryRow": 0, + "QueryRowContext": 1, + "Prepare": 0, + "PrepareContext": 1, + }, + "*database/sql.Tx": { + "Exec": 0, + "ExecContext": 1, + "Query": 0, + "QueryContext": 1, + "QueryRow": 0, + "QueryRowContext": 1, + "Prepare": 0, + "PrepareContext": 1, + }, +} + +// findQueryArg locates the argument taking raw SQL +func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) { + typeName, fnName, err := gosec.GetCallInfo(call, ctx) + if err != nil { + return nil, err + } + i := -1 + if ni, ok := sqlCallIdents[typeName]; ok { + if i, ok = ni[fnName]; !ok { + i = -1 + } + } + if i == -1 { + return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName) + } + if i >= len(call.Args) { + return nil, nil + } + query := call.Args[i] + return query, nil +} + func (s *sqlStatement) ID() string { return s.MetaData.ID } @@ -69,16 +114,10 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool { // checkQuery verifies if the query parameters is a string concatenation func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) { - _, fnName, err := gosec.GetCallInfo(call, ctx) + query, err := findQueryArg(call, ctx) if err != nil { return nil, err } - var query ast.Node - if strings.HasSuffix(fnName, "Context") { - query = call.Args[1] - } else { - query = call.Args[0] - } if be, ok := query.(*ast.BinaryExpr); ok { operands := gosec.GetBinaryExprOperands(be) @@ -137,8 +176,11 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { }, } - rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") - rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") + for s, si := range sqlCallIdents { + for i := range si { + rule.Add(s, i) + } + } return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} } @@ -171,16 +213,10 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool { } func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) { - _, fnName, err := gosec.GetCallInfo(call, ctx) + query, err := findQueryArg(call, ctx) if err != nil { return nil, err } - var query ast.Node - if strings.HasSuffix(fnName, "Context") { - query = call.Args[1] - } else { - query = call.Args[0] - } if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil { decl := ident.Obj.Decl @@ -306,8 +342,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { }, }, } - rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") - rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") + for s, si := range sqlCallIdents { + for i := range si { + rule.Add(s, i) + } + } rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.noIssue.AddAll("os", "Stdout", "Stderr") rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")