diff --git a/cmd/swag/main.go b/cmd/swag/main.go index 0a18900f0..579178032 100644 --- a/cmd/swag/main.go +++ b/cmd/swag/main.go @@ -112,8 +112,8 @@ var initFlags = []cli.Flag{ }, } -func initAction(c *cli.Context) error { - strategy := c.String(propertyStrategyFlag) +func initAction(ctx *cli.Context) error { + strategy := ctx.String(propertyStrategyFlag) switch strategy { case swag.CamelCase, swag.SnakeCase, swag.PascalCase: @@ -121,27 +121,27 @@ func initAction(c *cli.Context) error { return fmt.Errorf("not supported %s propertyStrategy", strategy) } - outputTypes := strings.Split(c.String(outputTypesFlag), ",") + outputTypes := strings.Split(ctx.String(outputTypesFlag), ",") if len(outputTypes) == 0 { return fmt.Errorf("no output types specified") } return gen.New().Build(&gen.Config{ - SearchDir: c.String(searchDirFlag), - Excludes: c.String(excludeFlag), - MainAPIFile: c.String(generalInfoFlag), + SearchDir: ctx.String(searchDirFlag), + Excludes: ctx.String(excludeFlag), + MainAPIFile: ctx.String(generalInfoFlag), PropNamingStrategy: strategy, - OutputDir: c.String(outputFlag), + OutputDir: ctx.String(outputFlag), OutputTypes: outputTypes, - ParseVendor: c.Bool(parseVendorFlag), - ParseDependency: c.Bool(parseDependencyFlag), - MarkdownFilesDir: c.String(markdownFilesFlag), - ParseInternal: c.Bool(parseInternalFlag), - GeneratedTime: c.Bool(generatedTimeFlag), - CodeExampleFilesDir: c.String(codeExampleFilesFlag), - ParseDepth: c.Int(parseDepthFlag), - InstanceName: c.String(instanceNameFlag), - OverridesFile: c.String(overridesFileFlag), + ParseVendor: ctx.Bool(parseVendorFlag), + ParseDependency: ctx.Bool(parseDependencyFlag), + MarkdownFilesDir: ctx.String(markdownFilesFlag), + ParseInternal: ctx.Bool(parseInternalFlag), + GeneratedTime: ctx.Bool(generatedTimeFlag), + CodeExampleFilesDir: ctx.String(codeExampleFilesFlag), + ParseDepth: ctx.Int(parseDepthFlag), + InstanceName: ctx.String(instanceNameFlag), + OverridesFile: ctx.String(overridesFileFlag), }) } @@ -192,8 +192,8 @@ func main() { }, }, } - err := app.Run(os.Args) - if err != nil { + + if err := app.Run(os.Args); err != nil { log.Fatal(err) } } diff --git a/field_parser.go b/field_parser.go index b5cf26242..a00b490dc 100644 --- a/field_parser.go +++ b/field_parser.go @@ -10,10 +10,17 @@ import ( "sync" "unicode" + "github.com/go-openapi/jsonreference" "github.com/go-openapi/spec" ) -var _ FieldParser = &tagBaseFieldParser{} +var _ FieldParser = &tagBaseFieldParser{p: nil, field: nil, tag: ""} + +const ( + requiredLabel = "required" + swaggerTypeTag = "swaggertype" + swaggerIgnoreTag = "swaggerignore" +) type tagBaseFieldParser struct { p *Parser @@ -22,46 +29,47 @@ type tagBaseFieldParser struct { } func newTagBaseFieldParser(p *Parser, field *ast.Field) FieldParser { - ps := &tagBaseFieldParser{ + fieldParser := tagBaseFieldParser{ p: p, field: field, + tag: "", } - if ps.field.Tag != nil { - ps.tag = reflect.StructTag(strings.Replace(field.Tag.Value, "`", "", -1)) + if fieldParser.field.Tag != nil { + fieldParser.tag = reflect.StructTag(strings.ReplaceAll(field.Tag.Value, "`", "")) } - return ps + return &fieldParser } -func (ps *tagBaseFieldParser) ShouldSkip() (bool, error) { +func (ps *tagBaseFieldParser) ShouldSkip() bool { // Skip non-exported fields. if !ast.IsExported(ps.field.Names[0].Name) { - return true, nil + return true } if ps.field.Tag == nil { - return false, nil + return false } - ignoreTag := ps.tag.Get("swaggerignore") + ignoreTag := ps.tag.Get(swaggerIgnoreTag) if strings.EqualFold(ignoreTag, "true") { - return true, nil + return true } // json:"tag,hoge" - name := strings.TrimSpace(strings.Split(ps.tag.Get("json"), ",")[0]) + name := strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0]) if name == "-" { - return true, nil + return true } - return false, nil + return false } func (ps *tagBaseFieldParser) FieldName() (string, error) { var name string if ps.field.Tag != nil { // json:"tag,hoge" - name = strings.TrimSpace(strings.Split(ps.tag.Get("json"), ",")[0]) + name = strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0]) if name != "" { return name, nil @@ -79,34 +87,40 @@ func (ps *tagBaseFieldParser) FieldName() (string, error) { } func toSnakeCase(in string) string { - runes := []rune(in) - length := len(runes) - - var out []rune - for i := 0; i < length; i++ { - if i > 0 && unicode.IsUpper(runes[i]) && - ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) { + var ( + runes = []rune(in) + length = len(runes) + out []rune + ) + + for idx := 0; idx < length; idx++ { + if idx > 0 && unicode.IsUpper(runes[idx]) && + ((idx+1 < length && unicode.IsLower(runes[idx+1])) || unicode.IsLower(runes[idx-1])) { out = append(out, '_') } - out = append(out, unicode.ToLower(runes[i])) + + out = append(out, unicode.ToLower(runes[idx])) } return string(out) } func toLowerCamelCase(in string) string { - runes := []rune(in) + var flag bool - var out []rune - flag := false + out := make([]rune, len(in)) + + runes := []rune(in) for i, curr := range runes { if (i == 0 && unicode.IsUpper(curr)) || (flag && unicode.IsUpper(curr)) { - out = append(out, unicode.ToLower(curr)) + out[i] = unicode.ToLower(curr) flag = true - } else { - out = append(out, curr) - flag = false + + continue } + + out[i] = curr + flag = false } return string(out) @@ -117,7 +131,7 @@ func (ps *tagBaseFieldParser) CustomSchema() (*spec.Schema, error) { return nil, nil } - typeTag := ps.tag.Get("swaggertype") + typeTag := ps.tag.Get(swaggerTypeTag) if typeTag != "" { return BuildCustomSchema(strings.Split(typeTag, ",")) } @@ -126,7 +140,6 @@ func (ps *tagBaseFieldParser) CustomSchema() (*spec.Schema, error) { } type structField struct { - desc string schemaType string arrayType string formatType string @@ -138,11 +151,8 @@ type structField struct { maxItems *int64 minItems *int64 exampleValue interface{} - defaultValue interface{} - extensions map[string]interface{} enums []interface{} enumVarNames []interface{} - readOnly bool unique bool } @@ -155,32 +165,42 @@ func splitNotWrapped(s string, sep rune) []string { '{': '}', } - result := make([]string, 0) - current := "" - var openCount = 0 - var openChar rune + var ( + result = make([]string, 0) + current = strings.Builder{} + openCount = 0 + openChar rune + ) + for _, char := range s { - if openChar == 0 && openCloseMap[char] != 0 { + switch { + case openChar == 0 && openCloseMap[char] != 0: openChar = char + openCount++ - current += string(char) - } else if char == openChar { + + current.WriteRune(char) + case char == openChar: openCount++ - current = current + string(char) - } else if openCount > 0 && char == openCloseMap[openChar] { + + current.WriteRune(char) + case openCount > 0 && char == openCloseMap[openChar]: openCount-- - current += string(char) - } else if openCount == 0 && char == sep { - result = append(result, current) + + current.WriteRune(char) + case openCount == 0 && char == sep: + result = append(result, current.String()) + openChar = 0 - current = "" - } else { - current += string(char) + + current = strings.Builder{} + default: + current.WriteRune(char) } } - if current != "" { - result = append(result, current) + if current.String() != "" { + result = append(result, current.String()) } return result @@ -196,157 +216,110 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { if ps.field.Doc != nil { schema.Description = strings.TrimSpace(ps.field.Doc.Text()) } + if schema.Description == "" && ps.field.Comment != nil { schema.Description = strings.TrimSpace(ps.field.Comment.Text()) } + return nil } - structField := &structField{ + field := &structField{ schemaType: types[0], formatType: ps.tag.Get(formatTag), - readOnly: ps.tag.Get(readOnlyTag) == "true", } if len(types) > 1 && (types[0] == ARRAY || types[0] == OBJECT) { - structField.arrayType = types[1] + field.arrayType = types[1] } - if ps.field.Doc != nil { - structField.desc = strings.TrimSpace(ps.field.Doc.Text()) - } - if structField.desc == "" && ps.field.Comment != nil { - structField.desc = strings.TrimSpace(ps.field.Comment.Text()) - } - - jsonTag := ps.tag.Get(jsonTag) - // json:"name,string" or json:",string" + jsonTagValue := ps.tag.Get(jsonTag) - exampleTag, ok := ps.tag.Lookup(exampleTag) - if ok { - structField.exampleValue = exampleTag - if !strings.Contains(jsonTag, ",string") { - example, err := defineTypeOfExample(structField.schemaType, structField.arrayType, exampleTag) - if err != nil { - return err - } - structField.exampleValue = example - } + bindingTagValue := ps.tag.Get(bindingTag) + if bindingTagValue != "" { + parseValidTags(bindingTagValue, field) } - bindingTag := ps.tag.Get(bindingTag) - if bindingTag != "" { - ps.parseValidTags(bindingTag, structField) + validateTagValue := ps.tag.Get(validateTag) + if validateTagValue != "" { + parseValidTags(validateTagValue, field) } - validateTag := ps.tag.Get(validateTag) - if validateTag != "" { - ps.parseValidTags(validateTag, structField) - } - - extensionsTag := ps.tag.Get(extensionsTag) - if extensionsTag != "" { - structField.extensions = map[string]interface{}{} - for _, val := range splitNotWrapped(extensionsTag, ',') { - parts := strings.SplitN(val, "=", 2) - if len(parts) == 2 { - structField.extensions[parts[0]] = parts[1] - } else { - if len(parts[0]) > 0 && string(parts[0][0]) == "!" { - structField.extensions[parts[0][1:]] = false - } else { - structField.extensions[parts[0]] = true - } - } - } - } - - enumsTag := ps.tag.Get(enumsTag) - if enumsTag != "" { - enumType := structField.schemaType - if structField.schemaType == ARRAY { - enumType = structField.arrayType - } - - structField.enums = nil - for _, e := range strings.Split(enumsTag, ",") { - value, err := defineType(enumType, e) - if err != nil { - return err - } - structField.enums = append(structField.enums, value) - } - } - varnamesTag := ps.tag.Get("x-enum-varnames") - if varnamesTag != "" { - if structField.extensions == nil { - structField.extensions = map[string]interface{}{} - } - varNames := strings.Split(varnamesTag, ",") - if len(varNames) != len(structField.enums) { - return fmt.Errorf("invalid count of x-enum-varnames. expected %d, got %d", len(structField.enums), len(varNames)) - } - structField.enumVarNames = nil - for _, v := range varNames { - structField.enumVarNames = append(structField.enumVarNames, v) - } - structField.extensions["x-enum-varnames"] = structField.enumVarNames - } - defaultTag := ps.tag.Get(defaultTag) - if defaultTag != "" { - value, err := defineType(structField.schemaType, defaultTag) + enumsTagValue := ps.tag.Get(enumsTag) + if enumsTagValue != "" { + err := parseEnumTags(enumsTagValue, field) if err != nil { return err } - structField.defaultValue = value } - if IsNumericType(structField.schemaType) || IsNumericType(structField.arrayType) { + if IsNumericType(field.schemaType) || IsNumericType(field.arrayType) { maximum, err := getFloatTag(ps.tag, maximumTag) if err != nil { return err } + if maximum != nil { - structField.maximum = maximum + field.maximum = maximum } minimum, err := getFloatTag(ps.tag, minimumTag) if err != nil { return err } + if minimum != nil { - structField.minimum = minimum + field.minimum = minimum } multipleOf, err := getFloatTag(ps.tag, multipleOfTag) if err != nil { return err } + if multipleOf != nil { - structField.multipleOf = multipleOf + field.multipleOf = multipleOf } } - if structField.schemaType == STRING || structField.arrayType == STRING { - maxLength, err := getIntTag(ps.tag, "maxLength") + if field.schemaType == STRING || field.arrayType == STRING { + maxLength, err := getIntTag(ps.tag, maxLengthTag) if err != nil { return err } + if maxLength != nil { - structField.maxLength = maxLength + field.maxLength = maxLength } - minLength, err := getIntTag(ps.tag, "minLength") + minLength, err := getIntTag(ps.tag, minLengthTag) if err != nil { return err } + if minLength != nil { - structField.minLength = minLength + field.minLength = minLength + } + } + + // json:"name,string" or json:",string" + exampleTagValue, ok := ps.tag.Lookup(exampleTag) + if ok { + field.exampleValue = exampleTagValue + + if !strings.Contains(jsonTagValue, ",string") { + example, err := defineTypeOfExample(field.schemaType, field.arrayType, exampleTagValue) + if err != nil { + return err + } + + field.exampleValue = example } } // perform this after setting everything else (min, max, etc...) - if strings.Contains(jsonTag, ",string") { // @encoding/json: "It applies only to fields of string, floating point, integer, or boolean types." + if strings.Contains(jsonTagValue, ",string") { + // @encoding/json: "It applies only to fields of string, floating point, integer, or boolean types." defaultValues := map[string]string{ // Zero Values as string STRING: "", @@ -355,51 +328,103 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { NUMBER: "0", } - defaultValue, ok := defaultValues[structField.schemaType] + defaultValue, ok := defaultValues[field.schemaType] if ok { - structField.schemaType = STRING + field.schemaType = STRING + *schema = *PrimitiveSchema(field.schemaType) - if structField.exampleValue == nil { + if field.exampleValue == nil { // if exampleValue is not defined by the user, // we will force an example with a correct value // (eg: int->"0", bool:"false") - structField.exampleValue = defaultValue + field.exampleValue = defaultValue } } } - if structField.schemaType == STRING && types[0] != STRING { - *schema = *PrimitiveSchema(structField.schemaType) + if ps.field.Doc != nil { + schema.Description = strings.TrimSpace(ps.field.Doc.Text()) } - schema.Description = structField.desc - schema.ReadOnly = structField.readOnly + if schema.Description == "" && ps.field.Comment != nil { + schema.Description = strings.TrimSpace(ps.field.Comment.Text()) + } + + schema.ReadOnly = ps.tag.Get(readOnlyTag) == "true" + if !reflect.ValueOf(schema.Ref).IsZero() && schema.ReadOnly { schema.AllOf = []spec.Schema{*spec.RefSchema(schema.Ref.String())} - schema.Ref = spec.Ref{} // clear out existing ref + schema.Ref = spec.Ref{ + Ref: jsonreference.Ref{ + HasFullURL: false, + HasURLPathOnly: false, + HasFragmentOnly: false, + HasFileScheme: false, + HasFullFilePath: false, + }, + } // clear out existing ref + } + + defaultTagValue := ps.tag.Get(defaultTag) + if defaultTagValue != "" { + value, err := defineType(field.schemaType, defaultTagValue) + if err != nil { + return err + } + + schema.Default = value } - schema.Default = structField.defaultValue - schema.Example = structField.exampleValue - if structField.schemaType != ARRAY { - schema.Format = structField.formatType + + schema.Example = field.exampleValue + + if field.schemaType != ARRAY { + schema.Format = field.formatType + } + + extensionsTagValue := ps.tag.Get(extensionsTag) + if extensionsTagValue != "" { + schema.Extensions = setExtensionParam(extensionsTagValue) } - schema.Extensions = structField.extensions + + varNamesTag := ps.tag.Get("x-enum-varnames") + if varNamesTag != "" { + if schema.Extensions == nil { + schema.Extensions = map[string]interface{}{} + } + + varNames := strings.Split(varNamesTag, ",") + if len(varNames) != len(field.enums) { + return fmt.Errorf("invalid count of x-enum-varnames. expected %d, got %d", len(field.enums), len(varNames)) + } + + field.enumVarNames = nil + + for _, v := range varNames { + field.enumVarNames = append(field.enumVarNames, v) + } + + schema.Extensions["x-enum-varnames"] = field.enumVarNames + } + eleSchema := schema - if structField.schemaType == ARRAY { + + if field.schemaType == ARRAY { // For Array only - schema.MaxItems = structField.maxItems - schema.MinItems = structField.minItems - schema.UniqueItems = structField.unique + schema.MaxItems = field.maxItems + schema.MinItems = field.minItems + schema.UniqueItems = field.unique eleSchema = schema.Items.Schema - eleSchema.Format = structField.formatType - } - eleSchema.Maximum = structField.maximum - eleSchema.Minimum = structField.minimum - eleSchema.MultipleOf = structField.multipleOf - eleSchema.MaxLength = structField.maxLength - eleSchema.MinLength = structField.minLength - eleSchema.Enum = structField.enums + eleSchema.Format = field.formatType + } + + eleSchema.Maximum = field.maximum + eleSchema.Minimum = field.minimum + eleSchema.MultipleOf = field.multipleOf + eleSchema.MaxLength = field.maxLength + eleSchema.MinLength = field.minLength + eleSchema.Enum = field.enums + return nil } @@ -439,7 +464,7 @@ func (ps *tagBaseFieldParser) IsRequired() (bool, error) { bindingTag := ps.tag.Get(bindingTag) if bindingTag != "" { for _, val := range strings.Split(bindingTag, ",") { - if val == "required" { + if val == requiredLabel { return true, nil } } @@ -448,7 +473,7 @@ func (ps *tagBaseFieldParser) IsRequired() (bool, error) { validateTag := ps.tag.Get(validateTag) if validateTag != "" { for _, val := range strings.Split(validateTag, ",") { - if val == "required" { + if val == requiredLabel { return true, nil } } @@ -457,27 +482,24 @@ func (ps *tagBaseFieldParser) IsRequired() (bool, error) { return false, nil } -func (ps *tagBaseFieldParser) parseValidTags(validTag string, sf *structField) { +func parseValidTags(validTag string, sf *structField) { // `validate:"required,max=10,min=1"` // ps. required checked by IsRequired(). for _, val := range strings.Split(validTag, ",") { var ( - valKey string valValue string + keyVal = strings.Split(val, "=") ) - kv := strings.Split(val, "=") - switch len(kv) { + + switch len(keyVal) { case 1: - valKey = kv[0] case 2: - valKey = kv[0] - valValue = kv[1] + valValue = strings.ReplaceAll(strings.ReplaceAll(keyVal[1], utf8HexComma, ","), utf8Pipe, "|") default: continue } - valValue = strings.Replace(strings.Replace(valValue, utf8HexComma, ",", -1), utf8Pipe, "|", -1) - switch valKey { + switch keyVal[0] { case "max", "lte": sf.setMax(valValue) case "min", "gte": @@ -497,6 +519,26 @@ func (ps *tagBaseFieldParser) parseValidTags(validTag string, sf *structField) { } } +func parseEnumTags(enumTag string, field *structField) error { + enumType := field.schemaType + if field.schemaType == ARRAY { + enumType = field.arrayType + } + + field.enums = nil + + for _, e := range strings.Split(enumTag, ",") { + value, err := defineType(enumType, e) + if err != nil { + return err + } + + field.enums = append(field.enums, value) + } + + return nil +} + func (sf *structField) setOneOf(valValue string) { if len(sf.enums) != 0 { return @@ -513,6 +555,7 @@ func (sf *structField) setOneOf(valValue string) { if err != nil { continue } + sf.enums = append(sf.enums, value) } } @@ -522,6 +565,7 @@ func (sf *structField) setMin(valValue string) { if err != nil { return } + switch sf.schemaType { case INTEGER, NUMBER: sf.minimum = &value @@ -539,6 +583,7 @@ func (sf *structField) setMax(valValue string) { if err != nil { return } + switch sf.schemaType { case INTEGER, NUMBER: sf.maximum = &value @@ -558,25 +603,30 @@ const ( // These code copy from // https://github.com/go-playground/validator/blob/d4271985b44b735c6f76abc7a06532ee997f9476/baked_in.go#L207 -// --- +// ---. var oneofValsCache = map[string][]string{} var oneofValsCacheRWLock = sync.RWMutex{} var splitParamsRegex = regexp.MustCompile(`'[^']*'|\S+`) -func parseOneOfParam2(s string) []string { +func parseOneOfParam2(param string) []string { oneofValsCacheRWLock.RLock() - values, ok := oneofValsCache[s] + values, ok := oneofValsCache[param] oneofValsCacheRWLock.RUnlock() + if !ok { oneofValsCacheRWLock.Lock() - values = splitParamsRegex.FindAllString(s, -1) + values = splitParamsRegex.FindAllString(param, -1) + for i := 0; i < len(values); i++ { - values[i] = strings.Replace(values[i], "'", "", -1) + values[i] = strings.ReplaceAll(values[i], "'", "") } - oneofValsCache[s] = values + + oneofValsCache[param] = values + oneofValsCacheRWLock.Unlock() } + return values } -// --- +// ---. diff --git a/format/format.go b/format/format.go index d40454cd8..e881d9d04 100644 --- a/format/format.go +++ b/format/format.go @@ -25,9 +25,6 @@ type Config struct { func (f *Fmt) Build(config *Config) error { log.Println("Formating code.... ") - formater := swag.NewFormater() - if err := formater.FormatAPI(config.SearchDir, config.Excludes, config.MainFile); err != nil { - return err - } - return nil + + return swag.NewFormatter().FormatAPI(config.SearchDir, config.Excludes, config.MainFile) } diff --git a/formater.go b/formatter.go similarity index 62% rename from formater.go rename to formatter.go index 1c903f48c..0b14e99d8 100644 --- a/formater.go +++ b/formatter.go @@ -20,8 +20,8 @@ import ( const splitTag = "&*" -// Formater implements a formater for Go source files. -type Formater struct { +// Formatter implements a formater for Go source files. +type Formatter struct { // debugging output goes here debug Debugger @@ -31,23 +31,42 @@ type Formater struct { mainFile string } -// NewFormater create a new formater instance. +// Formater creates a new formatter. +type Formater struct { + *Formatter +} + +// NewFormater Deprecated: Use NewFormatter instead. func NewFormater() *Formater { - formater := &Formater{ + formatter := Formater{ + Formatter: NewFormatter(), + } + + formatter.debug.Printf("warining: NewFormater is deprecated. use NewFormatter instead") + + return &formatter +} + +// NewFormatter create a new formater instance. +func NewFormatter() *Formatter { + formatter := Formatter{ + mainFile: "", debug: log.New(os.Stdout, "", log.LstdFlags), excludes: make(map[string]struct{}), } - return formater + + return &formatter } // FormatAPI format the swag comment. -func (f *Formater) FormatAPI(searchDir, excludeDir, mainFile string) error { +func (f *Formatter) FormatAPI(searchDir, excludeDir, mainFile string) error { searchDirs := strings.Split(searchDir, ",") for _, searchDir := range searchDirs { if _, err := os.Stat(searchDir); os.IsNotExist(err) { return fmt.Errorf("dir: %s does not exist", searchDir) } } + for _, fi := range strings.Split(excludeDir, ",") { fi = strings.TrimSpace(fi) if fi != "" { @@ -61,10 +80,12 @@ func (f *Formater) FormatAPI(searchDir, excludeDir, mainFile string) error { if err != nil { return err } + err = f.FormatMain(absMainAPIFilePath) if err != nil { return err } + f.mainFile = mainFile err = f.formatMultiSearchDir(searchDirs) @@ -75,7 +96,7 @@ func (f *Formater) FormatAPI(searchDir, excludeDir, mainFile string) error { return nil } -func (f *Formater) formatMultiSearchDir(searchDirs []string) error { +func (f *Formatter) formatMultiSearchDir(searchDirs []string) error { for _, searchDir := range searchDirs { f.debug.Printf("Format API Info, search dir:%s", searchDir) @@ -84,10 +105,11 @@ func (f *Formater) formatMultiSearchDir(searchDirs []string) error { return err } } + return nil } -func (f *Formater) visit(path string, fileInfo os.FileInfo, err error) error { +func (f *Formatter) visit(path string, fileInfo os.FileInfo, err error) error { if err := walkWith(f.excludes, false)(path, fileInfo); err != nil { return err } else if fileInfo.IsDir() { @@ -99,6 +121,7 @@ func (f *Formater) visit(path string, fileInfo os.FileInfo, err error) error { // skip if file not has suffix "*.go" return nil } + if strings.HasSuffix(strings.ToLower(path), f.mainFile) { // skip main file return nil @@ -108,16 +131,19 @@ func (f *Formater) visit(path string, fileInfo os.FileInfo, err error) error { if err != nil { return fmt.Errorf("ParseFile error:%+v", err) } + return nil } // FormatMain format the main.go comment. -func (f *Formater) FormatMain(mainFilepath string) error { +func (f *Formatter) FormatMain(mainFilepath string) error { fileSet := token.NewFileSet() + astFile, err := goparser.ParseFile(fileSet, mainFilepath, nil, goparser.ParseComments) if err != nil { return fmt.Errorf("cannot format file, err: %w path : %s ", err, mainFilepath) } + var ( formatedComments = bytes.Buffer{} // CommentCache @@ -130,12 +156,13 @@ func (f *Formater) FormatMain(mainFilepath string) error { } } - return writeFormatedComments(mainFilepath, formatedComments, oldCommentsMap) + return writeFormattedComments(mainFilepath, formatedComments, oldCommentsMap) } // FormatFile format the swag comment in go function. -func (f *Formater) FormatFile(filepath string) error { +func (f *Formatter) FormatFile(filepath string) error { fileSet := token.NewFileSet() + astFile, err := goparser.ParseFile(fileSet, filepath, nil, goparser.ParseComments) if err != nil { return fmt.Errorf("cannot format file, err: %w path : %s ", err, filepath) @@ -154,18 +181,19 @@ func (f *Formater) FormatFile(filepath string) error { } } - return writeFormatedComments(filepath, formatedComments, oldCommentsMap) + return writeFormattedComments(filepath, formatedComments, oldCommentsMap) } -func writeFormatedComments(filepath string, formatedComments bytes.Buffer, oldCommentsMap map[string]string) error { +func writeFormattedComments(filepath string, formatedComments bytes.Buffer, oldCommentsMap map[string]string) error { // Replace the file // Read the file srcBytes, err := ioutil.ReadFile(filepath) if err != nil { return fmt.Errorf("cannot open file, err: %w path : %s ", err, filepath) } - replaceSrc := string(srcBytes) - newComments := strings.Split(formatedComments.String(), "\n") + + replaceSrc, newComments := string(srcBytes), strings.Split(formatedComments.String(), "\n") + for _, e := range newComments { commentSplit := strings.Split(e, splitTag) if len(commentSplit) == 2 { @@ -176,11 +204,12 @@ func writeFormatedComments(filepath string, formatedComments bytes.Buffer, oldCo } } } + return writeBack(filepath, []byte(replaceSrc), srcBytes) } -func formatFuncDoc(commentList []*ast.Comment, formatedComments io.Writer, oldCommentsMap map[string]string) { - tabw := tabwriter.NewWriter(formatedComments, 0, 0, 2, ' ', 0) +func formatFuncDoc(commentList []*ast.Comment, formattedComments io.Writer, oldCommentsMap map[string]string) { + tabWriter := tabwriter.NewWriter(formattedComments, 0, 0, 2, ' ', 0) for _, comment := range commentList { commentLine := comment.Text @@ -193,72 +222,96 @@ func formatFuncDoc(commentList []*ast.Comment, formatedComments io.Writer, oldCo // md5 + splitTag + srcCommentLine // eg. xxx&*@Description get struct array - _, _ = fmt.Fprintln(tabw, cmd5+splitTag+c) + _, _ = fmt.Fprintln(tabWriter, cmd5+splitTag+c) } } - // format by tabwriter - _ = tabw.Flush() -} - -// Check of @Param @Success @Failure @Response @Header -var specialTagForSplit = map[string]byte{ - paramAttr: 1, - successAttr: 1, - failureAttr: 1, - responseAttr: 1, - headerAttr: 1, -} - -var skipChar = map[byte]byte{ - '"': 1, - '(': 1, - '{': 1, - '[': 1, + // format by tabWriter + _ = tabWriter.Flush() } -var skipCharEnd = map[byte]byte{ - '"': 1, - ')': 1, - '}': 1, - ']': 1, -} +func separatorFinder(comment string, replacer byte) string { + commentBytes, commentLine := []byte(comment), strings.TrimSpace(strings.TrimLeft(comment, "/")) -func separatorFinder(comment string, rp byte) string { - commentBytes := []byte(comment) - commentLine := strings.TrimSpace(strings.TrimLeft(comment, "/")) if len(commentLine) == 0 { return "" } + attribute := strings.Fields(commentLine)[0] attrLen := strings.Index(comment, attribute) + len(attribute) attribute = strings.ToLower(attribute) - var i = attrLen - - if _, ok := specialTagForSplit[attribute]; ok { - var skipFlag bool - for ; i < len(commentBytes); i++ { - if !skipFlag && commentBytes[i] == ' ' { - j := i - for j < len(commentBytes) && commentBytes[j] == ' ' { - j++ - } - commentBytes = replaceRange(commentBytes, i, j, rp) - } - if _, ok := skipChar[commentBytes[i]]; ok && !skipFlag { - skipFlag = true - } else if _, ok := skipCharEnd[commentBytes[i]]; ok && skipFlag { - skipFlag = false + + var ( + length = attrLen + + // Check of @Param @Success @Failure @Response @Header. + specialTagForSplit = map[string]byte{ + paramAttr: 1, + successAttr: 1, + failureAttr: 1, + responseAttr: 1, + headerAttr: 1, + } + ) + + _, ok := specialTagForSplit[attribute] + if ok { + return splitSpecialTags(commentBytes, length, replacer) + } + + for length < len(commentBytes) && commentBytes[length] == ' ' { + length++ + } + + if length >= len(commentBytes) { + return comment + } + + commentBytes = replaceRange(commentBytes, attrLen, length, replacer) + + return string(commentBytes) +} + +func splitSpecialTags(commentBytes []byte, length int, rp byte) string { + var ( + skipFlag bool + skipChar = map[byte]byte{ + '"': 1, + '(': 1, + '{': 1, + '[': 1, + } + + skipCharEnd = map[byte]byte{ + '"': 1, + ')': 1, + '}': 1, + ']': 1, + } + ) + + for ; length < len(commentBytes); length++ { + if !skipFlag && commentBytes[length] == ' ' { + j := length + for j < len(commentBytes) && commentBytes[j] == ' ' { + j++ } + + commentBytes = replaceRange(commentBytes, length, j, rp) } - } else { - for i < len(commentBytes) && commentBytes[i] == ' ' { - i++ + + _, found := skipChar[commentBytes[length]] + if found && !skipFlag { + skipFlag = true + + continue } - if i >= len(commentBytes) { - return comment + + _, found = skipCharEnd[commentBytes[length]] + if found && skipFlag { + skipFlag = false } - commentBytes = replaceRange(commentBytes, attrLen, i, rp) } + return string(commentBytes) } @@ -266,11 +319,15 @@ func replaceRange(s []byte, start, end int, new byte) []byte { if start > end || end < 1 { return s } + if end > len(s) { end = len(s) } + s = append(s[:start], s[end-1:]...) + s[start] = new + return s } @@ -281,23 +338,26 @@ func isSwagComment(comment string) bool { } func isBlankComment(comment string) bool { - lc := strings.TrimSpace(comment) - return len(lc) == 0 + return len(strings.TrimSpace(comment)) == 0 } -// writeBack write to file +// writeBack write to file. func writeBack(filepath string, src, old []byte) error { // make a temporary backup before overwriting original - bakname, err := backupFile(filepath+".", old, 0644) + backupName, err := backupFile(filepath+".", old, 0644) if err != nil { return err } + err = ioutil.WriteFile(filepath, src, 0644) if err != nil { - _ = os.Rename(bakname, filepath) + _ = os.Rename(backupName, filepath) + return err } - _ = os.Remove(bakname) + + _ = os.Remove(backupName) + return nil } @@ -306,21 +366,23 @@ const chmodSupported = runtime.GOOS != "windows" // backupFile writes data to a new file named filename with permissions perm, // with 0 && - IsSimplePrimitiveType(prop.Items.Schema.Type[0]): + case prop.Type[0] == ARRAY && prop.Items.Schema != nil && + len(prop.Items.Schema.Type) > 0 && IsSimplePrimitiveType(prop.Items.Schema.Type[0]): + param = createParameter(paramType, prop.Description, name, prop.Type[0], findInSlice(schema.Required, name)) param.SimpleSchema.Type = prop.Type[0] + if operation.parser != nil && operation.parser.collectionFormatInQuery != "" && param.CollectionFormat == "" { param.CollectionFormat = TransToValidCollectionFormat(operation.parser.collectionFormatInQuery) } + param.SimpleSchema.Items = &spec.Items{ SimpleSchema: spec.SimpleSchema{ Type: prop.Items.Schema.Type[0], @@ -309,6 +358,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F continue } + param.Nullable = prop.Nullable param.Format = prop.Format param.Default = prop.Default @@ -339,16 +389,18 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F if err != nil { return err } + param.Schema = schema } default: return fmt.Errorf("%s is not supported paramType", paramType) } - err := operation.parseAndExtractionParamAttribute(commentLine, objectType, refType, ¶m) + err := operation.parseParamAttribute(commentLine, objectType, refType, ¶m) if err != nil { return err } + operation.Operation.Parameters = append(operation.Operation.Parameters, param) return nil @@ -365,8 +417,8 @@ const ( validateTag = "validate" minimumTag = "minimum" maximumTag = "maximum" - minLengthTag = "minlength" - maxLengthTag = "maxlength" + minLengthTag = "minLength" + maxLengthTag = "maxLength" multipleOfTag = "multipleOf" readOnlyTag = "readonly" extensionsTag = "extensions" @@ -398,22 +450,24 @@ var regexAttributes = map[string]*regexp.Regexp{ schemaExampleTag: regexp.MustCompile(`(?i)\s+schemaExample\(.*\)`), } -func (operation *Operation) parseAndExtractionParamAttribute(commentLine, objectType, schemaType string, param *spec.Parameter) error { +func (operation *Operation) parseParamAttribute(comment, objectType, schemaType string, param *spec.Parameter) error { schemaType = TransToValidSchemeType(schemaType) + for attrKey, re := range regexAttributes { - attr, err := findAttr(re, commentLine) + attr, err := findAttr(re, comment) if err != nil { continue } + switch attrKey { case enumsTag: err = setEnumParam(param, attr, objectType, schemaType) case minimumTag, maximumTag: - err = setNumberParam(param, attrKey, schemaType, attr, commentLine) + err = setNumberParam(param, attrKey, schemaType, attr, comment) case defaultTag: err = setDefault(param, schemaType, attr) case minLengthTag, maxLengthTag: - err = setStringParam(param, attrKey, schemaType, attr, commentLine) + err = setStringParam(param, attrKey, schemaType, attr, comment) case formatTag: param.Format = attr case exampleTag: @@ -421,10 +475,11 @@ func (operation *Operation) parseAndExtractionParamAttribute(commentLine, object case schemaExampleTag: err = setSchemaExample(param, schemaType, attr) case extensionsTag: - _ = setExtensionParam(param, attr) + param.Extensions = setExtensionParam(attr) case collectionFormatTag: - err = setCollectionFormatParam(param, attrKey, objectType, attr, commentLine) + err = setCollectionFormatParam(param, attrKey, objectType, attr, comment) } + if err != nil { return err } @@ -435,8 +490,8 @@ func (operation *Operation) parseAndExtractionParamAttribute(commentLine, object func findAttr(re *regexp.Regexp, commentLine string) (string, error) { attr := re.FindString(commentLine) - l := strings.Index(attr, "(") - r := strings.Index(attr, ")") + + l, r := strings.Index(attr, "("), strings.Index(attr, ")") if l == -1 || r == -1 { return "", fmt.Errorf("can not find regex=%s, comment=%s", re.String(), commentLine) } @@ -471,12 +526,14 @@ func setNumberParam(param *spec.Parameter, name, schemaType, attr, commentLine s if err != nil { return fmt.Errorf("maximum is allow only a number. comment=%s got=%s", commentLine, attr) } + switch name { case minimumTag: param.Minimum = &n case maximumTag: param.Maximum = &n } + return nil default: return fmt.Errorf("%s is attribute to set to a number. comment=%s got=%s", name, commentLine, schemaType) @@ -503,23 +560,33 @@ func setEnumParam(param *spec.Parameter, attr, objectType, schemaType string) er return nil } -func setExtensionParam(param *spec.Parameter, attr string) error { - param.Extensions = map[string]interface{}{} +func setExtensionParam(attr string) spec.Extensions { + extensions := spec.Extensions{} + for _, val := range splitNotWrapped(attr, ',') { parts := strings.SplitN(val, "=", 2) if len(parts) == 2 { - param.Extensions.Add(parts[0], parts[1]) + extensions.Add(parts[0], parts[1]) + + continue + } + + if len(parts[0]) > 0 && string(parts[0][0]) == "!" { + extensions.Add(parts[0][1:], false) continue } - param.Extensions.Add(parts[0], true) + + extensions.Add(parts[0], true) } - return nil + + return extensions } func setCollectionFormatParam(param *spec.Parameter, name, schemaType, attr, commentLine string) error { if schemaType == ARRAY { param.CollectionFormat = TransToValidCollectionFormat(attr) + return nil } @@ -531,13 +598,12 @@ func setDefault(param *spec.Parameter, schemaType string, value string) error { if err != nil { return nil // Don't set a default value if it's not valid } + param.Default = val + return nil } -// controlCharReplacer replaces \r \n \t in example string values -var controlCharReplacer = strings.NewReplacer(`\r`, "\r", `\n`, "\n", `\t`, "\t") - func setSchemaExample(param *spec.Parameter, schemaType string, value string) error { val, err := defineType(schemaType, value) if err != nil { @@ -550,7 +616,8 @@ func setSchemaExample(param *spec.Parameter, schemaType string, value string) er switch v := val.(type) { case string: - param.Schema.Example = controlCharReplacer.Replace(v) + // replaces \r \n \t in example string values. + param.Schema.Example = strings.NewReplacer(`\r`, "\r", `\n`, "\n", `\t`, "\t").Replace(v) default: param.Schema.Example = val } @@ -563,13 +630,16 @@ func setExample(param *spec.Parameter, schemaType string, value string) error { if err != nil { return nil // Don't set a example value if it's not valid } + param.Example = val + return nil } // defineType enum value define the type (object and array unsupported). func defineType(schemaType string, value string) (v interface{}, err error) { schemaType = TransToValidSchemeType(schemaType) + switch schemaType { case STRING: return value, nil @@ -597,8 +667,7 @@ func defineType(schemaType string, value string) (v interface{}, err error) { // ParseTagsComment parses comment for given `tag` comment string. func (operation *Operation) ParseTagsComment(commentLine string) { - tags := strings.Split(commentLine, ",") - for _, tag := range tags { + for _, tag := range strings.Split(commentLine, ",") { operation.Tags = append(operation.Tags, strings.TrimSpace(tag)) } } @@ -617,13 +686,13 @@ func (operation *Operation) ParseProduceComment(commentLine string) error { // `produce` (`Content-Type:` response header) or // `accept` (`Accept:` request header). func parseMimeTypeList(mimeTypeList string, typeList *[]string, format string) error { - mimeTypes := strings.Split(mimeTypeList, ",") - for _, typeName := range mimeTypes { + for _, typeName := range strings.Split(mimeTypeList, ",") { if mimeTypePattern.MatchString(typeName) { *typeList = append(*typeList, typeName) continue } + aliasMimeType, ok := mimeTypeAliases[typeName] if !ok { return fmt.Errorf(format, typeName) @@ -643,6 +712,7 @@ func (operation *Operation) ParseRouterComment(commentLine string) error { if len(matches) != 3 { return fmt.Errorf("can not parse router comment \"%s\"", commentLine) } + signature := RouteProperties{ Path: matches[1], HTTPMethod: strings.ToUpper(matches[2]), @@ -659,35 +729,41 @@ func (operation *Operation) ParseRouterComment(commentLine string) error { // ParseSecurityComment parses comment for given `security` comment string. func (operation *Operation) ParseSecurityComment(commentLine string) error { - //var securityMap map[string][]string = map[string][]string{} + var ( + securityMap = make(map[string][]string) + securitySource = commentLine[strings.Index(commentLine, "@Security")+1:] + ) - var securityMap = make(map[string][]string) - securitySource := commentLine[strings.Index(commentLine, "@Security")+1:] for _, securityOption := range strings.Split(securitySource, "||") { securityOption = strings.TrimSpace(securityOption) - l := strings.Index(securityOption, "[") - r := strings.Index(securityOption, "]") - if !(l == -1 && r == -1) { - scopes := securityOption[l+1 : r] - var s []string + + left, right := strings.Index(securityOption, "["), strings.Index(securityOption, "]") + + if !(left == -1 && right == -1) { + scopes := securityOption[left+1 : right] + + var options []string + for _, scope := range strings.Split(scopes, ",") { - s = append(s, strings.TrimSpace(scope)) + options = append(options, strings.TrimSpace(scope)) } - securityKey := securityOption[0:l] - securityMap[securityKey] = append(securityMap[securityKey], s...) + securityKey := securityOption[0:left] + securityMap[securityKey] = append(securityMap[securityKey], options...) } else { securityKey := strings.TrimSpace(securityOption) securityMap[securityKey] = []string{} } } + operation.Security = append(operation.Security, securityMap) + return nil } // findTypeDef attempts to find the *ast.TypeSpec for a specific type given the // type's name and the package's import path. -// TODO: improve finding external pkg +// TODO: improve finding external pkg. func findTypeDef(importPath, typeName string) (*ast.TypeSpec, error) { cwd, err := os.Getwd() if err != nil { @@ -723,7 +799,6 @@ func findTypeDef(importPath, typeName string) (*ast.TypeSpec, error) { } // TODO: possibly cache pkgInfo since it's an expensive operation - for i := range pkgInfo.Files { for _, astDeclaration := range pkgInfo.Files[i].Decls { generalDeclaration, ok := astDeclaration.(*ast.GenDecl) @@ -743,7 +818,7 @@ func findTypeDef(importPath, typeName string) (*ast.TypeSpec, error) { return nil, fmt.Errorf("type spec not found") } -var responsePattern = regexp.MustCompile(`^([\w,]+)[\s]+([\w{}]+)[\s]+([\w\-.\\{}=,\[\]]+)[^"]*(.*)?`) +var responsePattern = regexp.MustCompile(`^([\w,]+)\s+([\w{}]+)\s+([\w\-.\\{}=,\[\]]+)[^"]*(.*)?`) // ResponseType{data1=Type1,data2=Type2}. var combinedPattern = regexp.MustCompile(`^([\w\-./\[\]]+){(.*)}$`) @@ -752,9 +827,9 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) switch { case refType == NIL: return nil, nil - case refType == "interface{}": + case refType == INTERFACE: return PrimitiveSchema(OBJECT), nil - case refType == "any": + case refType == ANY: return PrimitiveSchema(OBJECT), nil case IsGolangPrimitiveType(refType): refType = TransToValidSchemeType(refType) @@ -775,10 +850,12 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) if idx < 0 { return nil, fmt.Errorf("invalid type: %s", refType) } + refType = refType[idx+1:] - if refType == "interface{}" || refType == "any" { + if refType == INTERFACE || refType == ANY { return spec.MapProperty(nil), nil } + schema, err := operation.parseObjectSchema(refType, astFile) if err != nil { return nil, err @@ -801,45 +878,46 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) } } +func parseFields(s string) []string { + nestLevel := 0 + + return strings.FieldsFunc(s, func(char rune) bool { + if char == '{' { + nestLevel++ + + return false + } else if char == '}' { + nestLevel-- + + return false + } + + return char == ',' && nestLevel == 0 + }) +} + func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { matches := combinedPattern.FindStringSubmatch(refType) if len(matches) != 3 { return nil, fmt.Errorf("invalid type: %s", refType) } - refType = matches[1] - schema, err := operation.parseObjectSchema(refType, astFile) + + schema, err := operation.parseObjectSchema(matches[1], astFile) if err != nil { return nil, err } - parseFields := func(s string) []string { - n := 0 - - return strings.FieldsFunc(s, func(r rune) bool { - if r == '{' { - n++ - - return false - } else if r == '}' { - n-- - - return false - } - - return r == ',' && n == 0 - }) - } + fields, props := parseFields(matches[2]), map[string]spec.Schema{} - fields := parseFields(matches[2]) - props := map[string]spec.Schema{} for _, field := range fields { - matches := strings.SplitN(field, "=", 2) - if len(matches) == 2 { - schema, err := operation.parseObjectSchema(matches[1], astFile) + keyVal := strings.SplitN(field, "=", 2) + if len(keyVal) == 2 { + schema, err := operation.parseObjectSchema(keyVal[1], astFile) if err != nil { return nil, err } - props[matches[0]] = *schema + + props[keyVal[0]] = *schema } } @@ -861,6 +939,7 @@ func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, ast if !strings.HasPrefix(refType, "[]") { return operation.parseObjectSchema(refType, astFile) } + refType = refType[2:] fallthrough @@ -889,6 +968,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as } description := strings.Trim(matches[4], "\"") + schema, err := operation.parseAPIObjectSchema(strings.Trim(matches[2], "{}"), matches[3], astFile) if err != nil { return err @@ -900,6 +980,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as continue } + code, err := strconv.Atoi(codeStr) if err != nil { return fmt.Errorf("can not parse response comment \"%s\"", commentLine) @@ -924,6 +1005,23 @@ func newHeaderSpec(schemaType, description string) spec.Header { HeaderProps: spec.HeaderProps{ Description: description, }, + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, + CommonValidations: spec.CommonValidations{ + Maximum: nil, + ExclusiveMaximum: false, + Minimum: nil, + ExclusiveMinimum: false, + MaxLength: nil, + MinLength: nil, + Pattern: "", + MaxItems: nil, + MinItems: nil, + UniqueItems: false, + MultipleOf: nil, + Enum: nil, + }, } } @@ -966,6 +1064,7 @@ func (operation *Operation) ParseResponseHeaderComment(commentLine string, _ *as if err != nil { return fmt.Errorf("can not parse response comment \"%s\"", commentLine) } + if operation.Responses.StatusCodeResponses != nil { response, responseExist := operation.Responses.StatusCodeResponses[code] if responseExist { @@ -979,7 +1078,7 @@ func (operation *Operation) ParseResponseHeaderComment(commentLine string, _ *as return nil } -var emptyResponsePattern = regexp.MustCompile(`([\w,]+)[\s]+"(.*)"`) +var emptyResponsePattern = regexp.MustCompile(`([\w,]+)\s+"(.*)"`) // ParseEmptyResponseComment parse only comment out status code and description,eg: @Success 200 "it's ok". func (operation *Operation) ParseEmptyResponseComment(commentLine string) error { @@ -989,6 +1088,7 @@ func (operation *Operation) ParseEmptyResponseComment(commentLine string) error } description := strings.Trim(matches[2], "\"") + for _, codeStr := range strings.Split(matches[1], ",") { if strings.EqualFold(codeStr, defaultTag) { operation.DefaultResponse().WithDescription(description) @@ -1015,6 +1115,7 @@ func (operation *Operation) ParseEmptyResponseOnly(commentLine string) error { continue } + code, err := strconv.Atoi(codeStr) if err != nil { return fmt.Errorf("can not parse response comment \"%s\"", commentLine) @@ -1031,7 +1132,8 @@ func (operation *Operation) DefaultResponse() *spec.Response { if operation.Responses.Default == nil { operation.Responses.Default = &spec.Response{ ResponseProps: spec.ResponseProps{ - Headers: make(map[string]spec.Header), + Description: "", + Headers: make(map[string]spec.Header), }, } } @@ -1044,6 +1146,7 @@ func (operation *Operation) AddResponse(code int, response *spec.Response) { if response.Headers == nil { response.Headers = make(map[string]spec.Header) } + operation.Responses.StatusCodeResponses[code] = *response } @@ -1052,10 +1155,12 @@ func createParameter(paramType, description, paramName, schemaType string, requi // //five possible parameter types. query, path, body, header, form result := spec.Parameter{ ParamProps: spec.ParamProps{ - Name: paramName, - Description: description, - Required: required, - In: paramType, + Name: paramName, + Description: description, + Required: required, + In: paramType, + Schema: nil, + AllowEmptyValue: false, }, } @@ -1070,7 +1175,9 @@ func createParameter(paramType, description, paramName, schemaType string, requi } result.SimpleSchema = spec.SimpleSchema{ - Type: schemaType, + Type: schemaType, + Nullable: false, + Format: "", } return result @@ -1086,6 +1193,7 @@ func getCodeExampleForSummary(summaryName string, dirPath string) ([]byte, error if fileInfo.IsDir() { continue } + fileName := fileInfo.Name() if !strings.Contains(fileName, ".json") { @@ -1094,6 +1202,7 @@ func getCodeExampleForSummary(summaryName string, dirPath string) ([]byte, error if strings.Contains(fileName, summaryName) { fullPath := filepath.Join(dirPath, fileName) + commentInfo, err := ioutil.ReadFile(fullPath) if err != nil { return nil, fmt.Errorf("Failed to read code example file %s error: %s ", fullPath, err) diff --git a/operation_test.go b/operation_test.go index e9e43e8c8..d8089f0a2 100644 --- a/operation_test.go +++ b/operation_test.go @@ -1964,7 +1964,7 @@ func TestParseAndExtractionParamAttribute(t *testing.T) { op := NewOperation(nil) numberParam := spec.Parameter{} - err := op.parseAndExtractionParamAttribute( + err := op.parseParamAttribute( " default(1) maximum(100) minimum(0) format(csv)", "", NUMBER, @@ -1976,14 +1976,14 @@ func TestParseAndExtractionParamAttribute(t *testing.T) { assert.Equal(t, "csv", numberParam.SimpleSchema.Format) assert.Equal(t, float64(1), numberParam.Default) - err = op.parseAndExtractionParamAttribute(" minlength(1)", "", NUMBER, nil) + err = op.parseParamAttribute(" minlength(1)", "", NUMBER, nil) assert.Error(t, err) - err = op.parseAndExtractionParamAttribute(" maxlength(1)", "", NUMBER, nil) + err = op.parseParamAttribute(" maxlength(1)", "", NUMBER, nil) assert.Error(t, err) stringParam := spec.Parameter{} - err = op.parseAndExtractionParamAttribute( + err = op.parseParamAttribute( " default(test) maxlength(100) minlength(0) format(csv)", "", STRING, @@ -1993,21 +1993,21 @@ func TestParseAndExtractionParamAttribute(t *testing.T) { assert.Equal(t, int64(0), *stringParam.MinLength) assert.Equal(t, int64(100), *stringParam.MaxLength) assert.Equal(t, "csv", stringParam.SimpleSchema.Format) - err = op.parseAndExtractionParamAttribute(" minimum(0)", "", STRING, nil) + err = op.parseParamAttribute(" minimum(0)", "", STRING, nil) assert.Error(t, err) - err = op.parseAndExtractionParamAttribute(" maximum(0)", "", STRING, nil) + err = op.parseParamAttribute(" maximum(0)", "", STRING, nil) assert.Error(t, err) arrayParram := spec.Parameter{} - err = op.parseAndExtractionParamAttribute(" collectionFormat(tsv)", ARRAY, STRING, &arrayParram) + err = op.parseParamAttribute(" collectionFormat(tsv)", ARRAY, STRING, &arrayParram) assert.Equal(t, "tsv", arrayParram.CollectionFormat) assert.NoError(t, err) - err = op.parseAndExtractionParamAttribute(" collectionFormat(tsv)", STRING, STRING, nil) + err = op.parseParamAttribute(" collectionFormat(tsv)", STRING, STRING, nil) assert.Error(t, err) - err = op.parseAndExtractionParamAttribute(" default(0)", "", ARRAY, nil) + err = op.parseParamAttribute(" default(0)", "", ARRAY, nil) assert.NoError(t, err) } diff --git a/packages.go b/packages.go index c454f0652..dd1a0e6c7 100644 --- a/packages.go +++ b/packages.go @@ -29,13 +29,13 @@ func NewPackagesDefinitions() *PackagesDefinitions { } // CollectAstFile collect ast.file. -func (pkgs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error { - if pkgs.files == nil { - pkgs.files = make(map[*ast.File]*AstFileInfo) +func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error { + if pkgDefs.files == nil { + pkgDefs.files = make(map[*ast.File]*AstFileInfo) } - if pkgs.packages == nil { - pkgs.packages = make(map[string]*PackageDefinitions) + if pkgDefs.packages == nil { + pkgDefs.packages = make(map[string]*PackageDefinitions) } // return without storing the file if we lack a packageDir @@ -48,23 +48,24 @@ func (pkgs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile return err } - pd, ok := pkgs.packages[packageDir] + dependency, ok := pkgDefs.packages[packageDir] if ok { // return without storing the file if it already exists - _, exists := pd.Files[path] + _, exists := dependency.Files[path] if exists { return nil } - pd.Files[path] = astFile + + dependency.Files[path] = astFile } else { - pkgs.packages[packageDir] = &PackageDefinitions{ + pkgDefs.packages[packageDir] = &PackageDefinitions{ Name: astFile.Name.Name, Files: map[string]*ast.File{path: astFile}, TypeDefinitions: make(map[string]*TypeSpecDef), } } - pkgs.files[astFile] = &AstFileInfo{ + pkgDefs.files[astFile] = &AstFileInfo{ File: astFile, Path: path, PackagePath: packageDir, @@ -96,15 +97,15 @@ func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, f // ParseTypes parse types // @Return parsed definitions. -func (pkgs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) { +func (pkgDefs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) { parsedSchemas := make(map[*TypeSpecDef]*Schema) - for astFile, info := range pkgs.files { - pkgs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas) + for astFile, info := range pkgDefs.files { + pkgDefs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas) } return parsedSchemas, nil } -func (pkgs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { +func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { for _, astDeclaration := range astFile.Decls { if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE { for _, astSpec := range generalDeclaration.Specs { @@ -123,29 +124,29 @@ func (pkgs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePa } } - if pkgs.uniqueDefinitions == nil { - pkgs.uniqueDefinitions = make(map[string]*TypeSpecDef) + if pkgDefs.uniqueDefinitions == nil { + pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef) } fullName := typeSpecDef.FullName() - anotherTypeDef, ok := pkgs.uniqueDefinitions[fullName] + anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName] if ok { if typeSpecDef.PkgPath == anotherTypeDef.PkgPath { continue } else { - delete(pkgs.uniqueDefinitions, fullName) + delete(pkgDefs.uniqueDefinitions, fullName) } } else { - pkgs.uniqueDefinitions[fullName] = typeSpecDef + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } - if pkgs.packages[typeSpecDef.PkgPath] == nil { - pkgs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{ + if pkgDefs.packages[typeSpecDef.PkgPath] == nil { + pkgDefs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{ Name: astFile.Name.Name, TypeDefinitions: map[string]*TypeSpecDef{typeSpecDef.Name(): typeSpecDef}, } - } else if _, ok = pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { - pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef + } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { + pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef } } } @@ -153,11 +154,12 @@ func (pkgs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePa } } -func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef { - if pkgs.packages == nil { +func (pkgDefs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef { + if pkgDefs.packages == nil { return nil } - pd, found := pkgs.packages[pkgPath] + + pd, found := pkgDefs.packages[pkgPath] if found { typeSpec, ok := pd.TypeDefinitions[typeName] if ok { @@ -168,7 +170,7 @@ func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) * return nil } -func (pkgs *PackagesDefinitions) loadExternalPackage(importPath string) error { +func (pkgDefs *PackagesDefinitions) loadExternalPackage(importPath string) error { cwd, err := os.Getwd() if err != nil { return err @@ -189,7 +191,7 @@ func (pkgs *PackagesDefinitions) loadExternalPackage(importPath string) error { for _, info := range loaderProgram.AllPackages { pkgPath := strings.TrimPrefix(info.Pkg.Path(), "vendor/") for _, astFile := range info.Files { - pkgs.parseTypesFromFile(astFile, pkgPath, nil) + pkgDefs.parseTypesFromFile(astFile, pkgPath, nil) } } @@ -201,7 +203,7 @@ func (pkgs *PackagesDefinitions) loadExternalPackage(importPath string) error { // @file current ast.File in which to search imports // @fuzzy search for the package path that the last part matches the @pkg if true // @return the package path of a package of @pkg. -func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File, fuzzy bool) string { +func (pkgDefs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File, fuzzy bool) string { if file == nil { return "" } @@ -214,6 +216,7 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as matchLastPathPart := func(pkgPath string) bool { paths := strings.Split(pkgPath, "/") + return paths[len(paths)-1] == pkg } @@ -223,26 +226,33 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as if imp.Name.Name == pkg { return strings.Trim(imp.Path.Value, `"`) } + if imp.Name.Name == "_" { hasAnonymousPkg = true } continue } - if pkgs.packages != nil { + + if pkgDefs.packages != nil { path := strings.Trim(imp.Path.Value, `"`) if fuzzy { if matchLastPathPart(path) { return path } - } else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg { + + continue + } + + pd, ok := pkgDefs.packages[path] + if ok && pd.Name == pkg { return path } } } // match unnamed package - if hasAnonymousPkg && pkgs.packages != nil { + if hasAnonymousPkg && pkgDefs.packages != nil { for _, imp := range file.Imports { if imp.Name == nil { continue @@ -253,7 +263,7 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as if matchLastPathPart(path) { return path } - } else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg { + } else if pd, ok := pkgDefs.packages[path]; ok && pd.Name == pkg { return path } } @@ -267,12 +277,13 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as // @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file // @file the ast.file in which @typeName is used // @pkgPath the package path of @file. -func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef { +func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef { if IsGolangPrimitiveType(typeName) { return nil } + if file == nil { // for test - return pkgs.uniqueDefinitions[typeName] + return pkgDefs.uniqueDefinitions[typeName] } parts := strings.Split(typeName, ".") @@ -290,42 +301,43 @@ func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, p } if !isAliasPkgName(file, parts[0]) { - typeDef, ok := pkgs.uniqueDefinitions[typeName] + typeDef, ok := pkgDefs.uniqueDefinitions[typeName] if ok { return typeDef } } - pkgPath := pkgs.findPackagePathFromImports(parts[0], file, false) + + pkgPath := pkgDefs.findPackagePathFromImports(parts[0], file, false) if len(pkgPath) == 0 { // check if the current package if parts[0] == file.Name.Name { - pkgPath = pkgs.files[file].PackagePath + pkgPath = pkgDefs.files[file].PackagePath } else if parseDependency { // take it as an external package, needs to be loaded - if pkgPath = pkgs.findPackagePathFromImports(parts[0], file, true); len(pkgPath) > 0 { - if err := pkgs.loadExternalPackage(pkgPath); err != nil { + if pkgPath = pkgDefs.findPackagePathFromImports(parts[0], file, true); len(pkgPath) > 0 { + if err := pkgDefs.loadExternalPackage(pkgPath); err != nil { return nil } } } } - return pkgs.findTypeSpec(pkgPath, parts[1]) + return pkgDefs.findTypeSpec(pkgPath, parts[1]) } - typeDef, ok := pkgs.uniqueDefinitions[fullTypeName(file.Name.Name, typeName)] + typeDef, ok := pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, typeName)] if ok { return typeDef } - typeDef = pkgs.findTypeSpec(pkgs.files[file].PackagePath, typeName) + typeDef = pkgDefs.findTypeSpec(pkgDefs.files[file].PackagePath, typeName) if typeDef != nil { return typeDef } for _, imp := range file.Imports { if imp.Name != nil && imp.Name.Name == "." { - typeDef := pkgs.findTypeSpec(strings.Trim(imp.Path.Value, `"`), typeName) + typeDef := pkgDefs.findTypeSpec(strings.Trim(imp.Path.Value, `"`), typeName) if typeDef != nil { return typeDef } diff --git a/parser.go b/parser.go index 980c1d7b2..a390714f4 100644 --- a/parser.go +++ b/parser.go @@ -65,7 +65,7 @@ var ( // ErrFailedConvertPrimitiveType Failed to convert for swag to interpretable type. ErrFailedConvertPrimitiveType = errors.New("swag property: failed convert primitive type") - // ErrSkippedField .swaggo specifies field should be skipped + // ErrSkippedField .swaggo specifies field should be skipped. ErrSkippedField = errors.New("field is skipped by global overrides") ) @@ -142,12 +142,12 @@ type Parser struct { Overrides map[string]string } -// FieldParserFactory create FieldParser +// FieldParserFactory create FieldParser. type FieldParserFactory func(ps *Parser, field *ast.Field) FieldParser -// FieldParser parse struct field +// FieldParser parse struct field. type FieldParser interface { - ShouldSkip() (bool, error) + ShouldSkip() bool FieldName() (string, error) CustomSchema() (*spec.Schema, error) ComplementSchema(schema *spec.Schema) error @@ -161,8 +161,6 @@ type Debugger interface { // New creates a new Parser with default properties. func New(options ...func(*Parser)) *Parser { - // parser.swagger.SecurityDefinitions = - parser := &Parser{ swagger: &spec.Swagger{ SwaggerProps: spec.SwaggerProps{ @@ -177,10 +175,16 @@ func New(options ...func(*Parser)) *Parser { }, Paths: &spec.Paths{ Paths: make(map[string]spec.PathItem), + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, }, Definitions: make(map[string]spec.Schema), SecurityDefinitions: make(map[string]*spec.SecurityScheme), }, + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, }, packages: NewPackagesDefinitions(), debug: log.New(os.Stdout, "", log.LstdFlags), @@ -284,22 +288,22 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st } if parser.ParseDependency { - var t depth.Tree - t.ResolveInternal = true - t.MaxDepth = parseDepth + var tree depth.Tree + tree.ResolveInternal = true + tree.MaxDepth = parseDepth pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath)) if err != nil { return err } - err = t.Resolve(pkgName) + err = tree.Resolve(pkgName) if err != nil { return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err) } - for i := 0; i < len(t.Root.Deps); i++ { - err := parser.getAllGoFileInfoFromDeps(&t.Root.Deps[i]) + for i := 0; i < len(tree.Root.Deps); i++ { + err := parser.getAllGoFileInfoFromDeps(&tree.Root.Deps[i]) if err != nil { return err } @@ -329,7 +333,9 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st func getPkgName(searchDir string) (string, error) { cmd := exec.Command("go", "list", "-f={{.ImportPath}}") cmd.Dir = searchDir + var stdout, stderr strings.Builder + cmd.Stdout = &stdout cmd.Stderr = &stderr @@ -342,7 +348,9 @@ func getPkgName(searchDir string) (string, error) { if outStr[0] == '_' { // will shown like _/{GOPATH}/src/{YOUR_PACKAGE} when NOT enable GO MODULE. outStr = strings.TrimPrefix(outStr, "_"+build.Default.GOPATH+"/src/") } + f := strings.Split(outStr, "\n") + outStr = f[0] return outStr, nil @@ -370,7 +378,8 @@ func (parser *Parser) ParseGeneralAPIInfo(mainAPIFile string) error { if !isGeneralAPIComment(comments) { continue } - err := parseGeneralAPIInfo(parser, comments) + + err = parseGeneralAPIInfo(parser, comments) if err != nil { return err } @@ -383,14 +392,15 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error { previousAttribute := "" // parsing classic meta data model - for i := 0; i < len(comments); i++ { - commentLine := comments[i] + for line := 0; line < len(comments); line++ { + commentLine := comments[line] attribute := strings.Split(commentLine, " ")[0] value := strings.TrimSpace(commentLine[len(attribute):]) multilineBlock := false if previousAttribute == attribute { multilineBlock = true } + switch strings.ToLower(attribute) { case versionAttr: parser.swagger.Info.Version = value @@ -402,12 +412,14 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error { continue } + parser.swagger.Info.Description = value case "@description.markdown": commentInfo, err := getMarkdownForTag("api", parser.markdownFileDir) if err != nil { return err } + parser.swagger.Info.Description = string(commentInfo) case "@termsofservice": parser.swagger.Info.TermsOfService = value @@ -451,56 +463,66 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error { replaceLastTag(parser.swagger.Tags, tag) case "@tag.description.markdown": tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] + commentInfo, err := getMarkdownForTag(tag.TagProps.Name, parser.markdownFileDir) if err != nil { return err } + tag.TagProps.Description = string(commentInfo) replaceLastTag(parser.swagger.Tags, tag) case "@tag.docs.url": tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] tag.TagProps.ExternalDocs = &spec.ExternalDocumentation{ - URL: value, + URL: value, + Description: "", } + replaceLastTag(parser.swagger.Tags, tag) case "@tag.docs.description": tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] if tag.TagProps.ExternalDocs == nil { return fmt.Errorf("%s needs to come after a @tags.docs.url", attribute) } + tag.TagProps.ExternalDocs.Description = value replaceLastTag(parser.swagger.Tags, tag) case "@securitydefinitions.basic": parser.swagger.SecurityDefinitions[value] = spec.BasicAuth() case "@securitydefinitions.apikey": - attrMap, _, extensions, err := parseSecAttr(attribute, []string{"@in", "@name"}, comments, &i) + attrMap, _, extensions, err := parseSecAttr(attribute, []string{"@in", "@name"}, comments, &line) if err != nil { return err } + parser.swagger.SecurityDefinitions[value] = tryAddDescription(spec.APIKeyAuth(attrMap["@name"], attrMap["@in"]), extensions) case "@securitydefinitions.oauth2.application": - attrMap, scopes, extensions, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &i) + attrMap, scopes, extensions, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &line) if err != nil { return err } + parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Application(attrMap["@tokenurl"], scopes, extensions), extensions) case "@securitydefinitions.oauth2.implicit": - attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@authorizationurl"}, comments, &i) + attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@authorizationurl"}, comments, &line) if err != nil { return err } + parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Implicit(attrs["@authorizationurl"], scopes, ext), ext) case "@securitydefinitions.oauth2.password": - attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &i) + attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &line) if err != nil { return err } + parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Password(attrs["@tokenurl"], scopes, ext), ext) case "@securitydefinitions.oauth2.accesscode": - attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl", "@authorizationurl"}, comments, &i) + attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl", "@authorizationurl"}, comments, &line) if err != nil { return err } + parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2AccessToken(attrs["@authorizationurl"], attrs["@tokenurl"], scopes, ext), ext) case "@query.collection.format": parser.collectionFormatInQuery = value @@ -518,17 +540,21 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error { break } } + // if it is present on security def, don't add it again if extExistsInSecurityDef { break } var valueJSON interface{} + split := strings.SplitAfter(commentLine, attribute+" ") if len(split) < 2 { return fmt.Errorf("annotation %s need a value", attribute) } + extensionName := "x-" + strings.SplitAfter(attribute, prefixExtension)[1] + err := json.Unmarshal([]byte(split[1]), &valueJSON) if err != nil { return fmt.Errorf("annotation %s need a valid json value", attribute) @@ -540,10 +566,12 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error { if parser.swagger.Extensions == nil { parser.swagger.Extensions = make(map[string]interface{}) } + parser.swagger.Extensions[attribute[1:]] = valueJSON } } } + previousAttribute = attribute } @@ -556,6 +584,7 @@ func tryAddDescription(spec *spec.SecurityScheme, extensions map[string]interfac spec.Description = str } } + return spec } @@ -592,32 +621,40 @@ func parseSecAttr(context string, search []string, lines []string, index *int) ( for ; *index < len(lines); *index++ { v := lines[*index] + securityAttr := strings.ToLower(strings.Split(v, " ")[0]) for _, findterm := range search { if securityAttr == findterm { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) + continue } } + isExists, err := isExistsScope(securityAttr) if err != nil { return nil, nil, nil, err } + if isExists { scopes[securityAttr[len(scopeAttrPrefix):]] = v[len(securityAttr):] } + if strings.HasPrefix(securityAttr, "@x-") { // Add the custom attribute without the @ extensions[securityAttr[1:]] = strings.TrimSpace(v[len(securityAttr):]) } + // Not mandatory field if securityAttr == "@description" { extensions[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } + // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { // Go back to the previous line and break *index-- + break } } @@ -629,8 +666,12 @@ func parseSecAttr(context string, search []string, lines []string, index *int) ( return attrMap, scopes, extensions, nil } -func secOAuth2Application(tokenURL string, scopes map[string]string, - extensions map[string]interface{}) *spec.SecurityScheme { +type ( + authExtensions map[string]interface{} + authScopes map[string]string +) + +func secOAuth2Application(tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme { securityScheme := spec.OAuth2Application(tokenURL) securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions) for scope, description := range scopes { @@ -640,10 +681,10 @@ func secOAuth2Application(tokenURL string, scopes map[string]string, return securityScheme } -func secOAuth2Implicit(authorizationURL string, scopes map[string]string, - extensions map[string]interface{}) *spec.SecurityScheme { +func secOAuth2Implicit(authorizationURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme { securityScheme := spec.OAuth2Implicit(authorizationURL) securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions) + for scope, description := range scopes { securityScheme.AddScope(scope, description) } @@ -651,10 +692,10 @@ func secOAuth2Implicit(authorizationURL string, scopes map[string]string, return securityScheme } -func secOAuth2Password(tokenURL string, scopes map[string]string, - extensions map[string]interface{}) *spec.SecurityScheme { +func secOAuth2Password(tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme { securityScheme := spec.OAuth2Password(tokenURL) securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions) + for scope, description := range scopes { securityScheme.AddScope(scope, description) } @@ -662,10 +703,10 @@ func secOAuth2Password(tokenURL string, scopes map[string]string, return securityScheme } -func secOAuth2AccessToken(authorizationURL, tokenURL string, - scopes map[string]string, extensions map[string]interface{}) *spec.SecurityScheme { +func secOAuth2AccessToken(authorizationURL, tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme { securityScheme := spec.OAuth2AccessToken(authorizationURL, tokenURL) securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions) + for scope, description := range scopes { securityScheme.AddScope(scope, description) } @@ -673,7 +714,7 @@ func secOAuth2AccessToken(authorizationURL, tokenURL string, return securityScheme } -func handleSecuritySchemaExtensions(providedExtensions map[string]interface{}) spec.Extensions { +func handleSecuritySchemaExtensions(providedExtensions authExtensions) spec.Extensions { var extensions spec.Extensions if len(providedExtensions) > 0 { extensions = make(map[string]interface{}, len(providedExtensions)) @@ -695,6 +736,7 @@ func getMarkdownForTag(tagName string, dirPath string) ([]byte, error) { if fileInfo.IsDir() { continue } + fileName := fileInfo.Name() if !strings.Contains(fileName, ".md") { @@ -703,6 +745,7 @@ func getMarkdownForTag(tagName string, dirPath string) ([]byte, error) { if strings.Contains(fileName, tagName) { fullPath := filepath.Join(dirPath, fileName) + commentInfo, err := ioutil.ReadFile(fullPath) if err != nil { return nil, fmt.Errorf("Failed to read markdown file %s error: %s ", fullPath, err) @@ -749,29 +792,9 @@ func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) err } } - for _, routeProperties := range operation.RouterProperties { - var pathItem spec.PathItem - var ok bool - - pathItem, ok = parser.swagger.Paths.Paths[routeProperties.Path] - if !ok { - pathItem = spec.PathItem{} - } - - op := refRouteMethodOp(&pathItem, routeProperties.HTTPMethod) - - // check if we already have a operation for this path and method - if *op != nil { - err := fmt.Errorf("route %s %s is declared multiple times", routeProperties.HTTPMethod, routeProperties.Path) - if parser.Strict { - return err - } - parser.debug.Printf("warning: %s\n", err) - } - - *op = &operation.Operation - - parser.swagger.Paths.Paths[routeProperties.Path] = pathItem + err := processRouterOperation(parser, operation) + if err != nil { + return err } } } @@ -796,14 +819,48 @@ func refRouteMethodOp(item *spec.PathItem, method string) (op **spec.Operation) case http.MethodOptions: op = &item.Options } + return } +func processRouterOperation(parser *Parser, operation *Operation) error { + for _, routeProperties := range operation.RouterProperties { + var ( + pathItem spec.PathItem + ok bool + ) + + pathItem, ok = parser.swagger.Paths.Paths[routeProperties.Path] + if !ok { + pathItem = spec.PathItem{} + } + + op := refRouteMethodOp(&pathItem, routeProperties.HTTPMethod) + + // check if we already have an operation for this path and method + if *op != nil { + err := fmt.Errorf("route %s %s is declared multiple times", routeProperties.HTTPMethod, routeProperties.Path) + if parser.Strict { + return err + } + + parser.debug.Printf("warning: %s\n", err) + } + + *op = &operation.Operation + + parser.swagger.Paths.Paths[routeProperties.Path] = pathItem + } + + return nil +} + func convertFromSpecificToPrimitive(typeName string) (string, error) { name := typeName if strings.ContainsRune(name, '.') { name = strings.Split(name, ".")[1] } + switch strings.ToUpper(name) { case "TIME", "OBJECTID", "UUID": return STRING, nil @@ -832,6 +889,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( if override, ok := parser.Overrides[typeSpecDef.FullPath()]; ok { if override == "" { parser.debug.Printf("Override detected for %s: ignoring", typeSpecDef.FullPath()) + return nil, ErrSkippedField } @@ -841,6 +899,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( if separator == -1 { // treat as a swaggertype tag parts := strings.Split(override, ",") + return BuildCustomSchema(parts) } @@ -850,6 +909,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( schema, ok := parser.parsedSchemas[typeSpecDef] if !ok { var err error + schema, err = parser.ParseDefinition(typeSpecDef) if err != nil { if err == ErrRecursiveParseStruct && ref { @@ -885,8 +945,10 @@ func (parser *Parser) renameRefSchemas() { for _, refURL := range parser.toBeRenamedRefURLs { parts := strings.Split(refURL.Fragment, "/") name := parts[len(parts)-1] + if pkgPath, ok := parser.toBeRenamedSchemas[name]; ok { parts[len(parts)-1] = parser.renameSchema(name, pkgPath) + refURL.Fragment = strings.Join(parts, "/") } } @@ -915,6 +977,7 @@ func (parser *Parser) getRefTypeSchema(typeSpecDef *TypeSpecDef, schema *Schema) } else { parser.existSchemaNames[schema.Name] = schema } + parser.swagger.Definitions[schema.Name] = spec.Schema{} if schema.Schema != nil { @@ -948,8 +1011,8 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) typeName := typeSpecDef.FullName() refTypeName := TypeDocName(typeName, typeSpecDef.TypeSpec) - schema, ok := parser.parsedSchemas[typeSpecDef] - if ok { + schema, found := parser.parsedSchemas[typeSpecDef] + if found { parser.debug.Printf("Skipping '%s', already parsed.", typeName) return schema, nil @@ -965,6 +1028,7 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) }, ErrRecursiveParseStruct } + parser.structStack = append(parser.structStack, typeSpecDef) parser.debug.Printf("Generating %s", typeName) @@ -978,20 +1042,20 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) fillDefinitionDescription(definition, typeSpecDef.File, typeSpecDef) } - s := Schema{ + sch := Schema{ Name: refTypeName, PkgPath: typeSpecDef.PkgPath, Schema: definition, } - parser.parsedSchemas[typeSpecDef] = &s + parser.parsedSchemas[typeSpecDef] = &sch // update an empty schema as a result of recursion - s2, ok := parser.outputSchemas[typeSpecDef] - if ok { + s2, found := parser.outputSchemas[typeSpecDef] + if found { parser.swagger.Definitions[s2.Name] = *definition } - return &s, nil + return &sch, nil } func fullTypeName(pkgName, typeName string) string { @@ -1034,13 +1098,16 @@ func extractDeclarationDescription(commentGroups ...*ast.CommentGroup) string { } isHandlingDescription := false + for _, comment := range commentGroup.List { commentText := strings.TrimSpace(strings.TrimLeft(comment.Text, "/")) attribute := strings.Split(commentText, " ")[0] + if strings.ToLower(attribute) != descriptionAttr { if !isHandlingDescription { continue } + break } @@ -1108,8 +1175,8 @@ func (parser *Parser) parseTypeExpr(file *ast.File, typeExpr ast.Expr, ref bool) } func (parser *Parser) parseStruct(file *ast.File, fields *ast.FieldList) (*spec.Schema, error) { - required := make([]string, 0) - properties := make(map[string]spec.Schema) + required, properties := make([]string, 0), make(map[string]spec.Schema) + for _, field := range fields.List { fieldProps, requiredFromAnon, err := parser.parseStructField(file, field) if err != nil { @@ -1119,10 +1186,13 @@ func (parser *Parser) parseStruct(file *ast.File, fields *ast.FieldList) (*spec. return nil, err } + if len(fieldProps) == 0 { continue } + required = append(required, requiredFromAnon...) + for k, v := range fieldProps { properties[k] = v } @@ -1152,10 +1222,12 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st if err != nil { return nil, nil, err } + schema, err := parser.getTypeSchema(typeName, file, false) if err != nil { return nil, nil, err } + if len(schema.Type) > 0 && schema.Type[0] == OBJECT { if len(schema.Properties) == 0 { return nil, nil, nil @@ -1175,11 +1247,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st ps := parser.fieldParserFactory(parser, field) - ok, err := ps.ShouldSkip() - if err != nil { - return nil, nil, err - } - if ok { + if ps.ShouldSkip() { return nil, nil, nil } @@ -1192,6 +1260,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st if err != nil { return nil, nil, err } + if schema == nil { typeName, err := getFieldType(field.Type) if err == nil { @@ -1201,6 +1270,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st // unnamed type schema, err = parser.parseTypeExpr(file, field.Type, false) } + if err != nil { return nil, nil, err } @@ -1212,10 +1282,12 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st } var tagRequired []string + required, err := ps.IsRequired() if err != nil { return nil, nil, err } + if required { tagRequired = append(tagRequired, fieldName) } @@ -1234,7 +1306,6 @@ func getFieldType(field ast.Expr) (string, error) { } return fullTypeName(packageName, fieldType.Sel.Name), nil - case *ast.StarExpr: fullName, err := getFieldType(fieldType.X) if err != nil { @@ -1252,6 +1323,7 @@ func (parser *Parser) GetSchemaTypePath(schema *spec.Schema, depth int) []string if schema == nil || depth == 0 { return nil } + name := schema.Ref.String() if name != "" { if pos := strings.LastIndexByte(name, '/'); pos >= 0 { @@ -1263,10 +1335,12 @@ func (parser *Parser) GetSchemaTypePath(schema *spec.Schema, depth int) []string return nil } + if len(schema.Type) > 0 { switch schema.Type[0] { case ARRAY: depth-- + s := []string{schema.Type[0]} return append(s, parser.GetSchemaTypePath(schema.Items.Schema, depth)...) @@ -1274,6 +1348,7 @@ func (parser *Parser) GetSchemaTypePath(schema *spec.Schema, depth int) []string if schema.AdditionalProperties != nil && schema.AdditionalProperties.Schema != nil { // for map depth-- + s := []string{schema.Type[0]} return append(s, parser.GetSchemaTypePath(schema.AdditionalProperties.Schema, depth)...) @@ -1324,6 +1399,7 @@ func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{ if err != nil { return nil, err } + result = append(result, v) } @@ -1334,7 +1410,9 @@ func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{ } values := strings.Split(exampleValue, ",") + result := map[string]interface{}{} + for _, value := range values { mapData := strings.Split(value, ":") @@ -1343,10 +1421,14 @@ func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{ if err != nil { return nil, err } + result[mapData[0]] = v - } else { - return nil, fmt.Errorf("example value %s should format: key:value", exampleValue) + + continue + } + + return nil, fmt.Errorf("example value %s should format: key:value", exampleValue) } return result, nil @@ -1358,9 +1440,12 @@ func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{ // GetAllGoFileInfo gets all Go source files information for given searchDir. func (parser *Parser) getAllGoFileInfo(packageDir, searchDir string) error { return filepath.Walk(searchDir, func(path string, f os.FileInfo, _ error) error { - if err := parser.Skip(path, f); err != nil { + err := parser.Skip(path, f) + if err != nil { return err - } else if f.IsDir() { + } + + if f.IsDir() { return nil } @@ -1383,7 +1468,9 @@ func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg) error { if pkg.Raw == nil && pkg.Name == "C" { return nil } + srcDir := pkg.Raw.Dir + files, err := ioutil.ReadDir(srcDir) // only parsing files in the dir(don't contain sub dir files) if err != nil { return err @@ -1434,24 +1521,29 @@ func (parser *Parser) checkOperationIDUniqueness() error { for path, item := range parser.swagger.Paths.Paths { var method, id string + for method = range allMethod { op := refRouteMethodOp(&item, method) if *op != nil { id = (**op).ID + break } } + if id == "" { continue } current := fmt.Sprintf("%s %s", method, path) + previous, ok := operationsIds[id] if ok { return fmt.Errorf( "duplicated @id annotation '%s' found in '%s', previously declared in: '%s'", id, current, previous) } + operationsIds[id] = current } diff --git a/parser_test.go b/parser_test.go index 908d54a7f..392a04fe2 100644 --- a/parser_test.go +++ b/parser_test.go @@ -212,7 +212,9 @@ func TestParser_ParseGeneralApiInfo(t *testing.T) { }` gopath := os.Getenv("GOPATH") assert.NotNil(t, gopath) + p := New() + err := p.ParseGeneralAPIInfo("testdata/main.go") assert.NoError(t, err) @@ -295,7 +297,9 @@ func TestParser_ParseGeneralApiInfoTemplated(t *testing.T) { }` gopath := os.Getenv("GOPATH") assert.NotNil(t, gopath) + p := New() + err := p.ParseGeneralAPIInfo("testdata/templated.go") assert.NoError(t, err) @@ -311,7 +315,9 @@ func TestParser_ParseGeneralApiInfoExtensions(t *testing.T) { expected := "annotation @x-google-endpoints need a valid json value" gopath := os.Getenv("GOPATH") assert.NotNil(t, gopath) + p := New() + err := p.ParseGeneralAPIInfo("testdata/extensionsFail1.go") if assert.Error(t, err) { assert.Equal(t, expected, err.Error()) @@ -325,7 +331,9 @@ func TestParser_ParseGeneralApiInfoExtensions(t *testing.T) { expected := "annotation @x-google-endpoints need a value" gopath := os.Getenv("GOPATH") assert.NotNil(t, gopath) + p := New() + err := p.ParseGeneralAPIInfo("testdata/extensionsFail2.go") if assert.Error(t, err) { assert.Equal(t, expected, err.Error()) @@ -350,7 +358,9 @@ func TestParser_ParseGeneralApiInfoWithOpsInSameFile(t *testing.T) { gopath := os.Getenv("GOPATH") assert.NotNil(t, gopath) + p := New() + err := p.ParseGeneralAPIInfo("testdata/single_file_api/main.go") assert.NoError(t, err) @@ -387,6 +397,7 @@ func TestParser_ParseGeneralAPIInfoMarkdown(t *testing.T) { assert.Equal(t, expected, string(b)) p = New() + err = p.ParseGeneralAPIInfo(mainAPIFile) assert.Error(t, err) } @@ -3332,7 +3343,8 @@ func TestDefineTypeOfExample(t *testing.T) { example, err = defineTypeOfExample("array", "string", "one,two,three") assert.NoError(t, err) - arr := []string{} + + var arr []string for _, v := range example.([]interface{}) { arr = append(arr, v.(string)) diff --git a/schema.go b/schema.go index 0e72f65d7..a23d21b36 100644 --- a/schema.go +++ b/schema.go @@ -26,6 +26,8 @@ const ( STRING = "string" // FUNC represent a function value. FUNC = "func" + // INTERFACE represent a interface value. + INTERFACE = "interface{}" // ANY represent a any value. ANY = "any" // NIL represent a empty value. @@ -133,6 +135,7 @@ func TypeDocName(pkgName string, spec *ast.TypeSpec) string { } } } + if spec.Name != nil { return fullTypeName(strings.Split(pkgName, ".")[0], spec.Name.Name) } @@ -168,6 +171,7 @@ func BuildCustomSchema(types []string) (*spec.Schema, error) { if len(types) == 1 { return nil, errors.New("need array item type after array") } + schema, err := BuildCustomSchema(types[1:]) if err != nil { return nil, err @@ -178,6 +182,7 @@ func BuildCustomSchema(types []string) (*spec.Schema, error) { if len(types) == 1 { return PrimitiveSchema(types[0]), nil } + schema, err := BuildCustomSchema(types[1:]) if err != nil { return nil, err diff --git a/schema_test.go b/schema_test.go index dbb7b614c..0fe60e816 100644 --- a/schema_test.go +++ b/schema_test.go @@ -85,8 +85,10 @@ func TestIsSimplePrimitiveType(t *testing.T) { func TestBuildCustomSchema(t *testing.T) { t.Parallel() - var schema *spec.Schema - var err error + var ( + schema *spec.Schema + err error + ) schema, err = BuildCustomSchema([]string{}) assert.NoError(t, err) diff --git a/spec.go b/spec.go index 9e0ec1ad0..3a727c940 100644 --- a/spec.go +++ b/spec.go @@ -21,31 +21,33 @@ type Spec struct { // ReadDoc parses SwaggerTemplate into swagger document. func (i *Spec) ReadDoc() string { - i.Description = strings.Replace(i.Description, "\n", "\\n", -1) + i.Description = strings.ReplaceAll(i.Description, "\n", "\\n") - t, err := template.New("swagger_info").Funcs(template.FuncMap{ + tpl, err := template.New("swagger_info").Funcs(template.FuncMap{ "marshal": func(v interface{}) string { a, _ := json.Marshal(v) + return string(a) }, "escape": func(v interface{}) string { // escape tabs - str := strings.Replace(v.(string), "\t", "\\t", -1) + var str = strings.ReplaceAll(v.(string), "\t", "\\t") // replace " with \", and if that results in \\", replace that with \\\" - str = strings.Replace(str, "\"", "\\\"", -1) - return strings.Replace(str, "\\\\\"", "\\\\\\\"", -1) + str = strings.ReplaceAll(str, "\"", "\\\"") + + return strings.ReplaceAll(str, "\\\\\"", "\\\\\\\"") }, }).Parse(i.SwaggerTemplate) if err != nil { return i.SwaggerTemplate } - var tpl bytes.Buffer - if err = t.Execute(&tpl, i); err != nil { + var doc bytes.Buffer + if err = tpl.Execute(&doc, i); err != nil { return i.SwaggerTemplate } - return tpl.String() + return doc.String() } // InstanceName returns Spec instance name. diff --git a/spec_test.go b/spec_test.go index 00ab37007..75eb7f37d 100644 --- a/spec_test.go +++ b/spec_test.go @@ -1,8 +1,9 @@ package swag import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestSpec_InstanceName(t *testing.T) { @@ -16,6 +17,7 @@ func TestSpec_InstanceName(t *testing.T) { InfoInstanceName string SwaggerTemplate string } + tests := []struct { name string fields fields @@ -32,9 +34,10 @@ func TestSpec_InstanceName(t *testing.T) { want: "TestInstanceName1", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - i := &Spec{ + doc := Spec{ Version: tt.fields.Version, Host: tt.fields.Host, BasePath: tt.fields.BasePath, @@ -44,7 +47,8 @@ func TestSpec_InstanceName(t *testing.T) { InfoInstanceName: tt.fields.InfoInstanceName, SwaggerTemplate: tt.fields.SwaggerTemplate, } - assert.Equal(t, tt.want, i.InstanceName()) + + assert.Equal(t, tt.want, doc.InstanceName()) }) } } @@ -60,6 +64,7 @@ func TestSpec_ReadDoc(t *testing.T) { InfoInstanceName string SwaggerTemplate string } + tests := []struct { name string fields fields @@ -128,9 +133,10 @@ func TestSpec_ReadDoc(t *testing.T) { want: "{{ .Schemesa }}", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - i := &Spec{ + doc := Spec{ Version: tt.fields.Version, Host: tt.fields.Host, BasePath: tt.fields.BasePath, @@ -140,7 +146,8 @@ func TestSpec_ReadDoc(t *testing.T) { InfoInstanceName: tt.fields.InfoInstanceName, SwaggerTemplate: tt.fields.SwaggerTemplate, } - assert.Equal(t, tt.want, i.ReadDoc()) + + assert.Equal(t, tt.want, doc.ReadDoc()) }) } } diff --git a/swagger.go b/swagger.go index c00feb22b..5ffbab63e 100644 --- a/swagger.go +++ b/swagger.go @@ -23,6 +23,7 @@ type Swagger interface { func Register(name string, swagger Swagger) { swaggerMu.Lock() defer swaggerMu.Unlock() + if swagger == nil { panic("swagger is nil") } @@ -34,6 +35,7 @@ func Register(name string, swagger Swagger) { if _, ok := swags[name]; ok { panic("Register called twice for swag: " + name) } + swags[name] = swagger }