diff --git a/formater.go b/formater.go index e1ce225dc..1c903f48c 100644 --- a/formater.go +++ b/formater.go @@ -7,6 +7,7 @@ import ( "go/ast" goparser "go/parser" "go/token" + "io" "io/ioutil" "log" "os" @@ -30,7 +31,7 @@ type Formater struct { mainFile string } -// NewFormater create a new formater +// NewFormater create a new formater instance. func NewFormater() *Formater { formater := &Formater{ debug: log.New(os.Stdout, "", log.LstdFlags), @@ -110,7 +111,7 @@ func (f *Formater) visit(path string, fileInfo os.FileInfo, err error) error { return nil } -// FormatMain format the main.go comment +// FormatMain format the main.go comment. func (f *Formater) FormatMain(mainFilepath string) error { fileSet := token.NewFileSet() astFile, err := goparser.ParseFile(fileSet, mainFilepath, nil, goparser.ParseComments) @@ -171,17 +172,14 @@ func writeFormatedComments(filepath string, formatedComments bytes.Buffer, oldCo commentHash, commentContent := commentSplit[0], commentSplit[1] if !isBlankComment(commentContent) { - oldComment := oldCommentsMap[commentHash] - if strings.Contains(replaceSrc, oldComment) { - replaceSrc = strings.Replace(replaceSrc, oldComment, commentContent, 1) - } + replaceSrc = strings.Replace(replaceSrc, oldCommentsMap[commentHash], commentContent, 1) } } } return writeBack(filepath, []byte(replaceSrc), srcBytes) } -func formatFuncDoc(commentList []*ast.Comment, formatedComments *bytes.Buffer, oldCommentsMap map[string]string) { +func formatFuncDoc(commentList []*ast.Comment, formatedComments io.Writer, oldCommentsMap map[string]string) { tabw := tabwriter.NewWriter(formatedComments, 0, 0, 2, ' ', 0) for _, comment := range commentList { @@ -315,7 +313,6 @@ func backupFile(filename string, data []byte, perm os.FileMode) (string, error) if err != nil { return "", err } - bakname := f.Name() if chmodSupported { _ = f.Chmod(perm) } @@ -325,5 +322,5 @@ func backupFile(filename string, data []byte, perm os.FileMode) (string, error) if err1 := f.Close(); err == nil { err = err1 } - return bakname, err + return f.Name(), err } diff --git a/gen/gen_test.go b/gen/gen_test.go index 05d0ee833..7d732025f 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -398,7 +398,7 @@ func TestGen_configWithOutputTypesAll(t *testing.T) { if _, err := os.Stat(expectedFile); os.IsNotExist(err) { t.Fatal(err) } - os.Remove(expectedFile) + _ = os.Remove(expectedFile) } } @@ -428,7 +428,7 @@ func TestGen_configWithOutputTypesSingle(t *testing.T) { if _, err := os.Stat(expectedFile); os.IsNotExist(err) { t.Fatal(err) } - os.Remove(expectedFile) + _ = os.Remove(expectedFile) } } } @@ -563,23 +563,6 @@ func TestGen_cgoImports(t *testing.T) { } } -func TestGen_duplicateRoute(t *testing.T) { - config := &Config{ - SearchDir: "../testdata/duplicate_route", - MainAPIFile: "./main.go", - OutputDir: "../testdata/duplicate_route/docs", - PropNamingStrategy: "", - ParseDependency: true, - } - err := New().Build(config) - assert.NoError(t, err) - - // with Strict enabled should cause an error instead of warning about the duplicate route - config.Strict = true - err = New().Build(config) - assert.EqualError(t, err, "route GET /testapi/endpoint is declared multiple times") -} - func TestGen_parseOverrides(t *testing.T) { testCases := []struct { Name string diff --git a/packages_test.go b/packages_test.go index 6db095f45..9163aa0a0 100644 --- a/packages_test.go +++ b/packages_test.go @@ -110,6 +110,63 @@ func TestPackagesDefinitions_ParseTypes(t *testing.T) { assert.NoError(t, err) } +func TestPackagesDefinitions_FindTypeSpec(t *testing.T) { + userDef := TypeSpecDef{ + File: &ast.File{ + Name: &ast.Ident{Name: "user.go"}, + }, + TypeSpec: &ast.TypeSpec{ + Name: ast.NewIdent("User"), + }, + PkgPath: "user", + } + var pkg = PackagesDefinitions{ + uniqueDefinitions: map[string]*TypeSpecDef{ + "user.Model": &userDef, + }, + } + + var nilDef *TypeSpecDef + assert.Equal(t, nilDef, pkg.FindTypeSpec("int", nil, false)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("bool", nil, false)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("string", nil, false)) + + assert.Equal(t, &userDef, pkg.FindTypeSpec("user.Model", nil, false)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("Model", nil, false)) +} + +func TestPackage_rangeFiles(t *testing.T) { + files := map[*ast.File]*AstFileInfo{ + { + Name: &ast.Ident{Name: "main.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "main.go"}}, + Path: "testdata/simple/main.go", + PackagePath: "main", + }, + { + Name: &ast.Ident{Name: "api.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "api.go"}}, + Path: "testdata/simple/api/api.go", + PackagePath: "api", + }, + } + + var sorted []string + processor := func(filename string, file *ast.File) error { + sorted = append(sorted, filename) + return nil + } + assert.NoError(t, rangeFiles(files, processor)) + assert.Equal(t, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}, sorted) + + assert.Error(t, rangeFiles(files, func(filename string, file *ast.File) error { + return ErrFuncTypeField + })) + +} + func TestPackagesDefinitions_findTypeSpec(t *testing.T) { pd := PackagesDefinitions{} var nilTypeSpec *TypeSpecDef @@ -131,4 +188,5 @@ func TestPackagesDefinitions_findTypeSpec(t *testing.T) { } assert.Equal(t, &userTypeSpec, pd.findTypeSpec("model", "User")) assert.Equal(t, nilTypeSpec, pd.findTypeSpec("others", "User")) + } diff --git a/parser_test.go b/parser_test.go index 5282f8e88..223b9f51d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2745,21 +2745,29 @@ func TestParser_ParseRouterApiDuplicateRoute(t *testing.T) { t.Parallel() src := ` -package test +package api -// @Router /api/{id} [get] -func Test1(){ +import ( + "net/http" +) + +// @Router /api/endpoint [get] +func FunctionOne(w http.ResponseWriter, r *http.Request) { + //write your code } -// @Router /api/{id} [get] -func Test2(){ + +// @Router /api/endpoint [get] +func FunctionTwo(w http.ResponseWriter, r *http.Request) { + //write your code } + ` f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) assert.NoError(t, err) p := New(SetStrict(true)) err = p.ParseRouterAPIInfo("", f) - assert.EqualError(t, err, "route GET /api/{id} is declared multiple times") + assert.EqualError(t, err, "route GET /api/endpoint is declared multiple times") p = New() err = p.ParseRouterAPIInfo("", f) diff --git a/testdata/duplicate_route/api/api.go b/testdata/duplicate_route/api/api.go deleted file mode 100644 index cf7daed85..000000000 --- a/testdata/duplicate_route/api/api.go +++ /dev/null @@ -1,17 +0,0 @@ -package api - -import ( - "net/http" - - _ "github.com/swaggo/swag/testdata/simple/web" -) - -// @Router /testapi/endpoint [get] -func FunctionOne(w http.ResponseWriter, r *http.Request) { - //write your code -} - -// @Router /testapi/endpoint [get] -func FunctionTwo(w http.ResponseWriter, r *http.Request) { - //write your code -} diff --git a/testdata/duplicate_route/main.go b/testdata/duplicate_route/main.go deleted file mode 100644 index 1330c302b..000000000 --- a/testdata/duplicate_route/main.go +++ /dev/null @@ -1,13 +0,0 @@ -package main - -import ( - "net/http" - - "github.com/swaggo/swag/testdata/duplicate_route/api" -) - -func main() { - http.HandleFunc("/testapi/endpoint", api.FunctionOne) - http.HandleFunc("/testapi/endpoint", api.FunctionTwo) - http.ListenAndServe(":8080", nil) -}