From 9f54e0f83e2a8d2976c07037ad74aa20c62797a5 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 13 Sep 2023 00:24:18 +0900 Subject: [PATCH] feat: support trailing line comment for mage:import (#480) --- mage/import_test.go | 20 +++++++++ mage/testdata/mageimport/trailing/magefile.go | 5 +++ .../mageimport/trailing/other/other.go | 7 +++ parse/parse.go | 44 +++++++++++++------ 4 files changed, 63 insertions(+), 13 deletions(-) create mode 100644 mage/testdata/mageimport/trailing/magefile.go create mode 100644 mage/testdata/mageimport/trailing/other/other.go diff --git a/mage/import_test.go b/mage/import_test.go index 83949b27..df40a3fb 100644 --- a/mage/import_test.go +++ b/mage/import_test.go @@ -285,6 +285,26 @@ func TestMageImportsOneLine(t *testing.T) { t.Fatalf("expected: %q got: %q", expected, actual) } } +func TestMageImportsTrailing(t *testing.T) { + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/mageimport/trailing", + Stdout: stdout, + Stderr: stderr, + Args: []string{"build"}, + } + + code := Invoke(inv) + if code != 0 { + t.Fatalf("expected to exit with code 0, but got %v, stderr:\n%s", code, stderr) + } + actual := stdout.String() + expected := "build\n" + if actual != expected { + t.Fatalf("expected: %q got: %q", expected, actual) + } +} func TestMageImportsTaggedPackage(t *testing.T) { stdout := &bytes.Buffer{} diff --git a/mage/testdata/mageimport/trailing/magefile.go b/mage/testdata/mageimport/trailing/magefile.go new file mode 100644 index 00000000..ed3691b9 --- /dev/null +++ b/mage/testdata/mageimport/trailing/magefile.go @@ -0,0 +1,5 @@ +// +build mage + +package main + +import _ "github.com/magefile/mage/mage/testdata/mageimport/oneline/other" //mage:import diff --git a/mage/testdata/mageimport/trailing/other/other.go b/mage/testdata/mageimport/trailing/other/other.go new file mode 100644 index 00000000..5d40570b --- /dev/null +++ b/mage/testdata/mageimport/trailing/other/other.go @@ -0,0 +1,7 @@ +package other + +import "fmt" + +func Build() { + fmt.Println("build") +} diff --git a/parse/parse.go b/parse/parse.go index 48cf1ece..cf8b62a5 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -456,19 +456,18 @@ func setImports(gocmd string, pi *PkgInfo) error { } func getImportPath(imp *ast.ImportSpec) (path, alias string, ok bool) { - if imp.Doc == nil || len(imp.Doc.List) == 9 { - return "", "", false - } - // import is always the last comment - s := imp.Doc.List[len(imp.Doc.List)-1].Text - - // trim comment start and normalize for anyone who has spaces or not between - // "//"" and the text - vals := strings.Fields(strings.ToLower(s[2:])) - if len(vals) == 0 { - return "", "", false - } - if vals[0] != importTag { + leadingVals := getImportPathFromCommentGroup(imp.Doc) + trailingVals := getImportPathFromCommentGroup(imp.Comment) + + var vals []string + if len(leadingVals) > 0 { + vals = leadingVals + if len(trailingVals) > 0 { + log.Println("warning:", importTag, "specified both before and after, picking first") + } + } else if len(trailingVals) > 0 { + vals = trailingVals + } else { return "", "", false } path, ok = lit2string(imp.Path) @@ -489,6 +488,25 @@ func getImportPath(imp *ast.ImportSpec) (path, alias string, ok bool) { } } +func getImportPathFromCommentGroup(comments *ast.CommentGroup) []string { + if comments == nil || len(comments.List) == 9 { + return nil + } + // import is always the last comment + s := comments.List[len(comments.List)-1].Text + + // trim comment start and normalize for anyone who has spaces or not between + // "//"" and the text + vals := strings.Fields(strings.ToLower(s[2:])) + if len(vals) == 0 { + return nil + } + if vals[0] != importTag { + return nil + } + return vals +} + func isNamespace(t *doc.Type) bool { if len(t.Decl.Specs) != 1 { return false