Skip to content

Commit

Permalink
fix issue #1345 about generics
Browse files Browse the repository at this point in the history
  • Loading branch information
sdghchj authored and sdghchj committed Oct 15, 2022
1 parent 04c699c commit 4f31e5d
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 26 deletions.
81 changes: 57 additions & 24 deletions generics.go
Expand Up @@ -144,7 +144,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi

parametrizedTypeSpec.TypeSpec.Name = ident

newType := resolveGenericType(original.TypeSpec.Type, genericParamTypeDefs)
newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency)

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
Expand Down Expand Up @@ -197,22 +197,37 @@ func splitStructName(fullGenericForm string) (string, []string) {
return genericTypeName, genericParams
}

func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr {
func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
if genTypeSpec.ArrayDepth > 0 {
genTypeSpec.ArrayDepth--
return &ast.ArrayType{Elt: resolveGenericType(expr, genericParamTypeDefs)}
retType := genTypeSpec.Type()
for i := 0; i < genTypeSpec.ArrayDepth; i++ {
retType = &ast.ArrayType{Elt: retType}
}
return genTypeSpec.Type()
return retType
}
return expr
case *ast.ArrayType:
return &ast.ArrayType{
Elt: resolveGenericType(astExpr.Elt, genericParamTypeDefs),
Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs, parseDependency),
Len: astExpr.Len,
Lbrack: astExpr.Lbrack,
}
case *ast.StarExpr:
return &ast.StarExpr{
Star: astExpr.Star,
X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs, parseDependency),
}
case *ast.IndexExpr, *ast.IndexListExpr:
fullGenericName, err := getGenericFieldType(file, expr, genericParamTypeDefs)
if err != nil {
panic(err)
}
typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency)
if typeDef != nil {
return typeDef.TypeSpec.Type
}
case *ast.StructType:
newStructTypeDef := &ast.StructType{
Struct: astExpr.Struct,
Expand All @@ -231,7 +246,7 @@ func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericT
Comment: field.Comment,
}

newField.Type = resolveGenericType(field.Type, genericParamTypeDefs)
newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs, parseDependency)
if newField.Type == nil {
newField.Type = field.Type
}
Expand All @@ -243,32 +258,42 @@ func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericT
return nil
}

func getExtendedGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
switch fieldType := field.(type) {
case *ast.ArrayType:
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt)
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt, genericParamTypeDefs)
return "[]" + fieldName, err
case *ast.StarExpr:
return getExtendedGenericFieldType(file, fieldType.X)
return getExtendedGenericFieldType(file, fieldType.X, genericParamTypeDefs)
case *ast.Ident:
if genericParamTypeDefs != nil {
if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok {
if typeSpec.TypeSpec == nil && IsGolangPrimitiveType(typeSpec.Name) {
return typeSpec.Name, nil
}
return typeSpec.TypeSpec.FullName(), nil
}
}
return getFieldType(file, field)
default:
return getFieldType(file, field)
}
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
var fullName string
var baseName string
var err error
switch fieldType := field.(type) {
case *ast.IndexListExpr:
baseName, err = getGenericTypeName(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X, genericParamTypeDefs)
if err != nil {
return "", err
}
fullName = baseName + "["

for _, index := range fieldType.Indices {
fieldName, err := getExtendedGenericFieldType(file, index)
fieldName, err := getExtendedGenericFieldType(file, index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -278,12 +303,12 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {

fullName = strings.TrimRight(fullName, ",") + "]"
case *ast.IndexExpr:
baseName, err = getGenericTypeName(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X, genericParamTypeDefs)
if err != nil {
return "", err
}

indexName, err := getExtendedGenericFieldType(file, fieldType.Index)
indexName, err := getExtendedGenericFieldType(file, fieldType.Index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -306,28 +331,36 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil
}

func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
func getGenericTypeName(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
switch fieldType := field.(type) {
case *ast.Ident:
if indexType.Obj == nil {
return getFieldType(file, field)
if fieldType.Obj == nil {
if genericParamTypeDefs != nil {
if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok {
if typeSpec.TypeSpec == nil && IsGolangPrimitiveType(typeSpec.Name) {
return typeSpec.Name, nil
}
return typeSpec.TypeSpec.FullName(), nil
}
}
return fieldType.Name, nil
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.ArrayType:
tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil
}
return "", fmt.Errorf("unknown type %#v", field)
}
Expand All @@ -344,7 +377,7 @@ func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*
case *ast.MapType:
case *ast.FuncType:
case *ast.IndexExpr:
name, err := getExtendedGenericFieldType(file, expr)
name, err := getExtendedGenericFieldType(file, expr, nil)
if err == nil {
if schema, err := parser.getTypeSchema(name, file, false); err == nil {
return spec.MapProperty(schema), nil
Expand Down
4 changes: 4 additions & 0 deletions generics_test.go
Expand Up @@ -296,27 +296,31 @@ func TestGetGenericTypeName(t *testing.T) {
field, err := getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "field.Name", field)

_, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.BadExpr{},
nil,
)
assert.Error(t, err)
}
Expand Down
2 changes: 1 addition & 1 deletion parser.go
Expand Up @@ -1331,7 +1331,7 @@ func getFieldType(file *ast.File, field ast.Expr) (string, error) {

return fullName, nil
default:
return getGenericFieldType(file, field)
return getGenericFieldType(file, field, nil)
}
}

Expand Down
92 changes: 92 additions & 0 deletions testdata/generics_property/expected.json
Expand Up @@ -167,6 +167,26 @@
"properties": {
"value": {
"$ref": "#/definitions/api.Person"
},
"value2": {
"$ref": "#/definitions/api.Person"
},
"value3": {
"type": "array",
"items": {
"$ref": "#/definitions/api.Person"
}
},
"value4": {
"type": "object",
"properties": {
"subValue1": {
"$ref": "#/definitions/api.Person"
},
"subValue2": {
"type": "string"
}
}
}
}
},
Expand All @@ -178,6 +198,32 @@
"items": {
"$ref": "#/definitions/api.Person"
}
},
"value2": {
"type": "array",
"items": {
"$ref": "#/definitions/api.Person"
}
},
"value3": {
"type": "array",
"items": {
"type": "array",
"items": {
"$ref": "#/definitions/api.Person"
}
}
},
"value4": {
"type": "object",
"properties": {
"subValue1": {
"$ref": "#/definitions/api.Person"
},
"subValue2": {
"type": "string"
}
}
}
}
},
Expand All @@ -189,6 +235,32 @@
"items": {
"$ref": "#/definitions/types.Post"
}
},
"value2": {
"type": "array",
"items": {
"$ref": "#/definitions/types.Post"
}
},
"value3": {
"type": "array",
"items": {
"type": "array",
"items": {
"$ref": "#/definitions/types.Post"
}
}
},
"value4": {
"type": "object",
"properties": {
"subValue1": {
"$ref": "#/definitions/types.Post"
},
"subValue2": {
"type": "string"
}
}
}
}
},
Expand All @@ -197,6 +269,26 @@
"properties": {
"value": {
"type": "string"
},
"value2": {
"type": "string"
},
"value3": {
"type": "array",
"items": {
"type": "string"
}
},
"value4": {
"type": "object",
"properties": {
"subValue1": {
"type": "string"
},
"subValue2": {
"type": "string"
}
}
}
}
},
Expand Down
10 changes: 9 additions & 1 deletion testdata/generics_property/types/post.go
@@ -1,7 +1,15 @@
package types

type SubField1[T any, T2 any] struct {
SubValue1 T
SubValue2 T2
}

type Field[T any] struct {
Value T
Value T
Value2 *T
Value3 []T
Value4 SubField1[T, string]
}

type APIBase struct {
Expand Down

0 comments on commit 4f31e5d

Please sign in to comment.