From e07e2916e729695ccbd917650f628f64b8ccfc3c Mon Sep 17 00:00:00 2001 From: kaiili <35690781+kaiili@users.noreply.github.com> Date: Sun, 16 Jan 2022 16:55:57 +0800 Subject: [PATCH 1/2] Add db.Exec and db.Prepare to the sql rule --- rules/sql.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rules/sql.go b/rules/sql.go index e9cb41b292..6f4dbf1264 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -137,8 +137,8 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { }, } - rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext") - rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext") + rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") + rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} } @@ -306,8 +306,8 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { }, }, } - rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext") - rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext") + rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") + rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.noIssue.AddAll("os", "Stdout", "Stderr") rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier") From 407051fd6f519baa3d328f9ee6f629ef1c1fd583 Mon Sep 17 00:00:00 2001 From: kaiili <35690781+kaiili@users.noreply.github.com> Date: Sun, 16 Jan 2022 16:59:37 +0800 Subject: [PATCH 2/2] add test cases for G201,G202 --- testutils/source.go | 162 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 1 deletion(-) diff --git a/testutils/source.go b/testutils/source.go index 3188187c3c..212d6a4b0b 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -1255,7 +1255,103 @@ func main() { panic(err) } defer db.Close() -}`}, 1, gosec.NewConfig()}, +}`}, 1, gosec.NewConfig()}, {[]string{` +// SQLI by db.Prepare(some) +package main + +import ( + "database/sql" + "fmt" + "log" + "os" +) + +const Table = "foo" + +func main() { + var album string + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT name FROM users where '%s' = ?", os.Args[1]) + stmt, err := db.Prepare(q) + if err != nil { + log.Fatal(err) + } + stmt.QueryRow(fmt.Sprintf("%s", os.Args[2])).Scan(&album) + if err != nil { + if err == sql.ErrNoRows { + log.Fatal(err) + } + } + defer stmt.Close() +} +`}, 1, gosec.NewConfig()}, {[]string{` +// SQLI by db.PrepareContext(some) +package main + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" +) + +const Table = "foo" + +func main() { + var album string + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT name FROM users where '%s' = ?", os.Args[1]) + stmt, err := db.PrepareContext(context.Background(), q) + if err != nil { + log.Fatal(err) + } + stmt.QueryRow(fmt.Sprintf("%s", os.Args[2])).Scan(&album) + if err != nil { + if err == sql.ErrNoRows { + log.Fatal(err) + } + } + defer stmt.Close() +} +`}, 1, gosec.NewConfig()}, {[]string{` +// false positive +package main + +import ( + "database/sql" + "fmt" + "log" + "os" +) + +const Table = "foo" + +func main() { + var album string + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + stmt, err := db.Prepare("SELECT * FROM album WHERE id = ?") + if err != nil { + log.Fatal(err) + } + stmt.QueryRow(fmt.Sprintf("%s", os.Args[1])).Scan(&album) + if err != nil { + if err == sql.ErrNoRows { + log.Fatal(err) + } + } + defer stmt.Close() +} +`}, 0, gosec.NewConfig()}, } // SampleCodeG202 - SQL query string building via string concatenation @@ -1431,6 +1527,70 @@ func main(){ } defer rows.Close() } +`}, 0, gosec.NewConfig()}, {[]string{` +// ExecContext match +package main + +import ( + "context" + "database/sql" + "fmt" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + result, err := db.ExecContext(context.Background(), "select * from foo where name = "+os.Args[1]) + if err != nil { + panic(err) + } + fmt.Println(result) +}`}, 1, gosec.NewConfig()}, {[]string{` +// Exec match +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + result, err := db.Exec("select * from foo where name = " + os.Args[1]) + if err != nil { + panic(err) + } + fmt.Println(result) +}`}, 1, gosec.NewConfig()}, {[]string{` +package main + +import ( + "database/sql" + "fmt" +) +const gender = "M" +const age = "32" + +var staticQuery = "SELECT * FROM foo WHERE age < " + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + result, err := db.Exec("SELECT * FROM foo WHERE gender = " + gender) + if err != nil { + panic(err) + } + fmt.Println(result) +} `}, 0, gosec.NewConfig()}, }