diff --git a/rules/sql.go b/rules/sql.go index 844eaf5d3e..e9cb41b292 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -261,6 +261,19 @@ func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro switch stmt := n.(type) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { + if call, ok := expr.(*ast.CallExpr); ok { + selector, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + sqlQueryCall, ok := selector.X.(*ast.CallExpr) + if ok && s.ContainsCallExpr(sqlQueryCall, ctx) != nil { + issue, err := s.checkQuery(sqlQueryCall, ctx) + if err == nil && issue != nil { + return issue, err + } + } + } if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil { return s.checkQuery(sqlQueryCall, ctx) } diff --git a/testutils/source.go b/testutils/source.go index 106d8a2b05..a6f2af83cd 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -1189,6 +1189,72 @@ func main(){ panic(err) } defer rows.Close() +}`}, 1, gosec.NewConfig()}, {[]string{` +// Format string with \n\r +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM foo where\nname = '%s'", os.Args[1]) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +}`}, 1, gosec.NewConfig()}, {[]string{` +// SQLI by db.Query(some).Scan(&other) +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + var name string + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT name FROM users where id = '%s'", os.Args[1]) + row := db.QueryRow(q) + err = row.Scan(&name) + if err != nil { + panic(err) + } + defer db.Close() +}`}, 1, gosec.NewConfig()}, {[]string{` +// SQLI by db.Query(some).Scan(&other) +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + var name string + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT name FROM users where id = '%s'", os.Args[1]) + err = db.QueryRow(q).Scan(&name) + if err != nil { + panic(err) + } + defer db.Close() }`}, 1, gosec.NewConfig()}, }