Skip to content

Commit

Permalink
feat: Improve performance when generating spec with external dependen…
Browse files Browse the repository at this point in the history
…cies (#1108)

* Imporve performance when generating spec with external dependencies

* Fix code review comments

* Add go1.18 test code
  • Loading branch information
pytimer committed May 18, 2022
1 parent 3cedab9 commit 5f6b402
Show file tree
Hide file tree
Showing 19 changed files with 536 additions and 14 deletions.
7 changes: 7 additions & 0 deletions cmd/swag/main.go
Expand Up @@ -29,6 +29,7 @@ const (
parseDepthFlag = "parseDepth"
instanceNameFlag = "instanceName"
overridesFileFlag = "overridesFile"
parseGoListFlag = "parseGoList"
)

var initFlags = []cli.Flag{
Expand Down Expand Up @@ -110,6 +111,11 @@ var initFlags = []cli.Flag{
Value: gen.DefaultOverridesFile,
Usage: "File to read global type overrides from.",
},
&cli.BoolFlag{
Name: parseGoListFlag,
Value: true,
Usage: "Parse dependency via 'go list'",
},
}

func initAction(ctx *cli.Context) error {
Expand Down Expand Up @@ -142,6 +148,7 @@ func initAction(ctx *cli.Context) error {
ParseDepth: ctx.Int(parseDepthFlag),
InstanceName: ctx.String(instanceNameFlag),
OverridesFile: ctx.String(overridesFileFlag),
ParseGoList: ctx.Bool(parseGoListFlag),
})
}

Expand Down
4 changes: 4 additions & 0 deletions gen/gen.go
Expand Up @@ -105,6 +105,9 @@ type Config struct {

// OverridesFile defines global type overrides.
OverridesFile string

// ParseGoList whether swag use go list to parse dependency
ParseGoList bool
}

// Build builds swagger json file for given searchDir and mainAPIFile. Returns json.
Expand Down Expand Up @@ -146,6 +149,7 @@ func (g *Gen) Build(config *Config) error {
swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir),
swag.SetStrict(config.Strict),
swag.SetOverrides(overrides),
swag.ParseUsingGoList(config.ParseGoList),
)

p.PropNamingStrategy = config.PropNamingStrategy
Expand Down
74 changes: 74 additions & 0 deletions golist.go
@@ -0,0 +1,74 @@
package swag

import (
"bytes"
"context"
"encoding/json"
"fmt"
"go/build"
"os/exec"
"path/filepath"
)

func listPackages(ctx context.Context, dir string, env []string, args ...string) (pkgs []*build.Package, finalErr error) {
cmd := exec.CommandContext(ctx, "go", append([]string{"list", "-json", "-e"}, args...)...)
cmd.Env = env
cmd.Dir = dir

stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
defer func() {
if stderrBuf.Len() > 0 {
finalErr = fmt.Errorf("%v\n%s", finalErr, stderrBuf.Bytes())
}
}()

err = cmd.Start()
if err != nil {
return nil, err
}
dec := json.NewDecoder(stdout)
for dec.More() {
var pkg build.Package
err = dec.Decode(&pkg)
if err != nil {
return nil, err
}
pkgs = append(pkgs, &pkg)
}
err = cmd.Wait()
if err != nil {
return nil, err
}
return pkgs, nil
}

func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package) error {
ignoreInternal := pkg.Goroot && !parser.ParseInternal
if ignoreInternal { // ignored internal
return nil
}

srcDir := pkg.Dir
var err error
for i := range pkg.GoFiles {
err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil)
if err != nil {
return err
}
}

// parse .go source files that import "C"
for i := range pkg.CgoFiles {
err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil)
if err != nil {
return err
}
}

return nil
}
116 changes: 116 additions & 0 deletions golist_test.go
@@ -0,0 +1,116 @@
package swag

import (
"context"
"errors"
"fmt"
"go/build"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestListPackages(t *testing.T) {

cases := []struct {
name string
args []string
searchDir string
except error
}{
{
name: "errorArgs",
args: []string{"-abc"},
searchDir: "testdata/golist",
except: fmt.Errorf("exit status 2"),
},
{
name: "normal",
args: []string{"-deps"},
searchDir: "testdata/golist",
except: nil,
},
{
name: "list error",
args: []string{"-deps"},
searchDir: "testdata/golist_not_exist",
except: errors.New("searchDir not exist"),
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
_, err := listPackages(context.TODO(), c.searchDir, nil, c.args...)
if c.except != nil {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}

func TestGetAllGoFileInfoFromDepsByList(t *testing.T) {
p := New(ParseUsingGoList(true))
pwd, err := os.Getwd()
assert.NoError(t, err)
cases := []struct {
name string
buildPackage *build.Package
ignoreInternal bool
except error
}{
{
name: "normal",
buildPackage: &build.Package{
Name: "main",
ImportPath: "github.com/swaggo/swag/testdata/golist",
Dir: "testdata/golist",
GoFiles: []string{"main.go"},
CgoFiles: []string{"api/api.go"},
},
except: nil,
},
{
name: "ignore internal",
buildPackage: &build.Package{
Goroot: true,
},
ignoreInternal: true,
except: nil,
},
{
name: "gofiles error",
buildPackage: &build.Package{
Dir: "testdata/golist_not_exist",
GoFiles: []string{"main.go"},
},
except: errors.New("file not exist"),
},
{
name: "cgofiles error",
buildPackage: &build.Package{
Dir: "testdata/golist_not_exist",
CgoFiles: []string{"main.go"},
},
except: errors.New("file not exist"),
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if c.ignoreInternal {
p.ParseInternal = false
}
c.buildPackage.Dir = filepath.Join(pwd, c.buildPackage.Dir)
err := p.getAllGoFileInfoFromDepsByList(c.buildPackage)
if c.except != nil {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}
6 changes: 6 additions & 0 deletions packages.go
Expand Up @@ -6,6 +6,7 @@ import (
"go/token"
"os"
"path/filepath"
"runtime"
"sort"
"strings"

Expand Down Expand Up @@ -78,6 +79,11 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(files))
for _, info := range files {
// ignore package path prefix with 'vendor' or $GOROOT,
// because the router info of api will not be included these files.
if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) {
continue
}
sortedFiles = append(sortedFiles, info)
}

Expand Down
15 changes: 15 additions & 0 deletions packages_test.go
Expand Up @@ -4,6 +4,7 @@ import (
"go/ast"
"go/token"
"path/filepath"
"runtime"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -151,6 +152,20 @@ func TestPackage_rangeFiles(t *testing.T) {
Path: "testdata/simple/api/api.go",
PackagePath: "api",
},
{
Name: &ast.Ident{Name: "foo.go"},
}: {
File: &ast.File{Name: &ast.Ident{Name: "foo.go"}},
Path: "vendor/foo/foo.go",
PackagePath: "vendor/foo",
},
{
Name: &ast.Ident{Name: "bar.go"},
}: {
File: &ast.File{Name: &ast.Ident{Name: "bar.go"}},
Path: filepath.Join(runtime.GOROOT(), "bar.go"),
PackagePath: "bar",
},
}

var sorted []string
Expand Down
54 changes: 40 additions & 14 deletions parser.go
@@ -1,6 +1,7 @@
package swag

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -140,6 +141,9 @@ type Parser struct {

// Overrides allows global replacements of types. A blank replacement will be skipped.
Overrides map[string]string

// parseGoList whether swag use go list to parse dependency
parseGoList bool
}

// FieldParserFactory create FieldParser.
Expand Down Expand Up @@ -261,6 +265,13 @@ func SetOverrides(overrides map[string]string) func(parser *Parser) {
}
}

// ParseUsingGoList sets whether swag use go list to parse dependency
func ParseUsingGoList(enabled bool) func(parser *Parser) {
return func(p *Parser) {
p.parseGoList = enabled
}
}

// ParseAPI parses general api info for given searchDir and mainAPIFile.
func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string, parseDepth int) error {
return parser.ParseAPIMultiSearchDir([]string{searchDir}, mainAPIFile, parseDepth)
Expand All @@ -287,26 +298,41 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st
return err
}

// Use 'go list' command instead of depth.Resolve()
if parser.ParseDependency {
var tree depth.Tree
tree.ResolveInternal = true
tree.MaxDepth = parseDepth

pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath))
if err != nil {
return err
}
if parser.parseGoList {
pkgs, err := listPackages(context.Background(), filepath.Dir(absMainAPIFilePath), nil, "-deps")
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", filepath.Dir(absMainAPIFilePath), err)
}

err = tree.Resolve(pkgName)
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err)
}
length := len(pkgs)
for i := 0; i < length; i++ {
err := parser.getAllGoFileInfoFromDepsByList(pkgs[i])
if err != nil {
return err
}
}
} else {
var t depth.Tree
t.ResolveInternal = true
t.MaxDepth = parseDepth

for i := 0; i < len(tree.Root.Deps); i++ {
err := parser.getAllGoFileInfoFromDeps(&tree.Root.Deps[i])
pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath))
if err != nil {
return err
}

err = t.Resolve(pkgName)
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err)
}
for i := 0; i < len(t.Root.Deps); i++ {
err := parser.getAllGoFileInfoFromDeps(&t.Root.Deps[i])
if err != nil {
return err
}
}
}
}

Expand Down

0 comments on commit 5f6b402

Please sign in to comment.