diff --git a/rules/bad_defer.go b/rules/bad_defer.go index f6ca0be81f..13b42070da 100644 --- a/rules/bad_defer.go +++ b/rules/bad_defer.go @@ -38,11 +38,10 @@ func contains(methods []string, method string) bool { func (r *badDefer) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { if deferStmt, ok := n.(*ast.DeferStmt); ok { for _, deferTyp := range r.types { - if issue := r.checkChild(n, c, deferStmt.Call, deferTyp); issue != nil { - return issue, nil - } - if issue := r.checkFunction(n, c, deferStmt, deferTyp); issue != nil { - return issue, nil + if typ, method, err := gosec.GetCallInfo(deferStmt.Call, c); err == nil { + if normalize(typ) == deferTyp.typ && contains(deferTyp.methods, method) { + return gosec.NewIssue(c, n, r.ID(), fmt.Sprintf(r.What, method, typ), r.Severity, r.Confidence), nil + } } } } @@ -50,42 +49,6 @@ func (r *badDefer) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { return nil, nil } -func (r *badDefer) checkChild(n ast.Node, c *gosec.Context, callExp *ast.CallExpr, deferTyp deferType) *gosec.Issue { - if typ, method, err := gosec.GetCallInfo(callExp, c); err == nil { - if normalize(typ) == deferTyp.typ && contains(deferTyp.methods, method) { - return gosec.NewIssue(c, n, r.ID(), fmt.Sprintf(r.What, method, typ), r.Severity, r.Confidence) - } - } - return nil -} - -func (r *badDefer) checkFunction(n ast.Node, c *gosec.Context, deferStmt *ast.DeferStmt, deferTyp deferType) *gosec.Issue { - if anonFunc, isAnonFunc := deferStmt.Call.Fun.(*ast.FuncLit); isAnonFunc { - for _, subElem := range anonFunc.Body.List { - if issue := r.checkStmt(n, c, subElem, deferTyp); issue != nil { - return issue - } - } - } - return nil -} - -func (r *badDefer) checkStmt(n ast.Node, c *gosec.Context, subElem ast.Stmt, deferTyp deferType) *gosec.Issue { - switch stmt := subElem.(type) { - case *ast.AssignStmt: - for _, rh := range stmt.Rhs { - if e, isCallExp := rh.(*ast.CallExpr); isCallExp { - return r.checkChild(n, c, e, deferTyp) - } - } - case *ast.IfStmt: - if s, is := stmt.Init.(*ast.AssignStmt); is { - return r.checkStmt(n, c, s, deferTyp) - } - } - return nil -} - // NewDeferredClosing detects unsafe defer of error returning methods func NewDeferredClosing(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { return &badDefer{ diff --git a/testutils/source.go b/testutils/source.go index b389db218c..b57d1d23b2 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -2192,120 +2192,37 @@ func main() { // SampleCodeG307 - Unsafe defer of os.Close SampleCodeG307 = []CodeSample{ {[]string{`package main - import ( + "bufio" "fmt" "io/ioutil" "os" ) - func check(e error) { if e != nil { panic(e) } } - func main() { - d1 := []byte("hello\ngo\n") err := ioutil.WriteFile("/tmp/dat1", d1, 0744) check(err) - allowed := ioutil.WriteFile("/tmp/dat1", d1, 0600) check(allowed) - f, err := os.Create("/tmp/dat2") check(err) - defer f.Close() - - d2 := []byte{115, 111, 109, 101, 10} - n2, err := f.Write(d2) - - defer check(err) - fmt.Printf("wrote %d bytes\n", n2) - -}`}, 1, gosec.NewConfig()}, - {[]string{`package main - -import ( - "fmt" - "io/ioutil" - "log" - "os" -) - -func check(e error) { - if e != nil { - panic(e) - } -} - -func main() { - - d1 := []byte("hello\ngo\n") - err := ioutil.WriteFile("/tmp/dat1", d1, 0744) - check(err) - - allowed := ioutil.WriteFile("/tmp/dat1", d1, 0600) - check(allowed) - - f, err := os.Create("/tmp/dat2") - check(err) - - defer func() { - if err := f.Close(); err != nil { - log.Println(err) - } - }() - d2 := []byte{115, 111, 109, 101, 10} n2, err := f.Write(d2) - defer check(err) fmt.Printf("wrote %d bytes\n", n2) - -}`}, 1, gosec.NewConfig()}, - {[]string{`package main - -import ( - "fmt" - "io/ioutil" - "log" - "os" -) - -func check(e error) { - if e != nil { - panic(e) - } -} - -func main() { - - d1 := []byte("hello\ngo\n") - err := ioutil.WriteFile("/tmp/dat1", d1, 0744) - check(err) - - allowed := ioutil.WriteFile("/tmp/dat1", d1, 0600) - check(allowed) - - f, err := os.Create("/tmp/dat2") - check(err) - - defer func() { - err := f.Close() - if err != nil { - log.Println(err) - } - }() - - d2 := []byte{115, 111, 109, 101, 10} - n2, err := f.Write(d2) - - defer check(err) - fmt.Printf("wrote %d bytes\n", n2) - + n3, err := f.WriteString("writes\n") + fmt.Printf("wrote %d bytes\n", n3) + f.Sync() + w := bufio.NewWriter(f) + n4, err := w.WriteString("buffered\n") + fmt.Printf("wrote %d bytes\n", n4) + w.Flush() }`}, 1, gosec.NewConfig()}, }