Skip to content

Commit

Permalink
record token.FileSet for every file so that the position of parsing e…
Browse files Browse the repository at this point in the history
…rror can be acquired (#1393)
  • Loading branch information
sdghchj committed Nov 26, 2022
1 parent 30684a2 commit ba5df82
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 107 deletions.
32 changes: 27 additions & 5 deletions packages.go
@@ -1,6 +1,7 @@
package swag

import (
"fmt"
"go/ast"
goparser "go/parser"
"go/token"
Expand All @@ -19,6 +20,7 @@ type PackagesDefinitions struct {
packages map[string]*PackageDefinitions
uniqueDefinitions map[string]*TypeSpecDef
parseDependency bool
debug Debugger
}

// NewPackagesDefinitions create object PackagesDefinitions.
Expand All @@ -30,8 +32,19 @@ func NewPackagesDefinitions() *PackagesDefinitions {
}
}

// CollectAstFile collect ast.file.
func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error {
// ParseFile parse a source file.
func (pkgDefs *PackagesDefinitions) ParseFile(packageDir, path string, src interface{}) error {
// positions are relative to FileSet
fileSet := token.NewFileSet()
astFile, err := goparser.ParseFile(fileSet, path, src, goparser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse file %s, error:%+v", path, err)
}
return pkgDefs.collectAstFile(fileSet, packageDir, path, astFile)
}

// collectAstFile collect ast.file.
func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packageDir, path string, astFile *ast.File) error {
if pkgDefs.files == nil {
pkgDefs.files = make(map[*ast.File]*AstFileInfo)
}
Expand Down Expand Up @@ -64,6 +77,7 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
}

pkgDefs.files[astFile] = &AstFileInfo{
FileSet: fileSet,
File: astFile,
Path: path,
PackagePath: packageDir,
Expand All @@ -73,9 +87,9 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
}

// RangeFiles for range the collection of ast.File in alphabetic order.
func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(files))
for _, info := range files {
func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(pkgDefs.files))
for _, info := range pkgDefs.files {
// ignore package path prefix with 'vendor' or $GOROOT,
// because the router info of api will not be included these files.
if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) {
Expand Down Expand Up @@ -270,6 +284,14 @@ func (pkgDefs *PackagesDefinitions) evaluateAllConstVariables() {
// EvaluateConstValue evaluate a const variable.
func (pkgDefs *PackagesDefinitions) EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) {
if expr, ok := cv.Value.(ast.Expr); ok {
defer func() {
if err := recover(); err != nil {
if fi, ok := pkgDefs.files[cv.File]; ok {
pos := fi.FileSet.Position(cv.Name.NamePos)
pkgDefs.debug.Printf("warning: failed to evaluate const %s at %s:%d:%d, %v", cv.Name.Name, fi.Path, pos.Line, pos.Column, err)
}
}
}()
if recursiveStack == nil {
recursiveStack = make(map[string]struct{})
}
Expand Down
29 changes: 20 additions & 9 deletions packages_test.go
Expand Up @@ -10,35 +10,45 @@ import (
"github.com/stretchr/testify/assert"
)

func TestPackagesDefinitions_CollectAstFile(t *testing.T) {
func TestPackagesDefinitions_ParseFile(t *testing.T) {
pd := PackagesDefinitions{}
assert.NoError(t, pd.CollectAstFile("", "", nil))
packageDir := "github.com/swaggo/swag/testdata/simple"
assert.NoError(t, pd.ParseFile(packageDir, "testdata/simple/main.go", nil))
assert.Equal(t, 1, len(pd.packages))
assert.Equal(t, 1, len(pd.files))
}

func TestPackagesDefinitions_collectAstFile(t *testing.T) {
pd := PackagesDefinitions{}
fileSet := token.NewFileSet()
assert.NoError(t, pd.collectAstFile(fileSet, "", "", nil))

firstFile := &ast.File{
Name: &ast.Ident{Name: "main.go"},
}

packageDir := "github.com/swaggo/swag/testdata/simple"
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NotEmpty(t, pd.packages[packageDir])

absPath, _ := filepath.Abs("testdata/simple/" + firstFile.Name.String())
astFileInfo := &AstFileInfo{
FileSet: fileSet,
File: firstFile,
Path: absPath,
PackagePath: packageDir,
}
assert.Equal(t, pd.files[firstFile], astFileInfo)

// Override
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.Equal(t, pd.files[firstFile], astFileInfo)

// Another file
secondFile := &ast.File{
Name: &ast.Ident{Name: "api.go"},
}
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile))
}

func TestPackagesDefinitions_rangeFiles(t *testing.T) {
Expand All @@ -62,7 +72,7 @@ func TestPackagesDefinitions_rangeFiles(t *testing.T) {
}

i, expect := 0, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}
_ = rangeFiles(pd.files, func(filename string, file *ast.File) error {
_ = pd.RangeFiles(func(filename string, file *ast.File) error {
assert.Equal(t, expect[i], filename)
i++
return nil
Expand Down Expand Up @@ -182,7 +192,8 @@ func TestPackagesDefinitions_FindTypeSpec(t *testing.T) {
}

func TestPackage_rangeFiles(t *testing.T) {
files := map[*ast.File]*AstFileInfo{
pd := NewPackagesDefinitions()
pd.files = map[*ast.File]*AstFileInfo{
{
Name: &ast.Ident{Name: "main.go"},
}: {
Expand Down Expand Up @@ -218,10 +229,10 @@ func TestPackage_rangeFiles(t *testing.T) {
sorted = append(sorted, filename)
return nil
}
assert.NoError(t, rangeFiles(files, processor))
assert.NoError(t, pd.RangeFiles(processor))
assert.Equal(t, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}, sorted)

assert.Error(t, rangeFiles(files, func(filename string, file *ast.File) error {
assert.Error(t, pd.RangeFiles(func(filename string, file *ast.File) error {
return ErrFuncTypeField
}))

Expand Down
19 changes: 4 additions & 15 deletions parser.go
Expand Up @@ -211,6 +211,8 @@ func New(options ...func(*Parser)) *Parser {
option(parser)
}

parser.packages.debug = parser.debug

return parser
}

Expand Down Expand Up @@ -276,7 +278,6 @@ func SetDebugger(logger Debugger) func(parser *Parser) {
if logger != nil {
p.debug = logger
}

}
}

Expand Down Expand Up @@ -377,7 +378,7 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st
return err
}

err = rangeFiles(parser.packages.files, parser.ParseRouterAPIInfo)
err = parser.packages.RangeFiles(parser.ParseRouterAPIInfo)
if err != nil {
return err
}
Expand Down Expand Up @@ -982,7 +983,6 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (
if err == ErrRecursiveParseStruct && ref {
return parser.getRefTypeSchema(typeSpecDef, schema), nil
}

return nil, err
}
}
Expand Down Expand Up @@ -1536,18 +1536,7 @@ func (parser *Parser) parseFile(packageDir, path string, src interface{}) error
return nil
}

// positions are relative to FileSet
astFile, err := goparser.ParseFile(token.NewFileSet(), path, src, goparser.ParseComments)
if err != nil {
return fmt.Errorf("ParseFile error:%+v", err)
}

err = parser.packages.CollectAstFile(packageDir, path, astFile)
if err != nil {
return err
}

return nil
return parser.packages.ParseFile(packageDir, path, src)
}

func (parser *Parser) checkOperationIDUniqueness() error {
Expand Down

0 comments on commit ba5df82

Please sign in to comment.