Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generics with package alias #1360

Merged
merged 8 commits into from Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
160 changes: 61 additions & 99 deletions generics.go
Expand Up @@ -6,69 +6,32 @@ package swag
import (
"errors"
"fmt"
"github.com/go-openapi/spec"
"go/ast"
"strings"
"sync"
"unicode"
)

var genericDefinitionsMutex = &sync.RWMutex{}
var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}
"github.com/go-openapi/spec"
)

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func (s *genericTypeSpec) Type() ast.Expr {
if s.TypeSpec != nil {
return &ast.SelectorExpr{
X: &ast.Ident{Name: ""},
Sel: &ast.Ident{Name: s.Name},
}
func (t *genericTypeSpec) TypeName() string {
if t.TypeSpec != nil {
return t.TypeSpec.TypeName()
}

return &ast.Ident{Name: s.Name}
}

func (s *genericTypeSpec) TypeDocName() string {
if s.TypeSpec != nil {
return strings.Replace(TypeDocName(s.TypeSpec.FullName(), s.TypeSpec.TypeSpec), "-", "_", -1)
}

return s.Name
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

if typeSpecDef.TypeSpec.TypeParams != nil {
fullName = fullName + "["
for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List {
if i > 0 {
fullName = fullName + "-"
}

fullName = fullName + typeParam.Names[0].Name
}
fullName = fullName + "]"
}

return fullName
return t.Name
}

func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
genericDefinitionsMutex.RLock()
tSpec, ok := genericsDefinitions[original][fullGenericForm]
genericDefinitionsMutex.RUnlock()
if ok {
return tSpec
if original == nil || original.TypeSpec.TypeParams == nil || len(original.TypeSpec.TypeParams.List) == 0 {
return original
}

pkgName := strings.Split(fullGenericForm, ".")[0]
genericTypeName, genericParams := splitStructName(fullGenericForm)
name, genericParams := splitGenericsTypeName(fullGenericForm)
if genericParams == nil {
return nil
}
Expand All @@ -88,76 +51,62 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
arrayDepth++
}

tdef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if tdef != nil && !strings.Contains(genericParam, ".") {
genericParam = fullTypeName(file.Name.Name, genericParam)
typeDef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if typeDef != nil {
genericParam = typeDef.TypeName()
if _, ok := pkgDefs.uniqueDefinitions[genericParam]; !ok {
pkgDefs.uniqueDefinitions[genericParam] = typeDef
}
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
TypeSpec: typeDef,
Name: genericParam,
}
}

parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Doc: original.TypeSpec.Doc,
Comment: original.TypeSpec.Comment,
Assign: original.TypeSpec.Assign,
},
}

ident := &ast.Ident{
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
}

if strings.Contains(genericTypeName, ".") {
genericTypeName = strings.Split(genericTypeName, ".")[1]
}

var typeName = []string{TypeDocName(fullTypeName(pkgName, genericTypeName), parametrizedTypeSpec.TypeSpec)}

name = fmt.Sprintf("%s%s-", string(IgnoreNameOverridePrefix), original.TypeName())
var nameParts []string
for _, def := range original.TypeSpec.TypeParams.List {
if specDef, ok := genericParamTypeDefs[def.Names[0].Name]; ok {
var prefix = ""
if specDef.ArrayDepth > 0 {
if specDef.ArrayDepth == 1 {
prefix = "array_"
if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
} else if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
typeName = append(typeName, prefix+specDef.TypeDocName())
nameParts = append(nameParts, prefix+specDef.TypeName())
}
}

ident.Name = strings.Join(typeName, "-")
ident.Name = strings.Replace(ident.Name, ".", "_", -1)
pkgNamePrefix := pkgName + "_"
if strings.HasPrefix(ident.Name, pkgNamePrefix) {
ident.Name = fullTypeName(pkgName, ident.Name[len(pkgNamePrefix):])
}
ident.Name = string(IgnoreNameOverridePrefix) + ident.Name

parametrizedTypeSpec.TypeSpec.Name = ident
name += strings.Replace(strings.Join(nameParts, "-"), ".", "_", -1)

newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency)
if typeSpec, ok := pkgDefs.uniqueDefinitions[name]; ok {
return typeSpec
}

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
parametrizedTypeSpec.TypeSpec.Type = newType
if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{
Name: name,
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
},
Type: pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency),
Doc: original.TypeSpec.Doc,
Assign: original.TypeSpec.Assign,
},
}
genericsDefinitions[original][fullGenericForm] = parametrizedTypeSpec
pkgDefs.uniqueDefinitions[name] = parametrizedTypeSpec

return parametrizedTypeSpec
}

// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// splitGenericsTypeName splits a generic struct name in his parts
func splitGenericsTypeName(fullGenericForm string) (string, []string) {
//remove all spaces character
fullGenericForm = strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
Expand Down Expand Up @@ -197,11 +146,24 @@ func splitStructName(fullGenericForm string) (string, []string) {
return genericTypeName, genericParams
}

func (pkgDefs *PackagesDefinitions) getParametrizedType(genTypeSpec *genericTypeSpec) ast.Expr {
if genTypeSpec.TypeSpec != nil && strings.Contains(genTypeSpec.Name, ".") {
parts := strings.SplitN(genTypeSpec.Name, ".", 2)
return &ast.SelectorExpr{
X: &ast.Ident{Name: parts[0]},
Sel: &ast.Ident{Name: parts[1]},
}
}

//a primitive type name or a type name in current package
return &ast.Ident{Name: genTypeSpec.Name}
}

func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
retType := genTypeSpec.Type()
retType := pkgDefs.getParametrizedType(genTypeSpec)
for i := 0; i < genTypeSpec.ArrayDepth; i++ {
retType = &ast.ArrayType{Elt: retType}
}
Expand All @@ -220,7 +182,7 @@ func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.
}
case *ast.IndexExpr, *ast.IndexListExpr:
fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs)
typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency)
typeDef := pkgDefs.FindTypeSpec(fullGenericName, file, parseDependency)
if typeDef != nil {
return typeDef.TypeSpec.Type
}
Expand Down Expand Up @@ -274,7 +236,7 @@ func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTyp
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
default:
return getFieldType(file, field)
}
Expand Down Expand Up @@ -343,14 +305,14 @@ func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
case *ast.ArrayType:
tSpec := &TypeSpecDef{
File: file,
TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
return tSpec.TypeName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil
}
Expand Down
4 changes: 0 additions & 4 deletions generics_other.go
Expand Up @@ -15,10 +15,6 @@ type genericTypeSpec struct {
Name string
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}
Expand Down
35 changes: 27 additions & 8 deletions generics_test.go
Expand Up @@ -103,20 +103,39 @@ func TestParseGenericsNames(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParseGenericsPackageAlias(t *testing.T) {
t.Parallel()

searchDir := "testdata/generics_package_alias"
expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json"))
assert.NoError(t, err)

p := New()
err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}

func TestParametrizeStruct(t *testing.T) {
pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
packages: make(map[string]*PackageDefinitions),
uniqueDefinitions: make(map[string]*TypeSpecDef),
}
// valid
typeSpec := pd.parametrizeGenericType(
&ast.File{Name: &ast.Ident{Name: "test2"}},
&TypeSpecDef{
File: &ast.File{Name: &ast.Ident{Name: "test"}},
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, []string]", false)
assert.NotNil(t, typeSpec)
assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name())
assert.Equal(t, "test.Field-string-array_string", typeSpec.TypeName())

// definition contains one type params, but two type params are provided
typeSpec = pd.parametrizeGenericType(
Expand Down Expand Up @@ -172,30 +191,30 @@ func TestParametrizeStruct(t *testing.T) {
assert.Nil(t, typeSpec)
}

func TestSplitStructNames(t *testing.T) {
func TestSplitGenericsTypeNames(t *testing.T) {
t.Parallel()

field, params := splitStructName("test.Field")
field, params := splitGenericsTypeName("test.Field")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field]")
field, params = splitGenericsTypeName("test.Field]")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string")
field, params = splitGenericsTypeName("test.Field[string")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string] ")
field, params = splitGenericsTypeName("test.Field[string] ")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string"}, params)

field, params = splitStructName("test.Field[string, []string]")
field, params = splitGenericsTypeName("test.Field[string, []string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string", "[]string"}, params)

field, params = splitStructName("test.Field[test.Field[ string, []string] ]")
field, params = splitGenericsTypeName("test.Field[test.Field[ string, []string] ]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"test.Field[string,[]string]"}, params)
}
Expand Down