Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Nested generic fields not fully working, if generic type is from… #1305

Merged
merged 2 commits into from Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 32 additions & 14 deletions generics.go
Expand Up @@ -167,6 +167,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// split only at the first '[' and remove the last ']'
if fullGenericForm[len(fullGenericForm)-1] != ']' {
return "", nil
}

genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2)
if len(genericParams) == 1 {
return "", nil
Expand Down Expand Up @@ -224,12 +228,11 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin
func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.IndexListExpr:
spec := &TypeSpecDef{
File: file,
TypeSpec: getGenericTypeSpec(fieldType.X),
PkgPath: file.Name.Name,
fullName, err := getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}
fullName := spec.FullName() + "["
fullName += "["

for _, index := range fieldType.Indices {
var fieldName string
Expand All @@ -252,11 +255,6 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {

return strings.TrimRight(fullName, ",") + "]", nil
case *ast.IndexExpr:
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ := getFieldType(file, file.Name)

x, err := getFieldType(file, fieldType.X)
if err != nil {
return "", err
Expand All @@ -267,18 +265,38 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", err
}

packageName := ""
if !strings.Contains(x, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}

return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
}

return "", fmt.Errorf("unknown field type %#v", field)
}

func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec {
func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
case *ast.Ident:
return indexType.Obj.Decl.(*ast.TypeSpec)
spec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
case *ast.ArrayType:
return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec)
spec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
}
return nil
return "", fmt.Errorf("unknown type %#v", field)
}
145 changes: 145 additions & 0 deletions generics_test.go
Expand Up @@ -93,6 +93,87 @@ func TestParseGenericsNames(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParametrizeStruct(t *testing.T) {
pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
}
// valid
typeSpec := pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, []string]", false)
assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name())

// definition contains one type params, but two type params are provided
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, string]", false)
assert.Nil(t, typeSpec)

// definition contains two type params, but only one is used
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string]", false)
assert.Nil(t, typeSpec)

// name is not a valid type name
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string", false)
assert.Nil(t, typeSpec)

typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, [string]", false)
assert.Nil(t, typeSpec)

typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, ]string]", false)
assert.Nil(t, typeSpec)
}

func TestSplitStructNames(t *testing.T) {
t.Parallel()

field, params := splitStructName("test.Field")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field]")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string"}, params)

field, params = splitStructName("test.Field[string, []string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string", "[]string"}, params)
}

func TestGetGenericFieldType(t *testing.T) {
field, err := getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
Expand Down Expand Up @@ -124,6 +205,34 @@ func TestGetGenericFieldType(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,int]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}},
},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,[]int]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.BadExpr{},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}},
},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
Expand All @@ -148,4 +257,40 @@ func TestGetGenericFieldType(t *testing.T) {
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}},
)
assert.NoError(t, err)
assert.Equal(t, "field.Name[string]", field)
}

func TestGetGenericTypeName(t *testing.T) {
field, err := getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}},
)
assert.NoError(t, err)
assert.Equal(t, "field.Name", field)

_, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.BadExpr{},
)
assert.Error(t, err)
}
6 changes: 6 additions & 0 deletions testdata/generics_property/api/api.go
@@ -1,9 +1,14 @@
package api

import (
"github.com/swaggo/swag/testdata/generics_property/web"
"net/http"
)

type NestedResponse struct {
web.GenericResponse[[]string, *uint8]
}

// @Summary List Posts
// @Description Get All of the Posts
// @Accept json
Expand All @@ -12,6 +17,7 @@ import (
// @Success 200 {object} web.PostResponse "ok"
// @Success 201 {object} web.PostResponses "ok"
// @Success 202 {object} web.StringResponse "ok"
// @Success 203 {object} NestedResponse "ok"
// @Router /posts [get]
func GetPosts(w http.ResponseWriter, r *http.Request) {
}
25 changes: 20 additions & 5 deletions testdata/generics_property/expected.json
Expand Up @@ -39,11 +39,6 @@
"type": "integer",
"name": "rows",
"in": "query"
},
{
"type": "string",
"name": "search",
"in": "query"
}
],
"responses": {
Expand All @@ -64,12 +59,32 @@
"schema": {
"$ref": "#/definitions/web.StringResponse"
}
},
"203": {
"description": "ok",
"schema": {
"$ref": "#/definitions/api.NestedResponse"
}
}
}
}
}
},
"definitions": {
"api.NestedResponse": {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "string"
}
},
"items2": {
"type": "integer"
}
}
},
"types.Field-string": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion testdata/generics_property/web/handler.go
Expand Up @@ -28,7 +28,7 @@ func (String) Where(ps ...PostSelector) String {

type PostPager struct {
Pager[String, PostSelector]
Search string `json:"search" form:"search"`
Search types.Field[string] `json:"search" form:"search"`
}

type PostResponse struct {
Expand Down