Skip to content

Commit

Permalink
Fix generics with package alias (#1360)
Browse files Browse the repository at this point in the history
* fix issue #1353 about generics
Signed-off-by: sdghchj <sdghchj@qq.com>
  • Loading branch information
sdghchj committed Oct 26, 2022
1 parent da6d718 commit 0da94ff
Show file tree
Hide file tree
Showing 22 changed files with 627 additions and 426 deletions.
160 changes: 61 additions & 99 deletions generics.go
Expand Up @@ -6,69 +6,32 @@ package swag
import (
"errors"
"fmt"
"github.com/go-openapi/spec"
"go/ast"
"strings"
"sync"
"unicode"
)

var genericDefinitionsMutex = &sync.RWMutex{}
var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}
"github.com/go-openapi/spec"
)

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func (s *genericTypeSpec) Type() ast.Expr {
if s.TypeSpec != nil {
return &ast.SelectorExpr{
X: &ast.Ident{Name: ""},
Sel: &ast.Ident{Name: s.Name},
}
func (t *genericTypeSpec) TypeName() string {
if t.TypeSpec != nil {
return t.TypeSpec.TypeName()
}

return &ast.Ident{Name: s.Name}
}

func (s *genericTypeSpec) TypeDocName() string {
if s.TypeSpec != nil {
return strings.Replace(TypeDocName(s.TypeSpec.FullName(), s.TypeSpec.TypeSpec), "-", "_", -1)
}

return s.Name
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

if typeSpecDef.TypeSpec.TypeParams != nil {
fullName = fullName + "["
for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List {
if i > 0 {
fullName = fullName + "-"
}

fullName = fullName + typeParam.Names[0].Name
}
fullName = fullName + "]"
}

return fullName
return t.Name
}

func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
genericDefinitionsMutex.RLock()
tSpec, ok := genericsDefinitions[original][fullGenericForm]
genericDefinitionsMutex.RUnlock()
if ok {
return tSpec
if original == nil || original.TypeSpec.TypeParams == nil || len(original.TypeSpec.TypeParams.List) == 0 {
return original
}

pkgName := strings.Split(fullGenericForm, ".")[0]
genericTypeName, genericParams := splitStructName(fullGenericForm)
name, genericParams := splitGenericsTypeName(fullGenericForm)
if genericParams == nil {
return nil
}
Expand All @@ -88,76 +51,62 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
arrayDepth++
}

tdef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if tdef != nil && !strings.Contains(genericParam, ".") {
genericParam = fullTypeName(file.Name.Name, genericParam)
typeDef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if typeDef != nil {
genericParam = typeDef.TypeName()
if _, ok := pkgDefs.uniqueDefinitions[genericParam]; !ok {
pkgDefs.uniqueDefinitions[genericParam] = typeDef
}
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
TypeSpec: typeDef,
Name: genericParam,
}
}

parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Doc: original.TypeSpec.Doc,
Comment: original.TypeSpec.Comment,
Assign: original.TypeSpec.Assign,
},
}

ident := &ast.Ident{
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
}

if strings.Contains(genericTypeName, ".") {
genericTypeName = strings.Split(genericTypeName, ".")[1]
}

var typeName = []string{TypeDocName(fullTypeName(pkgName, genericTypeName), parametrizedTypeSpec.TypeSpec)}

name = fmt.Sprintf("%s%s-", string(IgnoreNameOverridePrefix), original.TypeName())
var nameParts []string
for _, def := range original.TypeSpec.TypeParams.List {
if specDef, ok := genericParamTypeDefs[def.Names[0].Name]; ok {
var prefix = ""
if specDef.ArrayDepth > 0 {
if specDef.ArrayDepth == 1 {
prefix = "array_"
if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
} else if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
typeName = append(typeName, prefix+specDef.TypeDocName())
nameParts = append(nameParts, prefix+specDef.TypeName())
}
}

ident.Name = strings.Join(typeName, "-")
ident.Name = strings.Replace(ident.Name, ".", "_", -1)
pkgNamePrefix := pkgName + "_"
if strings.HasPrefix(ident.Name, pkgNamePrefix) {
ident.Name = fullTypeName(pkgName, ident.Name[len(pkgNamePrefix):])
}
ident.Name = string(IgnoreNameOverridePrefix) + ident.Name

parametrizedTypeSpec.TypeSpec.Name = ident
name += strings.Replace(strings.Join(nameParts, "-"), ".", "_", -1)

newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency)
if typeSpec, ok := pkgDefs.uniqueDefinitions[name]; ok {
return typeSpec
}

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
parametrizedTypeSpec.TypeSpec.Type = newType
if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{
Name: name,
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
},
Type: pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency),
Doc: original.TypeSpec.Doc,
Assign: original.TypeSpec.Assign,
},
}
genericsDefinitions[original][fullGenericForm] = parametrizedTypeSpec
pkgDefs.uniqueDefinitions[name] = parametrizedTypeSpec

return parametrizedTypeSpec
}

// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// splitGenericsTypeName splits a generic struct name in his parts
func splitGenericsTypeName(fullGenericForm string) (string, []string) {
//remove all spaces character
fullGenericForm = strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
Expand Down Expand Up @@ -197,11 +146,24 @@ func splitStructName(fullGenericForm string) (string, []string) {
return genericTypeName, genericParams
}

func (pkgDefs *PackagesDefinitions) getParametrizedType(genTypeSpec *genericTypeSpec) ast.Expr {
if genTypeSpec.TypeSpec != nil && strings.Contains(genTypeSpec.Name, ".") {
parts := strings.SplitN(genTypeSpec.Name, ".", 2)
return &ast.SelectorExpr{
X: &ast.Ident{Name: parts[0]},
Sel: &ast.Ident{Name: parts[1]},
}
}

//a primitive type name or a type name in current package
return &ast.Ident{Name: genTypeSpec.Name}
}

func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
retType := genTypeSpec.Type()
retType := pkgDefs.getParametrizedType(genTypeSpec)
for i := 0; i < genTypeSpec.ArrayDepth; i++ {
retType = &ast.ArrayType{Elt: retType}
}
Expand All @@ -220,7 +182,7 @@ func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.
}
case *ast.IndexExpr, *ast.IndexListExpr:
fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs)
typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency)
typeDef := pkgDefs.FindTypeSpec(fullGenericName, file, parseDependency)
if typeDef != nil {
return typeDef.TypeSpec.Type
}
Expand Down Expand Up @@ -274,7 +236,7 @@ func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTyp
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
default:
return getFieldType(file, field)
}
Expand Down Expand Up @@ -343,14 +305,14 @@ func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
case *ast.ArrayType:
tSpec := &TypeSpecDef{
File: file,
TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil
}
Expand Down
4 changes: 0 additions & 4 deletions generics_other.go
Expand Up @@ -15,10 +15,6 @@ type genericTypeSpec struct {
Name string
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}
Expand Down
35 changes: 27 additions & 8 deletions generics_test.go
Expand Up @@ -103,20 +103,39 @@ func TestParseGenericsNames(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParseGenericsPackageAlias(t *testing.T) {
t.Parallel()

searchDir := "testdata/generics_package_alias"
expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json"))
assert.NoError(t, err)

p := New()
err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}

func TestParametrizeStruct(t *testing.T) {
pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
packages: make(map[string]*PackageDefinitions),
uniqueDefinitions: make(map[string]*TypeSpecDef),
}
// valid
typeSpec := pd.parametrizeGenericType(
&ast.File{Name: &ast.Ident{Name: "test2"}},
&TypeSpecDef{
File: &ast.File{Name: &ast.Ident{Name: "test"}},
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, []string]", false)
assert.NotNil(t, typeSpec)
assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name())
assert.Equal(t, "test.Field-string-array_string", typeSpec.TypeName())

// definition contains one type params, but two type params are provided
typeSpec = pd.parametrizeGenericType(
Expand Down Expand Up @@ -172,30 +191,30 @@ func TestParametrizeStruct(t *testing.T) {
assert.Nil(t, typeSpec)
}

func TestSplitStructNames(t *testing.T) {
func TestSplitGenericsTypeNames(t *testing.T) {
t.Parallel()

field, params := splitStructName("test.Field")
field, params := splitGenericsTypeName("test.Field")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field]")
field, params = splitGenericsTypeName("test.Field]")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string")
field, params = splitGenericsTypeName("test.Field[string")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string] ")
field, params = splitGenericsTypeName("test.Field[string] ")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string"}, params)

field, params = splitStructName("test.Field[string, []string]")
field, params = splitGenericsTypeName("test.Field[string, []string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string", "[]string"}, params)

field, params = splitStructName("test.Field[test.Field[ string, []string] ]")
field, params = splitGenericsTypeName("test.Field[test.Field[ string, []string] ]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"test.Field[string,[]string]"}, params)
}
Expand Down

0 comments on commit 0da94ff

Please sign in to comment.