Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Run func for all range variables #15

Merged
merged 2 commits into from Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 39 additions & 69 deletions pkg/paralleltest/paralleltest.go
Expand Up @@ -2,6 +2,7 @@ package paralleltest

import (
"go/ast"
"go/types"
"strings"

"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -34,9 +35,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
funcDecl := node.(*ast.FuncDecl)
var funcHasParallelMethod,
rangeStatementOverTestCasesExists,
rangeStatementHasParallelMethod,
testLoopVariableReinitialised bool
var testRunLoopIdentifier string
rangeStatementHasParallelMethod bool
var loopVariableUsedInRun *string
var numberOfTestRun int
var positionOfTestRunNode []ast.Node
var rangeNode ast.Node
Expand Down Expand Up @@ -81,6 +81,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
case *ast.RangeStmt:
rangeNode = v

var loopVars []types.Object
for _, expr := range []ast.Expr{v.Key, v.Value} {
if id, ok := expr.(*ast.Ident); ok {
loopVars = append(loopVars, pass.TypesInfo.ObjectOf(id))
}
}

ast.Inspect(v, func(n ast.Node) bool {
// nolint: gocritic
switch r := n.(type) {
Expand All @@ -90,26 +97,20 @@ func run(pass *analysis.Pass) (interface{}, error) {
innerTestVar := getRunCallbackParameterName(r.X)

rangeStatementOverTestCasesExists = true
testRunLoopIdentifier = methodRunFirstArgumentObjectName(r.X)

if !rangeStatementHasParallelMethod {
rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar)
}

if loopVariableUsedInRun == nil {
if run, ok := r.X.(*ast.CallExpr); ok {
loopVariableUsedInRun = loopVarReferencedInRun(run, loopVars, pass.TypesInfo)
}
}
}
}
return true
})

// Check for the range loop value identifier re assignment
// More info here https://gist.github.com/kunwardeep/80c2e9f3d3256c894898bae82d9f75d0
if rangeStatementOverTestCasesExists {
var rangeValueIdentifier string
if i, ok := v.Value.(*ast.Ident); ok {
rangeValueIdentifier = i.Name
}

testLoopVariableReinitialised = testCaseLoopVariableReinitialised(v.Body.List, rangeValueIdentifier, testRunLoopIdentifier)
}
}
}

Expand All @@ -120,12 +121,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
if rangeStatementOverTestCasesExists && rangeNode != nil {
if !rangeStatementHasParallelMethod {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s missing the call to method parallel in test Run\n", funcDecl.Name.Name)
} else {
if testRunLoopIdentifier == "" {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not use range value in test Run\n", funcDecl.Name.Name)
} else if !testLoopVariableReinitialised {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, testRunLoopIdentifier)
}
} else if loopVariableUsedInRun != nil {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, *loopVariableUsedInRun)
}
}

Expand All @@ -140,38 +137,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}

func testCaseLoopVariableReinitialised(statements []ast.Stmt, rangeValueIdentifier string, testRunLoopIdentifier string) bool {
if len(statements) > 1 {
for _, s := range statements {
leftIdentifier, rightIdentifier := getLeftAndRightIdentifier(s)
if leftIdentifier == testRunLoopIdentifier && rightIdentifier == rangeValueIdentifier {
return true
}
}
}
return false
}

// Return the left hand side and the right hand side identifiers name
func getLeftAndRightIdentifier(s ast.Stmt) (string, string) {
var leftIdentifier, rightIdentifier string
// nolint: gocritic
switch v := s.(type) {
case *ast.AssignStmt:
if len(v.Rhs) == 1 {
if i, ok := v.Rhs[0].(*ast.Ident); ok {
rightIdentifier = i.Name
}
}
if len(v.Lhs) == 1 {
if i, ok := v.Lhs[0].(*ast.Ident); ok {
leftIdentifier = i.Name
}
}
}
return leftIdentifier, rightIdentifier
}

func methodParallelIsCalledInMethodRun(node ast.Node, testVar string) bool {
var methodParallelCalled bool
// nolint: gocritic
Expand Down Expand Up @@ -247,22 +212,6 @@ func getRunCallbackParameterName(node ast.Node) string {
return ""
}

// Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
func methodRunFirstArgumentObjectName(node ast.Node) string {
// nolint: gocritic
switch n := node.(type) {
case *ast.CallExpr:
for _, arg := range n.Args {
if s, ok := arg.(*ast.SelectorExpr); ok {
if i, ok := s.X.(*ast.Ident); ok {
return i.Name
}
}
}
}
return ""
}

// Checks if the function has the param type *testing.T; if it does, then the
// parameter name is returned, too.
func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
Expand Down Expand Up @@ -291,3 +240,24 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {

return false, ""
}

func loopVarReferencedInRun(call *ast.CallExpr, vars []types.Object, typeInfo *types.Info) (found *string) {
if len(call.Args) != 2 {
return
}

ast.Inspect(call.Args[1], func(n ast.Node) bool {
ident, ok := n.(*ast.Ident)
if !ok {
return true
}
for _, o := range vars {
if typeInfo.ObjectOf(ident) == o {
found = &ident.Name
}
}
return true
})

return
}
2 changes: 1 addition & 1 deletion pkg/paralleltest/testdata/src/t/t_test.go
Expand Up @@ -81,7 +81,7 @@ func TestFunctionRangeNotUsingRangeValueInTDotRun(t *testing.T) {
testCases := []struct {
name string
}{{name: "foo"}}
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not use range value in test Run"
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not reinitialise the variable tc"
t.Run("tc.name", func(t *testing.T) {
t.Parallel()
fmt.Println(tc.name)
Expand Down