diff --git a/field_parser.go b/field_parser.go index b203226ad..b5cf26242 100644 --- a/field_parser.go +++ b/field_parser.go @@ -146,6 +146,46 @@ type structField struct { unique bool } +// splitNotWrapped slices s into all substrings separated by sep if sep is not +// wrapped by brackets and returns a slice of the substrings between those separators. +func splitNotWrapped(s string, sep rune) []string { + openCloseMap := map[rune]rune{ + '(': ')', + '[': ']', + '{': '}', + } + + result := make([]string, 0) + current := "" + var openCount = 0 + var openChar rune + for _, char := range s { + if openChar == 0 && openCloseMap[char] != 0 { + openChar = char + openCount++ + current += string(char) + } else if char == openChar { + openCount++ + current = current + string(char) + } else if openCount > 0 && char == openCloseMap[openChar] { + openCount-- + current += string(char) + } else if openCount == 0 && char == sep { + result = append(result, current) + openChar = 0 + current = "" + } else { + current += string(char) + } + } + + if current != "" { + result = append(result, current) + } + + return result +} + func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { types := ps.p.GetSchemaTypePath(schema, 2) if len(types) == 0 { @@ -207,7 +247,7 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { extensionsTag := ps.tag.Get(extensionsTag) if extensionsTag != "" { structField.extensions = map[string]interface{}{} - for _, val := range strings.Split(extensionsTag, ",") { + for _, val := range splitNotWrapped(extensionsTag, ',') { parts := strings.SplitN(val, "=", 2) if len(parts) == 2 { structField.extensions[parts[0]] = parts[1] diff --git a/field_parser_test.go b/field_parser_test.go index a3a21467e..1b0e1d639 100644 --- a/field_parser_test.go +++ b/field_parser_test.go @@ -83,7 +83,23 @@ func TestDefaultFieldParser(t *testing.T) { }) t.Run("Extensions tag", func(t *testing.T) { + t.Parallel() + schema := spec.Schema{} + schema.Type = []string{"int"} + schema.Extensions = map[string]interface{}{} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" extensions:"x-nullable,x-abc=def,!x-omitempty,x-example=[0, 9],x-example2={çãíœ, (bar=(abc, def)), [0,9]}"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, true, schema.Extensions["x-nullable"]) + assert.Equal(t, "def", schema.Extensions["x-abc"]) + assert.Equal(t, false, schema.Extensions["x-omitempty"]) + assert.Equal(t, "[0, 9]", schema.Extensions["x-example"]) + assert.Equal(t, "{çãíœ, (bar=(abc, def)), [0,9]}", schema.Extensions["x-example2"]) }) t.Run("Enums tag", func(t *testing.T) { diff --git a/operation.go b/operation.go index 642df7777..16d748891 100644 --- a/operation.go +++ b/operation.go @@ -500,7 +500,7 @@ func setEnumParam(param *spec.Parameter, attr, objectType, schemaType string) er func setExtensionParam(param *spec.Parameter, attr string) error { param.Extensions = map[string]interface{}{} - for _, val := range strings.Split(attr, ",") { + for _, val := range splitNotWrapped(attr, ',') { parts := strings.SplitN(val, "=", 2) if len(parts) == 2 { param.Extensions.Add(parts[0], parts[1]) diff --git a/parser_test.go b/parser_test.go index 223b9f51d..0799ea242 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2959,6 +2959,63 @@ func Fun() { assert.Equal(t, expected, string(b)) } +func TestParseParamCommentExtension(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request query string true "query params" extensions(x-example=[0, 9],x-foo=bar) +// @Success 200 +// @Router /test [get] +func Fun() { + +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "get": { + "parameters": [ + { + "type": "string", + "x-example": "[0, 9]", + "x-foo": "bar", + "description": "query params", + "name": "request", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "" + } + } + } + } + } +}` + + f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + assert.NoError(t, err) + + p := New() + _ = p.packages.CollectAstFile("api", "api/api.go", f) + + _, err = p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.ParseRouterAPIInfo("", f) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.JSONEq(t, expected, string(b)) +} + func TestParseRenamedStructDefinition(t *testing.T) { t.Parallel()