Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Go generics cannot find common package object type definition #1281

Merged
merged 3 commits into from Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions gen/gen_test.go
Expand Up @@ -641,6 +641,20 @@ func TestGen_parseOverrides(t *testing.T) {
"github.com/foo/bar": "",
},
},
{
Name: "generic-simple",
Data: `replace types.Field[string] string`,
Expected: map[string]string{
"types.Field[string]": "string",
},
},
{
Name: "generic-double",
Data: `replace types.Field[string,string] string`,
Expected: map[string]string{
"types.Field[string,string]": "string",
},
},
{
Name: "comment",
Data: `// this is a comment
Expand Down
41 changes: 28 additions & 13 deletions generics.go
Expand Up @@ -4,6 +4,7 @@
package swag

import (
"errors"
"fmt"
"go/ast"
"strings"
Expand All @@ -19,7 +20,10 @@ type genericTypeSpec struct {

func (s *genericTypeSpec) Type() ast.Expr {
if s.TypeSpec != nil {
return s.TypeSpec.TypeSpec.Type
return &ast.SelectorExpr{
X: &ast.Ident{Name: ""},
Sel: &ast.Ident{Name: s.Name},
}
}

return &ast.Ident{Name: s.Name}
Expand Down Expand Up @@ -78,16 +82,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
}

tdef := pkgDefs.FindTypeSpec(genericParam, original.File, parseDependency)
if tdef == nil {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
Name: genericParam,
}
} else {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
}
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
Name: genericParam,
}
}

Expand Down Expand Up @@ -249,10 +247,27 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", err
}

fullName += fieldName + ", "
fullName += fieldName + ","
}

return strings.TrimRight(fullName, ",") + "]", nil
case *ast.IndexExpr:
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ := getFieldType(file, file.Name)

x, err := getFieldType(file, fieldType.X)
if err != nil {
return "", err
}

i, err := getFieldType(file, fieldType.Index)
if err != nil {
return "", err
}

return strings.TrimRight(fullName, ", ") + "]", nil
return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
}

return "", fmt.Errorf("unknown field type %#v", field)
Expand Down
63 changes: 63 additions & 0 deletions generics_test.go
Expand Up @@ -5,6 +5,7 @@ package swag

import (
"encoding/json"
"go/ast"
"io/ioutil"
"path/filepath"
"testing"
Expand All @@ -20,6 +21,11 @@ func TestParseGenericsBasic(t *testing.T) {
assert.NoError(t, err)

p := New()
p.Overrides = map[string]string{
"types.Field[string]": "string",
"types.DoubleField[string,string]": "string",
}

err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, err := json.MarshalIndent(p.swagger, "", " ")
Expand Down Expand Up @@ -86,3 +92,60 @@ func TestParseGenericsNames(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}

func TestGetGenericFieldType(t *testing.T) {
field, err := getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}},
},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}},
},
)
assert.NoError(t, err)
assert.Equal(t, "Field[string]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,int]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string]", field)

field, err = getFieldType(
&ast.File{Name: nil},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.BadExpr{}, Index: &ast.Ident{Name: "string"}},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}},
)
assert.Error(t, err)
}
25 changes: 13 additions & 12 deletions testdata/generics_arrays/api/api.go
Expand Up @@ -3,43 +3,44 @@ package api
import (
"net/http"

"github.com/swaggo/swag/testdata/generics_arrays/types"
"github.com/swaggo/swag/testdata/generics_arrays/web"
)

// @Summary List Posts
// @Description Get All of the Posts
// @Accept json
// @Produce json
// @Param data body web.GenericListBody[web.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[web.Post]
// @Success 222 {object} web.GenericListResponseMulti[web.Post, web.Post]
// @Param data body web.GenericListBody[types.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[types.Post]
// @Success 222 {object} web.GenericListResponseMulti[types.Post, types.Post]
// @Router /posts [get]
func GetPosts(w http.ResponseWriter, r *http.Request) {
_ = web.GenericListResponseMulti[web.Post, web.Post]{}
_ = web.GenericListResponseMulti[types.Post, types.Post]{}
}

// @Summary Add new pets to the store
// @Description get string by ID
// @Accept json
// @Produce json
// @Param data body web.GenericListBodyMulti[web.Post, web.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[web.Post]
// @Success 222 {object} web.GenericListResponseMulti[web.Post, web.Post]
// @Param data body web.GenericListBodyMulti[types.Post, types.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[types.Post]
// @Success 222 {object} web.GenericListResponseMulti[types.Post, types.Post]
// @Router /posts-multi [get]
func GetPostMulti(w http.ResponseWriter, r *http.Request) {
//write your code
_ = web.GenericListResponseMulti[web.Post, web.Post]{}
_ = web.GenericListResponseMulti[types.Post, types.Post]{}
}

// @Summary Add new pets to the store
// @Description get string by ID
// @Accept json
// @Produce json
// @Param data body web.GenericListBodyMulti[web.Post, []web.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[[]web.Post]
// @Success 222 {object} web.GenericListResponseMulti[web.Post, []web.Post]
// @Param data body web.GenericListBodyMulti[types.Post, []types.Post] true "Some ID"
// @Success 200 {object} web.GenericListResponse[[]types.Post]
// @Success 222 {object} web.GenericListResponseMulti[types.Post, []types.Post]
// @Router /posts-multis [get]
func GetPostArray(w http.ResponseWriter, r *http.Request) {
//write your code
_ = web.GenericListResponseMulti[web.Post, []web.Post]{}
_ = web.GenericListResponseMulti[types.Post, []types.Post]{}
}