From e5d507dd472777bec3a3744f7f9cc86065037530 Mon Sep 17 00:00:00 2001 From: sdghchj Date: Tue, 22 Nov 2022 11:46:06 +0800 Subject: [PATCH] enhancement for PR #1387: evaluate const across packages (#1388) * enhancement: evaluate const across packages --- const.go | 113 +-------------------------------- enums.go | 2 +- gen/gen.go | 5 +- generics.go | 16 ++--- generics_other.go | 2 +- generics_other_test.go | 4 +- generics_test.go | 15 ++--- package.go | 108 ++++++++++++++++++++++++++++++- packages.go | 82 +++++++++++++++++++----- packages_test.go | 10 +-- parser.go | 12 +++- parser_test.go | 28 +++----- testdata/enums/consts/const.go | 3 + testdata/enums/types/model.go | 10 +-- 14 files changed, 229 insertions(+), 181 deletions(-) create mode 100644 testdata/enums/consts/const.go diff --git a/const.go b/const.go index ffe3a50cf..d9f970bb1 100644 --- a/const.go +++ b/const.go @@ -2,121 +2,14 @@ package swag import ( "go/ast" - "go/token" - "strconv" ) -// ConstVariable a model to record an const variable +// ConstVariable a model to record a const variable type ConstVariable struct { Name *ast.Ident Type ast.Expr Value interface{} Comment *ast.CommentGroup -} - -// EvaluateValue evaluate the value -func (cv *ConstVariable) EvaluateValue(constTable map[string]*ConstVariable) interface{} { - if expr, ok := cv.Value.(ast.Expr); ok { - value, evalType := evaluateConstValue(cv.Name.Name, cv.Name.Obj.Data.(int), expr, constTable, make(map[string]struct{})) - if cv.Type == nil && evalType != nil { - cv.Type = evalType - } - if value != nil { - cv.Value = value - } - return value - } - return cv.Value -} - -func evaluateConstValue(name string, iota int, expr ast.Expr, constTable map[string]*ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { - if len(name) > 0 { - if _, ok := recursiveStack[name]; ok { - return nil, nil - } - recursiveStack[name] = struct{}{} - } - - switch valueExpr := expr.(type) { - case *ast.Ident: - if valueExpr.Name == "iota" { - return iota, nil - } - if constTable != nil { - if cv, ok := constTable[valueExpr.Name]; ok { - if expr, ok = cv.Value.(ast.Expr); ok { - value, evalType := evaluateConstValue(valueExpr.Name, cv.Name.Obj.Data.(int), expr, constTable, recursiveStack) - if cv.Type == nil { - cv.Type = evalType - } - if value != nil { - cv.Value = value - } - return value, evalType - } - return cv.Value, cv.Type - } - } - case *ast.BasicLit: - switch valueExpr.Kind { - case token.INT: - x, err := strconv.ParseInt(valueExpr.Value, 10, 64) - if err != nil { - return nil, nil - } - return int(x), nil - case token.STRING, token.CHAR: - return valueExpr.Value[1 : len(valueExpr.Value)-1], nil - } - case *ast.UnaryExpr: - x, evalType := evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) - switch valueExpr.Op { - case token.SUB: - return -x.(int), evalType - case token.XOR: - return ^(x.(int)), evalType - } - case *ast.BinaryExpr: - x, evalTypex := evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) - y, evalTypey := evaluateConstValue("", iota, valueExpr.Y, constTable, recursiveStack) - evalType := evalTypex - if evalType == nil { - evalType = evalTypey - } - switch valueExpr.Op { - case token.ADD: - if ix, ok := x.(int); ok { - return ix + y.(int), evalType - } else if sx, ok := x.(string); ok { - return sx + y.(string), evalType - } - case token.SUB: - return x.(int) - y.(int), evalType - case token.MUL: - return x.(int) * y.(int), evalType - case token.QUO: - return x.(int) / y.(int), evalType - case token.REM: - return x.(int) % y.(int), evalType - case token.AND: - return x.(int) & y.(int), evalType - case token.OR: - return x.(int) | y.(int), evalType - case token.XOR: - return x.(int) ^ y.(int), evalType - case token.SHL: - return x.(int) << y.(int), evalType - case token.SHR: - return x.(int) >> y.(int), evalType - } - case *ast.ParenExpr: - return evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) - case *ast.CallExpr: - //data conversion - if ident, ok := valueExpr.Fun.(*ast.Ident); ok && len(valueExpr.Args) == 1 && IsGolangPrimitiveType(ident.Name) { - arg, _ := evaluateConstValue("", iota, valueExpr.Args[0], constTable, recursiveStack) - return arg, nil - } - } - return nil, nil + File *ast.File + Pkg *PackageDefinitions } diff --git a/enums.go b/enums.go index 4dc5547ff..38658f20a 100644 --- a/enums.go +++ b/enums.go @@ -5,7 +5,7 @@ const ( enumCommentsExtension = "x-enum-comments" ) -// EnumValue a model to record an enum const variable +// EnumValue a model to record an enum consts variable type EnumValue struct { key string Value interface{} diff --git a/gen/gen.go b/gen/gen.go index 8933d03f9..87e7d0d20 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -162,7 +162,9 @@ func (g *Gen) Build(config *Config) error { g.debug.Printf("Generate swagger docs....") - p := swag.New(swag.SetMarkdownFileDirectory(config.MarkdownFilesDir), + p := swag.New( + swag.SetParseDependency(config.ParseDependency), + swag.SetMarkdownFileDirectory(config.MarkdownFilesDir), swag.SetDebugger(config.Debugger), swag.SetExcludedDirsAndFiles(config.Excludes), swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir), @@ -174,7 +176,6 @@ func (g *Gen) Build(config *Config) error { p.PropNamingStrategy = config.PropNamingStrategy p.ParseVendor = config.ParseVendor - p.ParseDependency = config.ParseDependency p.ParseInternal = config.ParseInternal p.RequiredByDefault = config.RequiredByDefault diff --git a/generics.go b/generics.go index e04bc89ea..652f4a8d0 100644 --- a/generics.go +++ b/generics.go @@ -26,7 +26,7 @@ func (t *genericTypeSpec) TypeName() string { return t.Name } -func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef { +func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string) *TypeSpecDef { if original == nil || original.TypeSpec.TypeParams == nil || len(original.TypeSpec.TypeParams.List) == 0 { return original } @@ -51,7 +51,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi arrayDepth++ } - typeDef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency) + typeDef := pkgDefs.FindTypeSpec(genericParam, file) if typeDef != nil { genericParam = typeDef.TypeName() if _, ok := pkgDefs.uniqueDefinitions[genericParam]; !ok { @@ -95,7 +95,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi NamePos: original.TypeSpec.Name.NamePos, Obj: original.TypeSpec.Name.Obj, }, - Type: pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency), + Type: pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs), Doc: original.TypeSpec.Doc, Assign: original.TypeSpec.Assign, }, @@ -159,7 +159,7 @@ func (pkgDefs *PackagesDefinitions) getParametrizedType(genTypeSpec *genericType return &ast.Ident{Name: genTypeSpec.Name} } -func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr { +func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr { switch astExpr := expr.(type) { case *ast.Ident: if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok { @@ -171,18 +171,18 @@ func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast. } case *ast.ArrayType: return &ast.ArrayType{ - Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs, parseDependency), + Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs), Len: astExpr.Len, Lbrack: astExpr.Lbrack, } case *ast.StarExpr: return &ast.StarExpr{ Star: astExpr.Star, - X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs, parseDependency), + X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs), } case *ast.IndexExpr, *ast.IndexListExpr: fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs) - typeDef := pkgDefs.FindTypeSpec(fullGenericName, file, parseDependency) + typeDef := pkgDefs.FindTypeSpec(fullGenericName, file) if typeDef != nil { return typeDef.TypeSpec.Type } @@ -205,7 +205,7 @@ func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast. Comment: field.Comment, } - newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs, parseDependency) + newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs) newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField) } diff --git a/generics_other.go b/generics_other.go index deab83167..5fd9e8231 100644 --- a/generics_other.go +++ b/generics_other.go @@ -15,7 +15,7 @@ type genericTypeSpec struct { Name string } -func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef { +func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string) *TypeSpecDef { return original } diff --git a/generics_other_test.go b/generics_other_test.go index f9c9a3f8c..1a396a04b 100644 --- a/generics_other_test.go +++ b/generics_other_test.go @@ -32,10 +32,10 @@ func TestParametrizeStruct(t *testing.T) { }, } - tr := pd.parametrizeGenericType(&ast.File{}, tSpec, "", false) + tr := pd.parametrizeGenericType(&ast.File{}, tSpec, "") assert.Equal(t, tr, tSpec) - tr = pd.parametrizeGenericType(&ast.File{}, tSpec, "", true) + tr = pd.parametrizeGenericType(&ast.File{}, tSpec, "") assert.Equal(t, tr, tSpec) } diff --git a/generics_test.go b/generics_test.go index d62f3898c..48dde59a9 100644 --- a/generics_test.go +++ b/generics_test.go @@ -110,8 +110,7 @@ func TestParseGenericsPackageAlias(t *testing.T) { expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) assert.NoError(t, err) - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.NoError(t, err) b, err := json.MarshalIndent(p.swagger, "", " ") @@ -133,7 +132,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string, []string]", false) + }}, "test.Field[string, []string]") assert.NotNil(t, typeSpec) assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name()) assert.Equal(t, "test.Field-string-array_string", typeSpec.TypeName()) @@ -146,7 +145,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string, string]", false) + }}, "test.Field[string, string]") assert.Nil(t, typeSpec) // definition contains two type params, but only one is used @@ -157,7 +156,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string]", false) + }}, "test.Field[string]") assert.Nil(t, typeSpec) // name is not a valid type name @@ -168,7 +167,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string", false) + }}, "test.Field[string") assert.Nil(t, typeSpec) typeSpec = pd.parametrizeGenericType( @@ -178,7 +177,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string, [string]", false) + }}, "test.Field[string, [string]") assert.Nil(t, typeSpec) typeSpec = pd.parametrizeGenericType( @@ -188,7 +187,7 @@ func TestParametrizeStruct(t *testing.T) { Name: &ast.Ident{Name: "Field"}, TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, - }}, "test.Field[string, ]string]", false) + }}, "test.Field[string, ]string]") assert.Nil(t, typeSpec) } diff --git a/package.go b/package.go index c11fe0710..bcce3f65c 100644 --- a/package.go +++ b/package.go @@ -1,6 +1,10 @@ package swag -import "go/ast" +import ( + "go/ast" + "go/token" + "strconv" +) // PackageDefinitions files and definition in a package. type PackageDefinitions struct { @@ -18,12 +22,22 @@ type PackageDefinitions struct { // package name Name string + + // package path + Path string +} + +// ConstVariableGlobalEvaluator an interface used to evaluate enums across packages +type ConstVariableGlobalEvaluator interface { + EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) + EvaluateConstValueByName(file *ast.File, pkgPath, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr) } // NewPackageDefinitions new a PackageDefinitions object -func NewPackageDefinitions(name string) *PackageDefinitions { +func NewPackageDefinitions(name, pkgPath string) *PackageDefinitions { return &PackageDefinitions{ Name: name, + Path: pkgPath, Files: make(map[string]*ast.File), TypeDefinitions: make(map[string]*TypeSpecDef), ConstTable: make(map[string]*ConstVariable), @@ -43,16 +57,104 @@ func (pkg *PackageDefinitions) AddTypeSpec(name string, typeSpec *TypeSpecDef) * } // AddConst add a const variable. -func (pkg *PackageDefinitions) AddConst(valueSpec *ast.ValueSpec) *PackageDefinitions { +func (pkg *PackageDefinitions) AddConst(astFile *ast.File, valueSpec *ast.ValueSpec) *PackageDefinitions { for i := 0; i < len(valueSpec.Names) && i < len(valueSpec.Values); i++ { variable := &ConstVariable{ Name: valueSpec.Names[i], Type: valueSpec.Type, Value: valueSpec.Values[i], Comment: valueSpec.Comment, + File: astFile, } pkg.ConstTable[valueSpec.Names[i].Name] = variable pkg.OrderedConst = append(pkg.OrderedConst, variable) } return pkg } + +func (pkg *PackageDefinitions) evaluateConstValue(file *ast.File, iota int, expr ast.Expr, globalEvaluator ConstVariableGlobalEvaluator, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + switch valueExpr := expr.(type) { + case *ast.Ident: + if valueExpr.Name == "iota" { + return iota, nil + } + if pkg.ConstTable != nil { + if cv, ok := pkg.ConstTable[valueExpr.Name]; ok { + return globalEvaluator.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + case *ast.SelectorExpr: + pkgIdent, ok := valueExpr.X.(*ast.Ident) + if !ok { + return nil, nil + } + return globalEvaluator.EvaluateConstValueByName(file, pkgIdent.Name, valueExpr.Sel.Name, recursiveStack) + case *ast.BasicLit: + switch valueExpr.Kind { + case token.INT: + x, err := strconv.ParseInt(valueExpr.Value, 10, 64) + if err != nil { + return nil, nil + } + return int(x), nil + case token.STRING, token.CHAR: + return valueExpr.Value[1 : len(valueExpr.Value)-1], nil + } + case *ast.UnaryExpr: + x, evalType := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + if x == nil { + return nil, nil + } + switch valueExpr.Op { + case token.SUB: + return -x.(int), evalType + case token.XOR: + return ^(x.(int)), evalType + } + case *ast.BinaryExpr: + x, evalTypex := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + y, evalTypey := pkg.evaluateConstValue(file, iota, valueExpr.Y, globalEvaluator, recursiveStack) + if x == nil || y == nil { + return nil, nil + } + evalType := evalTypex + if evalType == nil { + evalType = evalTypey + } + switch valueExpr.Op { + case token.ADD: + if ix, ok := x.(int); ok { + return ix + y.(int), evalType + } else if sx, ok := x.(string); ok { + return sx + y.(string), evalType + } + case token.SUB: + return x.(int) - y.(int), evalType + case token.MUL: + return x.(int) * y.(int), evalType + case token.QUO: + return x.(int) / y.(int), evalType + case token.REM: + return x.(int) % y.(int), evalType + case token.AND: + return x.(int) & y.(int), evalType + case token.OR: + return x.(int) | y.(int), evalType + case token.XOR: + return x.(int) ^ y.(int), evalType + case token.SHL: + return x.(int) << y.(int), evalType + case token.SHR: + return x.(int) >> y.(int), evalType + } + case *ast.ParenExpr: + return pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + case *ast.CallExpr: + //data conversion + if ident, ok := valueExpr.Fun.(*ast.Ident); ok && len(valueExpr.Args) == 1 && IsGolangPrimitiveType(ident.Name) { + arg, _ := pkg.evaluateConstValue(file, iota, valueExpr.Args[0], globalEvaluator, recursiveStack) + return arg, nil + } + } + return nil, nil +} diff --git a/packages.go b/packages.go index 74547bc7e..c00ebd12c 100644 --- a/packages.go +++ b/packages.go @@ -18,6 +18,7 @@ type PackagesDefinitions struct { files map[*ast.File]*AstFileInfo packages map[string]*PackageDefinitions uniqueDefinitions map[string]*TypeSpecDef + parseDependency bool } // NewPackagesDefinitions create object PackagesDefinitions. @@ -59,7 +60,7 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF dependency.Files[path] = astFile } else { - pkgDefs.packages[packageDir] = NewPackageDefinitions(astFile.Name.Name).AddFile(path, astFile) + pkgDefs.packages[packageDir] = NewPackageDefinitions(astFile.Name.Name, packageDir).AddFile(path, astFile) } pkgDefs.files[astFile] = &AstFileInfo{ @@ -106,7 +107,7 @@ func (pkgDefs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, erro pkgDefs.parseFunctionScopedTypesFromFile(astFile, info.PackagePath, parsedSchemas) } pkgDefs.removeAllNotUniqueTypes() - pkgDefs.evaluateConstVariables() + pkgDefs.evaluateAllConstVariables() pkgDefs.collectConstEnums(parsedSchemas) return parsedSchemas, nil } @@ -159,14 +160,14 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag } if pkgDefs.packages[typeSpecDef.PkgPath] == nil { - pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name).AddTypeSpec(typeSpecDef.Name(), typeSpecDef) + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name, typeSpecDef.PkgPath).AddTypeSpec(typeSpecDef.Name(), typeSpecDef) } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(typeSpecDef.Name(), typeSpecDef) } } } } else if generalDeclaration.Tok == token.CONST { - // collect const + // collect consts pkgDefs.collectConstVariables(astFile, packagePath, generalDeclaration) } } @@ -221,7 +222,7 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as } if pkgDefs.packages[typeSpecDef.PkgPath] == nil { - pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name).AddTypeSpec(fullName, typeSpecDef) + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name, typeSpecDef.PkgPath).AddTypeSpec(fullName, typeSpecDef) } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName]; !ok { pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(fullName, typeSpecDef) } @@ -238,7 +239,7 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as func (pkgDefs *PackagesDefinitions) collectConstVariables(astFile *ast.File, packagePath string, generalDeclaration *ast.GenDecl) { pkg, ok := pkgDefs.packages[packagePath] if !ok { - pkg = NewPackageDefinitions(astFile.Name.Name) + pkg = NewPackageDefinitions(astFile.Name.Name, packagePath) pkgDefs.packages[packagePath] = pkg } @@ -254,19 +255,66 @@ func (pkgDefs *PackagesDefinitions) collectConstVariables(astFile *ast.File, pac valueSpec.Type = lastValueSpec.Type valueSpec.Values = lastValueSpec.Values } - pkg.AddConst(valueSpec) + pkg.AddConst(astFile, valueSpec) } } -func (pkgDefs *PackagesDefinitions) evaluateConstVariables() { - //TODO evaluate enum cross packages +func (pkgDefs *PackagesDefinitions) evaluateAllConstVariables() { for _, pkg := range pkgDefs.packages { for _, constVar := range pkg.OrderedConst { - constVar.EvaluateValue(pkg.ConstTable) + pkgDefs.EvaluateConstValue(pkg, constVar, nil) } } } +// EvaluateConstValue evaluate a const variable. +func (pkgDefs *PackagesDefinitions) EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + if expr, ok := cv.Value.(ast.Expr); ok { + if recursiveStack == nil { + recursiveStack = make(map[string]struct{}) + } + fullConstName := fullTypeName(pkg.Path, cv.Name.Name) + if _, ok = recursiveStack[fullConstName]; ok { + return nil, nil + } + recursiveStack[fullConstName] = struct{}{} + + value, evalType := pkg.evaluateConstValue(cv.File, cv.Name.Obj.Data.(int), expr, pkgDefs, recursiveStack) + if cv.Type == nil && evalType != nil { + cv.Type = evalType + } + if value != nil { + cv.Value = value + } + return value, cv.Type + } + return cv.Value, cv.Type +} + +// EvaluateConstValueByName evaluate a const variable by name. +func (pkgDefs *PackagesDefinitions) EvaluateConstValueByName(file *ast.File, pkgName, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + matchedPkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports(pkgName, file) + for _, pkgPath := range matchedPkgPaths { + if pkg, ok := pkgDefs.packages[pkgPath]; ok { + if cv, ok := pkg.ConstTable[constVariableName]; ok { + return pkgDefs.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + } + if pkgDefs.parseDependency { + for _, pkgPath := range externalPkgPaths { + if err := pkgDefs.loadExternalPackage(pkgPath); err == nil { + if pkg, ok := pkgDefs.packages[pkgPath]; ok { + if cv, ok := pkg.ConstTable[constVariableName]; ok { + return pkgDefs.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + } + } + } + return nil, nil +} + func (pkgDefs *PackagesDefinitions) collectConstEnums(parsedSchemas map[*TypeSpecDef]*Schema) { for _, pkg := range pkgDefs.packages { for _, constVar := range pkg.OrderedConst { @@ -434,7 +482,7 @@ func (pkgDefs *PackagesDefinitions) findPackagePathFromImports(pkg string, file return } -func (pkgDefs *PackagesDefinitions) findTypeSpecFromPackagePaths(matchedPkgPaths, externalPkgPaths []string, name string, parseDependency bool) (typeDef *TypeSpecDef) { +func (pkgDefs *PackagesDefinitions) findTypeSpecFromPackagePaths(matchedPkgPaths, externalPkgPaths []string, name string) (typeDef *TypeSpecDef) { for _, pkgPath := range matchedPkgPaths { typeDef = pkgDefs.findTypeSpec(pkgPath, name) if typeDef != nil { @@ -442,7 +490,7 @@ func (pkgDefs *PackagesDefinitions) findTypeSpecFromPackagePaths(matchedPkgPaths } } - if parseDependency { + if pkgDefs.parseDependency { for _, pkgPath := range externalPkgPaths { if err := pkgDefs.loadExternalPackage(pkgPath); err == nil { typeDef = pkgDefs.findTypeSpec(pkgPath, name) @@ -460,7 +508,7 @@ func (pkgDefs *PackagesDefinitions) findTypeSpecFromPackagePaths(matchedPkgPaths // @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file // @file the ast.file in which @typeName is used // @pkgPath the package path of @file. -func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef { +func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef { if IsGolangPrimitiveType(typeName) { return nil } @@ -477,8 +525,8 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File } pkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports(parts[0], file) - typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, parts[1], parseDependency) - return pkgDefs.parametrizeGenericType(file, typeDef, typeName, parseDependency) + typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, parts[1]) + return pkgDefs.parametrizeGenericType(file, typeDef, typeName) } typeDef, ok := pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, typeName)] @@ -496,7 +544,7 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File typeDef, ok = pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, name)] if !ok { pkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports("", file) - typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, name, parseDependency) + typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, name) } - return pkgDefs.parametrizeGenericType(file, typeDef, typeName, parseDependency) + return pkgDefs.parametrizeGenericType(file, typeDef, typeName) } diff --git a/packages_test.go b/packages_test.go index ad012ee92..fa736d6a8 100644 --- a/packages_test.go +++ b/packages_test.go @@ -173,12 +173,12 @@ func TestPackagesDefinitions_FindTypeSpec(t *testing.T) { } var nilDef *TypeSpecDef - assert.Equal(t, nilDef, pkg.FindTypeSpec("int", nil, false)) - assert.Equal(t, nilDef, pkg.FindTypeSpec("bool", nil, false)) - assert.Equal(t, nilDef, pkg.FindTypeSpec("string", nil, false)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("int", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("bool", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("string", nil)) - assert.Equal(t, &userDef, pkg.FindTypeSpec("user.Model", nil, false)) - assert.Equal(t, nilDef, pkg.FindTypeSpec("Model", nil, false)) + assert.Equal(t, &userDef, pkg.FindTypeSpec("user.Model", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("Model", nil)) } func TestPackage_rangeFiles(t *testing.T) { diff --git a/parser.go b/parser.go index e7412fa06..088ec5f09 100644 --- a/parser.go +++ b/parser.go @@ -214,6 +214,16 @@ func New(options ...func(*Parser)) *Parser { return parser } +// SetParseDependency sets whether to parse the dependent packages. +func SetParseDependency(parseDependency bool) func(*Parser) { + return func(p *Parser) { + p.ParseDependency = parseDependency + if p.packages != nil { + p.packages.parseDependency = parseDependency + } + } +} + // SetMarkdownFileDirectory sets the directory to search for markdown files. func SetMarkdownFileDirectory(directoryPath string) func(*Parser) { return func(p *Parser) { @@ -925,7 +935,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( return PrimitiveSchema(schemaType), nil } - typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency) + typeSpecDef := parser.packages.FindTypeSpec(typeName, file) if typeSpecDef == nil { return nil, fmt.Errorf("cannot find type definition: %s", typeName) } diff --git a/parser_test.go b/parser_test.go index 644227f6b..a87b8bfea 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2156,8 +2156,7 @@ func TestParseNested(t *testing.T) { t.Parallel() searchDir := "testdata/nested" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.NoError(t, err) @@ -2172,8 +2171,7 @@ func TestParseDuplicated(t *testing.T) { t.Parallel() searchDir := "testdata/duplicated" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.Errorf(t, err, "duplicated @id declarations successfully found") } @@ -2182,8 +2180,7 @@ func TestParseDuplicatedOtherMethods(t *testing.T) { t.Parallel() searchDir := "testdata/duplicated2" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.Errorf(t, err, "duplicated @id declarations successfully found") } @@ -2192,8 +2189,7 @@ func TestParseDuplicatedFunctionScoped(t *testing.T) { t.Parallel() searchDir := "testdata/duplicated_function_scoped" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.Errorf(t, err, "duplicated @id declarations successfully found") } @@ -2202,8 +2198,7 @@ func TestParseConflictSchemaName(t *testing.T) { t.Parallel() searchDir := "testdata/conflict_name" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.NoError(t, err) b, _ := json.MarshalIndent(p.swagger, "", " ") @@ -2215,8 +2210,7 @@ func TestParseConflictSchemaName(t *testing.T) { func TestParseExternalModels(t *testing.T) { searchDir := "testdata/external_models/main" mainAPIFile := "main.go" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.NoError(t, err) b, _ := json.MarshalIndent(p.swagger, "", " ") @@ -2228,9 +2222,7 @@ func TestParseExternalModels(t *testing.T) { func TestParseGoList(t *testing.T) { mainAPIFile := "main.go" - p := New(ParseUsingGoList(true)) - p.ParseDependency = true - + p := New(ParseUsingGoList(true), SetParseDependency(true)) go111moduleEnv := os.Getenv("GO111MODULE") cases := []struct { @@ -2444,8 +2436,7 @@ type ResponseWrapper struct { } } }` - parser := New() - parser.ParseDependency = true + parser := New(SetParseDependency(true)) f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) assert.NoError(t, err) @@ -3080,8 +3071,7 @@ func TestParseOutsideDependencies(t *testing.T) { searchDir := "testdata/pare_outside_dependencies" mainAPIFile := "cmd/main.go" - p := New() - p.ParseDependency = true + p := New(SetParseDependency(true)) if err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth); err != nil { t.Error("Failed to parse api: " + err.Error()) } diff --git a/testdata/enums/consts/const.go b/testdata/enums/consts/const.go new file mode 100644 index 000000000..db3ff311e --- /dev/null +++ b/testdata/enums/consts/const.go @@ -0,0 +1,3 @@ +package consts + +const Base = 1 diff --git a/testdata/enums/types/model.go b/testdata/enums/types/model.go index cd5c8e08e..79a4e83f3 100644 --- a/testdata/enums/types/model.go +++ b/testdata/enums/types/model.go @@ -1,13 +1,15 @@ package types -type Class int +import ( + "github.com/swaggo/swag/testdata/enums/consts" +) -const Base = 1 +type Class int const ( None Class = -1 - A Class = Base + (iota+1-1)*2/2%100 - (1&1 | 1) + (2 ^ 2) // AAA - B /* BBB */ + A Class = consts.Base + (iota+1-1)*2/2%100 - (1&1 | 1) + (2 ^ 2) // AAA + B /* BBB */ C D F = D + 1