Skip to content

Commit

Permalink
Fix generics issue #1345 (#1349)
Browse files Browse the repository at this point in the history
* fix issue #1345 about generics

* fix issue #1345 about generics

* fix issue #1345 about generics

* fix issue #1345 about generics

* fix tests coverage

* fix tests coverage

* fix tests coverage

* no need to wrap schema by map

Co-authored-by: sdghchj <sdghchj@qq.com>
  • Loading branch information
sdghchj committed Oct 16, 2022
1 parent 04c699c commit 7f90377
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 29 deletions.
77 changes: 51 additions & 26 deletions generics.go
Expand Up @@ -144,7 +144,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi

parametrizedTypeSpec.TypeSpec.Name = ident

newType := resolveGenericType(original.TypeSpec.Type, genericParamTypeDefs)
newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency)

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
Expand Down Expand Up @@ -197,22 +197,33 @@ func splitStructName(fullGenericForm string) (string, []string) {
return genericTypeName, genericParams
}

func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr {
func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
if genTypeSpec.ArrayDepth > 0 {
genTypeSpec.ArrayDepth--
return &ast.ArrayType{Elt: resolveGenericType(expr, genericParamTypeDefs)}
retType := genTypeSpec.Type()
for i := 0; i < genTypeSpec.ArrayDepth; i++ {
retType = &ast.ArrayType{Elt: retType}
}
return genTypeSpec.Type()
return retType
}
case *ast.ArrayType:
return &ast.ArrayType{
Elt: resolveGenericType(astExpr.Elt, genericParamTypeDefs),
Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs, parseDependency),
Len: astExpr.Len,
Lbrack: astExpr.Lbrack,
}
case *ast.StarExpr:
return &ast.StarExpr{
Star: astExpr.Star,
X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs, parseDependency),
}
case *ast.IndexExpr, *ast.IndexListExpr:
fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs)
typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency)
if typeDef != nil {
return typeDef.TypeSpec.Type
}
case *ast.StructType:
newStructTypeDef := &ast.StructType{
Struct: astExpr.Struct,
Expand All @@ -225,37 +236,51 @@ func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericT

for _, field := range astExpr.Fields.List {
newField := &ast.Field{
Type: field.Type,
Doc: field.Doc,
Names: field.Names,
Tag: field.Tag,
Comment: field.Comment,
}

newField.Type = resolveGenericType(field.Type, genericParamTypeDefs)
if newField.Type == nil {
newField.Type = field.Type
}
newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs, parseDependency)

newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}
return newStructTypeDef
}
return nil
return expr
}

func getExtendedGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
switch fieldType := field.(type) {
case *ast.ArrayType:
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt)
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt, genericParamTypeDefs)
return "[]" + fieldName, err
case *ast.StarExpr:
return getExtendedGenericFieldType(file, fieldType.X)
return getExtendedGenericFieldType(file, fieldType.X, genericParamTypeDefs)
case *ast.Ident:
if genericParamTypeDefs != nil {
if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok {
return typeSpec.Name, nil
}
}
if fieldType.Obj == nil {
return fieldType.Name, nil
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
default:
return getFieldType(file, field)
}
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
var fullName string
var baseName string
var err error
Expand All @@ -268,7 +293,7 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
fullName = baseName + "["

for _, index := range fieldType.Indices {
fieldName, err := getExtendedGenericFieldType(file, index)
fieldName, err := getExtendedGenericFieldType(file, index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -283,7 +308,7 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", err
}

indexName, err := getExtendedGenericFieldType(file, fieldType.Index)
indexName, err := getExtendedGenericFieldType(file, fieldType.Index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -307,27 +332,27 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
}

func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
switch fieldType := field.(type) {
case *ast.Ident:
if indexType.Obj == nil {
return getFieldType(file, field)
if fieldType.Obj == nil {
return fieldType.Name, nil
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.ArrayType:
tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil
}
return "", fmt.Errorf("unknown type %#v", field)
}
Expand All @@ -344,10 +369,10 @@ func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*
case *ast.MapType:
case *ast.FuncType:
case *ast.IndexExpr:
name, err := getExtendedGenericFieldType(file, expr)
name, err := getExtendedGenericFieldType(file, expr, nil)
if err == nil {
if schema, err := parser.getTypeSchema(name, file, false); err == nil {
return spec.MapProperty(schema), nil
return schema, nil
}
}

Expand Down
8 changes: 7 additions & 1 deletion generics_other.go
Expand Up @@ -9,6 +9,12 @@ import (
"go/ast"
)

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}
Expand All @@ -17,7 +23,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}

Expand Down
2 changes: 1 addition & 1 deletion parser.go
Expand Up @@ -1331,7 +1331,7 @@ func getFieldType(file *ast.File, field ast.Expr) (string, error) {

return fullName, nil
default:
return getGenericFieldType(file, field)
return getGenericFieldType(file, field, nil)
}
}

Expand Down
2 changes: 2 additions & 0 deletions testdata/generics_property/api/api.go
Expand Up @@ -22,6 +22,8 @@ type CreateMovie struct {
Producer types.Field[*Person]
Audience Audience[Person]
AudienceNames Audience[string]
Detail1 types.Field[types.Field[Person]]
Detail2 types.Field[types.Field[string]]
}

type Person struct {
Expand Down

0 comments on commit 7f90377

Please sign in to comment.