From c9ad7ea95c641cdabdcbb84ce767ba503125309d Mon Sep 17 00:00:00 2001 From: doniacld Date: Wed, 27 Oct 2021 21:29:45 +0200 Subject: [PATCH] Update error-strings rule (#608) --- rule/error-strings.go | 91 ++++++++++++++++++--- test/error-strings_test.go | 11 +++ testdata/golint/error-strings-pkg-errors.go | 18 ++++ 3 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 test/error-strings_test.go create mode 100644 testdata/golint/error-strings-pkg-errors.go diff --git a/rule/error-strings.go b/rule/error-strings.go index b8a5b7ed7..a08c528cd 100644 --- a/rule/error-strings.go +++ b/rule/error-strings.go @@ -17,10 +17,25 @@ type ErrorStringsRule struct{} func (r *ErrorStringsRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure { var failures []lint.Failure + var printFunctions = map[string]map[string]struct{}{ + "fmt": { + "Errorf": {}, + }, + "errors": { + "Errorf": {}, + "WithMessage": {}, + "Wrap": {}, + "New": {}, + "WithMessagef": {}, + "Wrapf": {}, + }, + } + fileAst := file.AST walker := lintErrorStrings{ - file: file, - fileAst: fileAst, + file: file, + fileAst: fileAst, + printFunctions: printFunctions, onFailure: func(failure lint.Failure) { failures = append(failures, failure) }, @@ -37,24 +52,31 @@ func (r *ErrorStringsRule) Name() string { } type lintErrorStrings struct { - file *lint.File - fileAst *ast.File - onFailure func(lint.Failure) + file *lint.File + fileAst *ast.File + printFunctions map[string]map[string]struct{} + onFailure func(lint.Failure) } +// Visit browses the AST func (w lintErrorStrings) Visit(n ast.Node) ast.Visitor { ce, ok := n.(*ast.CallExpr) if !ok { return w } - if !isPkgDot(ce.Fun, "errors", "New") && !isPkgDot(ce.Fun, "fmt", "Errorf") { + + if len(ce.Args) < 1 { return w } - if len(ce.Args) < 1 { + + // expression matches the known pkg.function + ok = w.match(ce) + if !ok { return w } - str, ok := ce.Args[0].(*ast.BasicLit) - if !ok || str.Kind != token.STRING { + + str, ok := w.getMessage(ce) + if !ok { return w } s, _ := strconv.Unquote(str.Value) // can assume well-formed Go @@ -65,7 +87,6 @@ func (w lintErrorStrings) Visit(n ast.Node) ast.Visitor { if clean { return w } - w.onFailure(lint.Failure{ Node: str, Confidence: conf, @@ -75,6 +96,56 @@ func (w lintErrorStrings) Visit(n ast.Node) ast.Visitor { return w } +// match returns true if the expression corresponds to the known pkg.function +// i.e.: errors.Wrap +func (w lintErrorStrings) match(expr *ast.CallExpr) bool { + sel, ok := expr.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + // retrieve the package + id, ok := sel.X.(*ast.Ident) + functions, ok := w.printFunctions[id.Name] + if !ok { + return false + } + // retrieve the function + _, ok = functions[sel.Sel.Name] + if !ok { + return false + } + return true +} + +// getMessage returns the message depending on its position +// returns false if the cast is unsuccessful +func (lintErrorStrings) getMessage(expr *ast.CallExpr) (s *ast.BasicLit, success bool) { + str, ok := checkArg(expr, 0) + if ok { + return str, true + } + + if len(expr.Args) < 2 { + return s, false + } + str, ok = checkArg(expr, 1) + if !ok { + return s, false + } + return str, true +} + +func checkArg(expr *ast.CallExpr, arg int) (s *ast.BasicLit, success bool) { + str, ok := expr.Args[arg].(*ast.BasicLit) + if !ok { + return s, false + } + if str.Kind != token.STRING { + return s, false + } + return str, true +} + func lintErrorString(s string) (isClean bool, conf float64) { const basicConfidence = 0.8 const capConfidence = basicConfidence - 0.2 diff --git a/test/error-strings_test.go b/test/error-strings_test.go new file mode 100644 index 000000000..2a535aa98 --- /dev/null +++ b/test/error-strings_test.go @@ -0,0 +1,11 @@ +package test + +import ( + "github.com/mgechev/revive/rule" + "testing" +) + +func TestErrorStrings(t *testing.T) { + testRule(t, "error-strings-pkg-errors", &rule.ErrorStringsRule{}) +} + diff --git a/testdata/golint/error-strings-pkg-errors.go b/testdata/golint/error-strings-pkg-errors.go new file mode 100644 index 000000000..c83723c26 --- /dev/null +++ b/testdata/golint/error-strings-pkg-errors.go @@ -0,0 +1,18 @@ +// Package foo ... +package foo + +import ( + "github.com/pkg/errors" +) + +// Check for the error strings themselves. + +func errorsStrings(x int) error { + var err error + err = errors.Wrap(err, "This %d is too low") // MATCH /error strings should not be capitalized or end with punctuation or a newline/ + err = errors.New("This %d is too low") // MATCH /error strings should not be capitalized or end with punctuation or a newline/ + err = errors.Wrapf(err, "This %d is too low", x) // MATCH /error strings should not be capitalized or end with punctuation or a newline/ + err = errors.WithMessage(err, "This %d is too low") // MATCH /error strings should not be capitalized or end with punctuation or a newline/ + err = errors.WithMessagef(err, "This %d is too low", x) // MATCH /error strings should not be capitalized or end with punctuation or a newline/ + return err +}