diff --git a/cmd/swag/main.go b/cmd/swag/main.go index 579178032..dde1e0b8c 100644 --- a/cmd/swag/main.go +++ b/cmd/swag/main.go @@ -29,6 +29,7 @@ const ( parseDepthFlag = "parseDepth" instanceNameFlag = "instanceName" overridesFileFlag = "overridesFile" + parseGoListFlag = "parseGoList" ) var initFlags = []cli.Flag{ @@ -110,6 +111,11 @@ var initFlags = []cli.Flag{ Value: gen.DefaultOverridesFile, Usage: "File to read global type overrides from.", }, + &cli.BoolFlag{ + Name: parseGoListFlag, + Value: true, + Usage: "Parse dependency via 'go list'", + }, } func initAction(ctx *cli.Context) error { @@ -142,6 +148,7 @@ func initAction(ctx *cli.Context) error { ParseDepth: ctx.Int(parseDepthFlag), InstanceName: ctx.String(instanceNameFlag), OverridesFile: ctx.String(overridesFileFlag), + ParseGoList: ctx.Bool(parseGoListFlag), }) } diff --git a/gen/gen.go b/gen/gen.go index ce9f0db4d..7198433f6 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -105,6 +105,9 @@ type Config struct { // OverridesFile defines global type overrides. OverridesFile string + + // ParseGoList whether swag use go list to parse dependency + ParseGoList bool } // Build builds swagger json file for given searchDir and mainAPIFile. Returns json. @@ -146,6 +149,7 @@ func (g *Gen) Build(config *Config) error { swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir), swag.SetStrict(config.Strict), swag.SetOverrides(overrides), + swag.ParseUsingGoList(config.ParseGoList), ) p.PropNamingStrategy = config.PropNamingStrategy diff --git a/golist.go b/golist.go new file mode 100644 index 000000000..b8663abde --- /dev/null +++ b/golist.go @@ -0,0 +1,74 @@ +package swag + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "go/build" + "os/exec" + "path/filepath" +) + +func listPackages(ctx context.Context, dir string, env []string, args ...string) (pkgs []*build.Package, finalErr error) { + cmd := exec.CommandContext(ctx, "go", append([]string{"list", "-json", "-e"}, args...)...) + cmd.Env = env + cmd.Dir = dir + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + var stderrBuf bytes.Buffer + cmd.Stderr = &stderrBuf + defer func() { + if stderrBuf.Len() > 0 { + finalErr = fmt.Errorf("%v\n%s", finalErr, stderrBuf.Bytes()) + } + }() + + err = cmd.Start() + if err != nil { + return nil, err + } + dec := json.NewDecoder(stdout) + for dec.More() { + var pkg build.Package + err = dec.Decode(&pkg) + if err != nil { + return nil, err + } + pkgs = append(pkgs, &pkg) + } + err = cmd.Wait() + if err != nil { + return nil, err + } + return pkgs, nil +} + +func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package) error { + ignoreInternal := pkg.Goroot && !parser.ParseInternal + if ignoreInternal { // ignored internal + return nil + } + + srcDir := pkg.Dir + var err error + for i := range pkg.GoFiles { + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil) + if err != nil { + return err + } + } + + // parse .go source files that import "C" + for i := range pkg.CgoFiles { + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil) + if err != nil { + return err + } + } + + return nil +} diff --git a/golist_test.go b/golist_test.go new file mode 100644 index 000000000..6b11da7b7 --- /dev/null +++ b/golist_test.go @@ -0,0 +1,116 @@ +package swag + +import ( + "context" + "errors" + "fmt" + "go/build" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListPackages(t *testing.T) { + + cases := []struct { + name string + args []string + searchDir string + except error + }{ + { + name: "errorArgs", + args: []string{"-abc"}, + searchDir: "testdata/golist", + except: fmt.Errorf("exit status 2"), + }, + { + name: "normal", + args: []string{"-deps"}, + searchDir: "testdata/golist", + except: nil, + }, + { + name: "list error", + args: []string{"-deps"}, + searchDir: "testdata/golist_not_exist", + except: errors.New("searchDir not exist"), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + _, err := listPackages(context.TODO(), c.searchDir, nil, c.args...) + if c.except != nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestGetAllGoFileInfoFromDepsByList(t *testing.T) { + p := New(ParseUsingGoList(true)) + pwd, err := os.Getwd() + assert.NoError(t, err) + cases := []struct { + name string + buildPackage *build.Package + ignoreInternal bool + except error + }{ + { + name: "normal", + buildPackage: &build.Package{ + Name: "main", + ImportPath: "github.com/swaggo/swag/testdata/golist", + Dir: "testdata/golist", + GoFiles: []string{"main.go"}, + CgoFiles: []string{"api/api.go"}, + }, + except: nil, + }, + { + name: "ignore internal", + buildPackage: &build.Package{ + Goroot: true, + }, + ignoreInternal: true, + except: nil, + }, + { + name: "gofiles error", + buildPackage: &build.Package{ + Dir: "testdata/golist_not_exist", + GoFiles: []string{"main.go"}, + }, + except: errors.New("file not exist"), + }, + { + name: "cgofiles error", + buildPackage: &build.Package{ + Dir: "testdata/golist_not_exist", + CgoFiles: []string{"main.go"}, + }, + except: errors.New("file not exist"), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.ignoreInternal { + p.ParseInternal = false + } + c.buildPackage.Dir = filepath.Join(pwd, c.buildPackage.Dir) + err := p.getAllGoFileInfoFromDepsByList(c.buildPackage) + if c.except != nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/packages.go b/packages.go index dd1a0e6c7..fc7a6bd2b 100644 --- a/packages.go +++ b/packages.go @@ -6,6 +6,7 @@ import ( "go/token" "os" "path/filepath" + "runtime" "sort" "strings" @@ -78,6 +79,11 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error { sortedFiles := make([]*AstFileInfo, 0, len(files)) for _, info := range files { + // ignore package path prefix with 'vendor' or $GOROOT, + // because the router info of api will not be included these files. + if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) { + continue + } sortedFiles = append(sortedFiles, info) } diff --git a/packages_test.go b/packages_test.go index 9163aa0a0..d74ba4d3a 100644 --- a/packages_test.go +++ b/packages_test.go @@ -4,6 +4,7 @@ import ( "go/ast" "go/token" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" @@ -151,6 +152,20 @@ func TestPackage_rangeFiles(t *testing.T) { Path: "testdata/simple/api/api.go", PackagePath: "api", }, + { + Name: &ast.Ident{Name: "foo.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "foo.go"}}, + Path: "vendor/foo/foo.go", + PackagePath: "vendor/foo", + }, + { + Name: &ast.Ident{Name: "bar.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "bar.go"}}, + Path: filepath.Join(runtime.GOROOT(), "bar.go"), + PackagePath: "bar", + }, } var sorted []string diff --git a/parser.go b/parser.go index a390714f4..c8a18a6db 100644 --- a/parser.go +++ b/parser.go @@ -1,6 +1,7 @@ package swag import ( + "context" "encoding/json" "errors" "fmt" @@ -140,6 +141,9 @@ type Parser struct { // Overrides allows global replacements of types. A blank replacement will be skipped. Overrides map[string]string + + // parseGoList whether swag use go list to parse dependency + parseGoList bool } // FieldParserFactory create FieldParser. @@ -261,6 +265,13 @@ func SetOverrides(overrides map[string]string) func(parser *Parser) { } } +// ParseUsingGoList sets whether swag use go list to parse dependency +func ParseUsingGoList(enabled bool) func(parser *Parser) { + return func(p *Parser) { + p.parseGoList = enabled + } +} + // ParseAPI parses general api info for given searchDir and mainAPIFile. func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string, parseDepth int) error { return parser.ParseAPIMultiSearchDir([]string{searchDir}, mainAPIFile, parseDepth) @@ -287,26 +298,41 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st return err } + // Use 'go list' command instead of depth.Resolve() if parser.ParseDependency { - var tree depth.Tree - tree.ResolveInternal = true - tree.MaxDepth = parseDepth - - pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath)) - if err != nil { - return err - } + if parser.parseGoList { + pkgs, err := listPackages(context.Background(), filepath.Dir(absMainAPIFilePath), nil, "-deps") + if err != nil { + return fmt.Errorf("pkg %s cannot find all dependencies, %s", filepath.Dir(absMainAPIFilePath), err) + } - err = tree.Resolve(pkgName) - if err != nil { - return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err) - } + length := len(pkgs) + for i := 0; i < length; i++ { + err := parser.getAllGoFileInfoFromDepsByList(pkgs[i]) + if err != nil { + return err + } + } + } else { + var t depth.Tree + t.ResolveInternal = true + t.MaxDepth = parseDepth - for i := 0; i < len(tree.Root.Deps); i++ { - err := parser.getAllGoFileInfoFromDeps(&tree.Root.Deps[i]) + pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath)) if err != nil { return err } + + err = t.Resolve(pkgName) + if err != nil { + return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err) + } + for i := 0; i < len(t.Root.Deps); i++ { + err := parser.getAllGoFileInfoFromDeps(&t.Root.Deps[i]) + if err != nil { + return err + } + } } } diff --git a/parser_test.go b/parser_test.go index 392a04fe2..45d02bdfd 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,6 +3,7 @@ package swag import ( "bytes" "encoding/json" + "errors" "go/ast" goparser "go/parser" "go/token" @@ -2159,6 +2160,109 @@ func TestParseExternalModels(t *testing.T) { assert.Equal(t, string(expected), string(b)) } +func TestParseGoList(t *testing.T) { + mainAPIFile := "main.go" + p := New(ParseUsingGoList(true)) + p.ParseDependency = true + + go111moduleEnv := os.Getenv("GO111MODULE") + + cases := []struct { + name string + gomodule bool + searchDir string + err error + run func(searchDir string) error + }{ + { + name: "disableGOMODULE", + gomodule: false, + searchDir: "testdata/golist_disablemodule", + run: func(searchDir string) error { + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "enableGOMODULE", + gomodule: true, + searchDir: "testdata/golist", + run: func(searchDir string) error { + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "invalid_main", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("no such file or directory"), + run: func(searchDir string) error { + return p.ParseAPI(searchDir, "invalid/main.go", defaultParseDepth) + }, + }, + { + name: "internal_invalid_pkg", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("expected 'package', found This"), + run: func(searchDir string) error { + mockErrGoFile := "testdata/golist_invalid/err.go" + f, err := os.OpenFile(mockErrGoFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write([]byte(`package invalid + +function a() {}`)) + if err != nil { + return err + } + defer os.Remove(mockErrGoFile) + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "invalid_pkg", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("expected 'package', found This"), + run: func(searchDir string) error { + mockErrGoFile := "testdata/invalid_external_pkg/invalid/err.go" + f, err := os.OpenFile(mockErrGoFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write([]byte(`package invalid + +function a() {}`)) + if err != nil { + return err + } + defer os.Remove(mockErrGoFile) + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.gomodule { + os.Setenv("GO111MODULE", "on") + } else { + os.Setenv("GO111MODULE", "off") + } + err := c.run(c.searchDir) + os.Setenv("GO111MODULE", go111moduleEnv) + if c.err == nil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + func TestParser_ParseStructArrayObject(t *testing.T) { t.Parallel() diff --git a/testdata/golist/api/api.go b/testdata/golist/api/api.go new file mode 100644 index 000000000..09191daff --- /dev/null +++ b/testdata/golist/api/api.go @@ -0,0 +1,38 @@ +package api + +/* +#include "foo.h" +*/ +import "C" +import ( + "fmt" + "net/http" +) + +func PrintInt(i, j int) { + res := C.add(C.int(i), C.int(j)) + fmt.Println(res) +} + +type Foo struct { + ID int `json:"id"` + Name string `json:"name"` + PhotoUrls []string `json:"photoUrls"` + Status string `json:"status"` +} + +// GetFoo example +// @Summary Get foo +// @Description get foo +// @ID foo +// @Accept json +// @Produce json +// @Param some_id query int true "Some ID" +// @Param some_foo formData Foo true "Foo" +// @Success 200 {string} string "ok" +// @Failure 400 {object} web.APIError "We need ID!!" +// @Failure 404 {object} web.APIError "Can not find ID" +// @Router /testapi/foo [get] +func GetFoo(w http.ResponseWriter, r *http.Request) { + // write your code +} diff --git a/testdata/golist/api/foo.c b/testdata/golist/api/foo.c new file mode 100644 index 000000000..082fc1e6f --- /dev/null +++ b/testdata/golist/api/foo.c @@ -0,0 +1,3 @@ +int add(int a, int b) { + return a + b; +} \ No newline at end of file diff --git a/testdata/golist/api/foo.h b/testdata/golist/api/foo.h new file mode 100644 index 000000000..0228a4c3c --- /dev/null +++ b/testdata/golist/api/foo.h @@ -0,0 +1 @@ +int add(int, int); \ No newline at end of file diff --git a/testdata/golist/main.go b/testdata/golist/main.go new file mode 100644 index 000000000..70143eb88 --- /dev/null +++ b/testdata/golist/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "net/http" + + "github.com/swaggo/swag/example/basic/api" + goapi "github.com/swaggo/swag/testdata/golist/api" +) + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server Petstore server. +// @termsOfService http://swagger.io/terms/ + +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io + +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html + +// @securityDefinitions.apikey ApiKeyAuth +// @in header +// @name Authorization + +// @query.collection.format multi +// @host petstore.swagger.io +// @BasePath /v2 +func main() { + goapi.PrintInt(10, 5) + http.HandleFunc("/testapi/get-string-by-int/", api.GetStringByInt) + http.HandleFunc("/testapi/get-struct-array-by-string/", api.GetStructArrayByString) + http.HandleFunc("/testapi/upload", api.Upload) + http.HandleFunc("/testapi/foo", goapi.GetFoo) + http.ListenAndServe(":8080", nil) +} diff --git a/testdata/golist_disablemodule/api/api.go b/testdata/golist_disablemodule/api/api.go new file mode 100644 index 000000000..1f1f28d60 --- /dev/null +++ b/testdata/golist_disablemodule/api/api.go @@ -0,0 +1,14 @@ +package api + +/* +#include "foo.h" +*/ +import "C" +import ( + "fmt" +) + +func PrintInt(i, j int) { + res := C.add(C.int(i), C.int(j)) + fmt.Println(res) +} diff --git a/testdata/golist_disablemodule/api/foo.c b/testdata/golist_disablemodule/api/foo.c new file mode 100644 index 000000000..082fc1e6f --- /dev/null +++ b/testdata/golist_disablemodule/api/foo.c @@ -0,0 +1,3 @@ +int add(int a, int b) { + return a + b; +} \ No newline at end of file diff --git a/testdata/golist_disablemodule/api/foo.h b/testdata/golist_disablemodule/api/foo.h new file mode 100644 index 000000000..0228a4c3c --- /dev/null +++ b/testdata/golist_disablemodule/api/foo.h @@ -0,0 +1 @@ +int add(int, int); \ No newline at end of file diff --git a/testdata/golist_disablemodule/main.go b/testdata/golist_disablemodule/main.go new file mode 100644 index 000000000..ab7d61735 --- /dev/null +++ b/testdata/golist_disablemodule/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "net/http" + + "github.com/swaggo/swag/example/basic/api" + internalapi "github.com/swaggo/swag/testdata/golist_disablemodule/api" +) + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server Petstore server. +// @termsOfService http://swagger.io/terms/ + +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io + +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html + +// @securityDefinitions.apikey ApiKeyAuth +// @in header +// @name Authorization + +// @host petstore.swagger.io +// @BasePath /v2 +func main() { + internalapi.PrintInt(0, 1) + http.HandleFunc("/testapi/get-string-by-int/", api.GetStringByInt) + http.HandleFunc("/testapi/get-struct-array-by-string/", api.GetStructArrayByString) + http.HandleFunc("/testapi/upload", api.Upload) + http.ListenAndServe(":8080", nil) +} diff --git a/testdata/golist_invalid/main.go b/testdata/golist_invalid/main.go new file mode 100644 index 000000000..69bae8025 --- /dev/null +++ b/testdata/golist_invalid/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "net/http" + + "github.com/swaggo/swag/example/basic/api" + "github.com/swaggo/swag/testdata/invalid_external_pkg/invalid" +) + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server Petstore server. +// @termsOfService http://swagger.io/terms/ + +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io + +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html + +// @securityDefinitions.apikey ApiKeyAuth +// @in header +// @name Authorization + +// @host petstore.swagger.io +// @BasePath /v2 +func main() { + invalid.Foo() + http.HandleFunc("/testapi/upload", api.Upload) + http.ListenAndServe(":8080", nil) +} diff --git a/testdata/invalid_external_pkg/invalid/normal.go b/testdata/invalid_external_pkg/invalid/normal.go new file mode 100644 index 000000000..f58a341df --- /dev/null +++ b/testdata/invalid_external_pkg/invalid/normal.go @@ -0,0 +1,5 @@ +package invalid + +func Foo() { + +} diff --git a/testdata/invalid_external_pkg/main.go b/testdata/invalid_external_pkg/main.go new file mode 100644 index 000000000..38dd16da6 --- /dev/null +++ b/testdata/invalid_external_pkg/main.go @@ -0,0 +1,3 @@ +package main + +func main() {}