From 0326f2bffbd8adf4ac347390faacdba89f494bd6 Mon Sep 17 00:00:00 2001 From: Antony Dovgal Date: Tue, 4 Jan 2022 14:49:40 +0300 Subject: [PATCH] checkers: ignore labeled continue in select statement (#1158) Fixes #1130 --- checkers/internal/lintutil/astfind.go | 32 +++++++++++++------ .../testdata/unlabelStmt/negative_tests.go | 13 ++++++++ checkers/typeSwitchVar_checker.go | 2 +- checkers/unlabelStmt_checker.go | 21 +++++++++--- 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/checkers/internal/lintutil/astfind.go b/checkers/internal/lintutil/astfind.go index 3c0a95afc..a6d0ad7c4 100644 --- a/checkers/internal/lintutil/astfind.go +++ b/checkers/internal/lintutil/astfind.go @@ -7,21 +7,35 @@ import ( ) // FindNode applies pred for root and all it's childs until it returns true. +// If followFunc is defined, it's called before following any node to check whether it needs to be followed. +// followFunc has to return true in order to continuing traversing the node and return false otherwise. // Matched node is returned. // If none of the nodes matched predicate, nil is returned. -func FindNode(root ast.Node, pred func(ast.Node) bool) ast.Node { - var found ast.Node - astutil.Apply(root, nil, func(cur *astutil.Cursor) bool { - if pred(cur.Node()) { - found = cur.Node() - return false +func FindNode(root ast.Node, followFunc, pred func(ast.Node) bool) ast.Node { + var ( + found ast.Node + preFunc func(*astutil.Cursor) bool + ) + + if followFunc != nil { + preFunc = func(cur *astutil.Cursor) bool { + return followFunc(cur.Node()) } - return true - }) + } + + astutil.Apply(root, + preFunc, + func(cur *astutil.Cursor) bool { + if pred(cur.Node()) { + found = cur.Node() + return false + } + return true + }) return found } // ContainsNode reports whether `FindNode(root, pred)!=nil`. func ContainsNode(root ast.Node, pred func(ast.Node) bool) bool { - return FindNode(root, pred) != nil + return FindNode(root, nil, pred) != nil } diff --git a/checkers/testdata/unlabelStmt/negative_tests.go b/checkers/testdata/unlabelStmt/negative_tests.go index 8bd60c7e4..d8283ac39 100644 --- a/checkers/testdata/unlabelStmt/negative_tests.go +++ b/checkers/testdata/unlabelStmt/negative_tests.go @@ -147,3 +147,16 @@ outer2: _ = x } } + +func twoLoopsWithSelect(c chan int) { +outer2: + for { + println("foo") + for { + select { + case <-c: + continue outer2 + } + } + } +} diff --git a/checkers/typeSwitchVar_checker.go b/checkers/typeSwitchVar_checker.go index 6bbec5037..1e11e4937 100644 --- a/checkers/typeSwitchVar_checker.go +++ b/checkers/typeSwitchVar_checker.go @@ -74,7 +74,7 @@ func (c *typeSwitchVarChecker) checkTypeSwitch(root *ast.TypeSwitchStmt) { // Create artificial node just for matching. assert1 := ast.TypeAssertExpr{X: expr, Type: clause.List[0]} for _, stmt := range clause.Body { - assert2 := lintutil.FindNode(stmt, func(x ast.Node) bool { + assert2 := lintutil.FindNode(stmt, nil, func(x ast.Node) bool { return astequal.Node(&assert1, x) }) if object == c.ctx.TypesInfo.ObjectOf(identOf(assert2)) { diff --git a/checkers/unlabelStmt_checker.go b/checkers/unlabelStmt_checker.go index fab864ec5..bcca24d2a 100644 --- a/checkers/unlabelStmt_checker.go +++ b/checkers/unlabelStmt_checker.go @@ -87,6 +87,7 @@ func (c *unlabelStmtChecker) VisitStmt(stmt ast.Stmt) { // Only for loops: if last stmt in list is a loop // that contains labeled "continue" to the outer loop label, // it can be refactored to use "break" instead. + // Exceptions: select statements with a labeled "continue" are ignored. if c.isLoop(labeled.Stmt) { body := c.blockStmtOf(labeled.Stmt) if len(body.List) == 0 { @@ -96,11 +97,21 @@ func (c *unlabelStmtChecker) VisitStmt(stmt ast.Stmt) { if !c.isLoop(last) { return } - br := lintutil.FindNode(c.blockStmtOf(last), func(n ast.Node) bool { - br, ok := n.(*ast.BranchStmt) - return ok && br.Label != nil && - br.Label.Name == name && br.Tok == token.CONTINUE - }) + br := lintutil.FindNode(c.blockStmtOf(last), + func(n ast.Node) bool { + switch n.(type) { + case *ast.SelectStmt: + return false + default: + return true + } + }, + func(n ast.Node) bool { + br, ok := n.(*ast.BranchStmt) + return ok && br.Label != nil && + br.Label.Name == name && br.Tok == token.CONTINUE + }) + if br != nil { c.warnLabeledContinue(br, name) }