diff --git a/generics.go b/generics.go index d33498fff..38d74f1a4 100644 --- a/generics.go +++ b/generics.go @@ -167,6 +167,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful // splitStructName splits a generic struct name in his parts func splitStructName(fullGenericForm string) (string, []string) { // split only at the first '[' and remove the last ']' + if fullGenericForm[len(fullGenericForm)-1] != ']' { + return "", nil + } + genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2) if len(genericParams) == 1 { return "", nil @@ -224,12 +228,11 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) { switch fieldType := field.(type) { case *ast.IndexListExpr: - spec := &TypeSpecDef{ - File: file, - TypeSpec: getGenericTypeSpec(fieldType.X), - PkgPath: file.Name.Name, + fullName, err := getGenericTypeName(file, fieldType.X) + if err != nil { + return "", err } - fullName := spec.FullName() + "[" + fullName += "[" for _, index := range fieldType.Indices { var fieldName string @@ -252,11 +255,6 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) { return strings.TrimRight(fullName, ",") + "]", nil case *ast.IndexExpr: - if file.Name == nil { - return "", errors.New("file name is nil") - } - packageName, _ := getFieldType(file, file.Name) - x, err := getFieldType(file, fieldType.X) if err != nil { return "", err @@ -267,18 +265,38 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) { return "", err } + packageName := "" + if !strings.Contains(x, ".") { + if file.Name == nil { + return "", errors.New("file name is nil") + } + packageName, _ = getFieldType(file, file.Name) + } + return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil } return "", fmt.Errorf("unknown field type %#v", field) } -func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec { +func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) { switch indexType := field.(type) { case *ast.Ident: - return indexType.Obj.Decl.(*ast.TypeSpec) + spec := &TypeSpecDef{ + File: file, + TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec), + PkgPath: file.Name.Name, + } + return spec.FullName(), nil case *ast.ArrayType: - return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec) + spec := &TypeSpecDef{ + File: file, + TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec), + PkgPath: file.Name.Name, + } + return spec.FullName(), nil + case *ast.SelectorExpr: + return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil } - return nil + return "", fmt.Errorf("unknown type %#v", field) } diff --git a/generics_test.go b/generics_test.go index 109cb6a12..e5604188c 100644 --- a/generics_test.go +++ b/generics_test.go @@ -93,6 +93,87 @@ func TestParseGenericsNames(t *testing.T) { assert.Equal(t, string(expected), string(b)) } +func TestParametrizeStruct(t *testing.T) { + pd := PackagesDefinitions{ + packages: make(map[string]*PackageDefinitions), + } + // valid + typeSpec := pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, []string]", false) + assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name()) + + // definition contains one type params, but two type params are provided + typeSpec = pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, string]", false) + assert.Nil(t, typeSpec) + + // definition contains two type params, but only one is used + typeSpec = pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string]", false) + assert.Nil(t, typeSpec) + + // name is not a valid type name + typeSpec = pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string", false) + assert.Nil(t, typeSpec) + + typeSpec = pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, [string]", false) + assert.Nil(t, typeSpec) + + typeSpec = pd.parametrizeStruct(&TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, ]string]", false) + assert.Nil(t, typeSpec) +} + +func TestSplitStructNames(t *testing.T) { + t.Parallel() + + field, params := splitStructName("test.Field") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitStructName("test.Field]") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitStructName("test.Field[string") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitStructName("test.Field[string]") + assert.Equal(t, "test.Field", field) + assert.Equal(t, []string{"string"}, params) + + field, params = splitStructName("test.Field[string, []string]") + assert.Equal(t, "test.Field", field) + assert.Equal(t, []string{"string", "[]string"}, params) +} + func TestGetGenericFieldType(t *testing.T) { field, err := getFieldType( &ast.File{Name: &ast.Ident{Name: "test"}}, @@ -124,6 +205,34 @@ func TestGetGenericFieldType(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "test.Field[string,int]", field) + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}}, + }, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field[string,[]int]", field) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.BadExpr{}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}}, + }, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}}, + }, + ) + assert.Error(t, err) + field, err = getFieldType( &ast.File{Name: &ast.Ident{Name: "test"}}, &ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}}, @@ -148,4 +257,40 @@ func TestGetGenericFieldType(t *testing.T) { &ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}}, ) assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}}, + ) + assert.NoError(t, err) + assert.Equal(t, "field.Name[string]", field) +} + +func TestGetGenericTypeName(t *testing.T) { + field, err := getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field", field) + + field, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}}, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field", field) + + field, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, + ) + assert.NoError(t, err) + assert.Equal(t, "field.Name", field) + + _, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.BadExpr{}, + ) + assert.Error(t, err) } diff --git a/testdata/generics_property/api/api.go b/testdata/generics_property/api/api.go index 43206653e..2e6fc7caf 100644 --- a/testdata/generics_property/api/api.go +++ b/testdata/generics_property/api/api.go @@ -1,9 +1,14 @@ package api import ( + "github.com/swaggo/swag/testdata/generics_property/web" "net/http" ) +type NestedResponse struct { + web.GenericResponse[[]string, *uint8] +} + // @Summary List Posts // @Description Get All of the Posts // @Accept json @@ -12,6 +17,7 @@ import ( // @Success 200 {object} web.PostResponse "ok" // @Success 201 {object} web.PostResponses "ok" // @Success 202 {object} web.StringResponse "ok" +// @Success 203 {object} NestedResponse "ok" // @Router /posts [get] func GetPosts(w http.ResponseWriter, r *http.Request) { } diff --git a/testdata/generics_property/expected.json b/testdata/generics_property/expected.json index 459ad0a09..24b6f94c4 100644 --- a/testdata/generics_property/expected.json +++ b/testdata/generics_property/expected.json @@ -39,11 +39,6 @@ "type": "integer", "name": "rows", "in": "query" - }, - { - "type": "string", - "name": "search", - "in": "query" } ], "responses": { @@ -64,12 +59,32 @@ "schema": { "$ref": "#/definitions/web.StringResponse" } + }, + "203": { + "description": "ok", + "schema": { + "$ref": "#/definitions/api.NestedResponse" + } } } } } }, "definitions": { + "api.NestedResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "string" + } + }, + "items2": { + "type": "integer" + } + } + }, "types.Field-string": { "type": "object", "properties": { diff --git a/testdata/generics_property/web/handler.go b/testdata/generics_property/web/handler.go index a46af59eb..e22aef025 100644 --- a/testdata/generics_property/web/handler.go +++ b/testdata/generics_property/web/handler.go @@ -28,7 +28,7 @@ func (String) Where(ps ...PostSelector) String { type PostPager struct { Pager[String, PostSelector] - Search string `json:"search" form:"search"` + Search types.Field[string] `json:"search" form:"search"` } type PostResponse struct {