From cc152e31fec3b54146ad4e62aa34b08cd3f403a6 Mon Sep 17 00:00:00 2001 From: Jim van Kleef Date: Thu, 9 Jun 2022 01:30:28 +0200 Subject: [PATCH] also check function literals, fixes #19 --- tenv.go | 78 ++++++++++++++++++++++------------------ testdata/src/a/a_test.go | 12 +++++++ 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/tenv.go b/tenv.go index 9cdeb69..9fd7872 100644 --- a/tenv.go +++ b/tenv.go @@ -34,49 +34,58 @@ func run(pass *analysis.Pass) (interface{}, error) { inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) nodeFilter := []ast.Node{ - (*ast.File)(nil), + (*ast.FuncDecl)(nil), + (*ast.FuncLit)(nil), } inspect.Preorder(nodeFilter, func(n ast.Node) { switch n := n.(type) { - case *ast.File: - for _, decl := range n.Decls { - - funcDecl, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - checkFunc(pass, funcDecl, pass.Fset.File(n.Pos()).Name()) - } + case *ast.FuncDecl: + checkFuncDecl(pass, n, pass.Fset.File(n.Pos()).Name()) + case *ast.FuncLit: + checkFuncLit(pass, n, pass.Fset.File(n.Pos()).Name()) } }) return nil, nil } -func checkFunc(pass *analysis.Pass, n *ast.FuncDecl, fileName string) { - argName, ok := targetRunner(n, fileName) - if ok { - for _, stmt := range n.Body.List { - switch stmt := stmt.(type) { - case *ast.ExprStmt: - if !checkExprStmt(pass, stmt, n, argName) { - continue - } - case *ast.IfStmt: - if !checkIfStmt(pass, stmt, n, argName) { - continue - } - case *ast.AssignStmt: - if !checkAssignStmt(pass, stmt, n, argName) { - continue - } +func checkFuncDecl(pass *analysis.Pass, f *ast.FuncDecl, fileName string) { + argName, ok := targetRunner(f.Type.Params.List, fileName) + if !ok { + return + } + checkStmts(pass, f.Body.List, f.Name.Name, argName) +} + +func checkFuncLit(pass *analysis.Pass, f *ast.FuncLit, fileName string) { + argName, ok := targetRunner(f.Type.Params.List, fileName) + if !ok { + return + } + checkStmts(pass, f.Body.List, "function literal", argName) +} + +func checkStmts(pass *analysis.Pass, stmts []ast.Stmt, funcName, argName string) { + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.ExprStmt: + if !checkExprStmt(pass, stmt, funcName, argName) { + continue + } + case *ast.IfStmt: + if !checkIfStmt(pass, stmt, funcName, argName) { + continue + } + case *ast.AssignStmt: + if !checkAssignStmt(pass, stmt, funcName, argName) { + continue } } } } -func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, argName string) bool { +func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, funcName, argName string) bool { callExpr, ok := stmt.X.(*ast.CallExpr) if !ok { return false @@ -94,12 +103,12 @@ func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, arg if argName == "" { argName = "testing" } - pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name) + pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName) } return true } -func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName string) bool { +func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, funcName, argName string) bool { assignStmt, ok := stmt.Init.(*ast.AssignStmt) if !ok { return false @@ -121,12 +130,12 @@ func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName if argName == "" { argName = "testing" } - pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name) + pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName) } return true } -func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl, argName string) bool { +func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, funcName, argName string) bool { rhs, ok := stmt.Rhs[0].(*ast.CallExpr) if !ok { return false @@ -144,13 +153,12 @@ func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl, if argName == "" { argName = "testing" } - pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name) + pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName) } return true } -func targetRunner(funcDecl *ast.FuncDecl, fileName string) (string, bool) { - params := funcDecl.Type.Params.List +func targetRunner(params []*ast.Field, fileName string) (string, bool) { for _, p := range params { switch typ := p.Type.(type) { case *ast.StarExpr: diff --git a/testdata/src/a/a_test.go b/testdata/src/a/a_test.go index dd385ec..0be79cf 100644 --- a/testdata/src/a/a_test.go +++ b/testdata/src/a/a_test.go @@ -55,3 +55,15 @@ func FuzzF(f *testing.F) { _ = err } } + +func TestFunctionLiteral(t *testing.T) { + testsetup() + t.Run("test", func(t *testing.T) { + os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal" + err := os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal" + _ = err + if err := os.Setenv("a", "b"); err != nil { // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal" + _ = err + } + }) +}