Skip to content

Commit

Permalink
fix: split extensions param respecting open/close brackets (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
helder-jaspion committed Feb 7, 2022
1 parent 24209aa commit 3a778dc
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 2 deletions.
42 changes: 41 additions & 1 deletion field_parser.go
Expand Up @@ -146,6 +146,46 @@ type structField struct {
unique bool
}

// splitNotWrapped slices s into all substrings separated by sep if sep is not
// wrapped by brackets and returns a slice of the substrings between those separators.
func splitNotWrapped(s string, sep rune) []string {
openCloseMap := map[rune]rune{
'(': ')',
'[': ']',
'{': '}',
}

result := make([]string, 0)
current := ""
var openCount = 0
var openChar rune
for _, char := range s {
if openChar == 0 && openCloseMap[char] != 0 {
openChar = char
openCount++
current += string(char)
} else if char == openChar {
openCount++
current = current + string(char)
} else if openCount > 0 && char == openCloseMap[openChar] {
openCount--
current += string(char)
} else if openCount == 0 && char == sep {
result = append(result, current)
openChar = 0
current = ""
} else {
current += string(char)
}
}

if current != "" {
result = append(result, current)
}

return result
}

func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error {
types := ps.p.GetSchemaTypePath(schema, 2)
if len(types) == 0 {
Expand Down Expand Up @@ -207,7 +247,7 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error {
extensionsTag := ps.tag.Get(extensionsTag)
if extensionsTag != "" {
structField.extensions = map[string]interface{}{}
for _, val := range strings.Split(extensionsTag, ",") {
for _, val := range splitNotWrapped(extensionsTag, ',') {
parts := strings.SplitN(val, "=", 2)
if len(parts) == 2 {
structField.extensions[parts[0]] = parts[1]
Expand Down
16 changes: 16 additions & 0 deletions field_parser_test.go
Expand Up @@ -83,7 +83,23 @@ func TestDefaultFieldParser(t *testing.T) {
})

t.Run("Extensions tag", func(t *testing.T) {
t.Parallel()

schema := spec.Schema{}
schema.Type = []string{"int"}
schema.Extensions = map[string]interface{}{}
err := newTagBaseFieldParser(
&Parser{},
&ast.Field{Tag: &ast.BasicLit{
Value: `json:"test" extensions:"x-nullable,x-abc=def,!x-omitempty,x-example=[0, 9],x-example2={çãíœ, (bar=(abc, def)), [0,9]}"`,
}},
).ComplementSchema(&schema)
assert.NoError(t, err)
assert.Equal(t, true, schema.Extensions["x-nullable"])
assert.Equal(t, "def", schema.Extensions["x-abc"])
assert.Equal(t, false, schema.Extensions["x-omitempty"])
assert.Equal(t, "[0, 9]", schema.Extensions["x-example"])
assert.Equal(t, "{çãíœ, (bar=(abc, def)), [0,9]}", schema.Extensions["x-example2"])
})

t.Run("Enums tag", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion operation.go
Expand Up @@ -500,7 +500,7 @@ func setEnumParam(param *spec.Parameter, attr, objectType, schemaType string) er

func setExtensionParam(param *spec.Parameter, attr string) error {
param.Extensions = map[string]interface{}{}
for _, val := range strings.Split(attr, ",") {
for _, val := range splitNotWrapped(attr, ',') {
parts := strings.SplitN(val, "=", 2)
if len(parts) == 2 {
param.Extensions.Add(parts[0], parts[1])
Expand Down
57 changes: 57 additions & 0 deletions parser_test.go
Expand Up @@ -2959,6 +2959,63 @@ func Fun() {
assert.Equal(t, expected, string(b))
}

func TestParseParamCommentExtension(t *testing.T) {
t.Parallel()

src := `
package main
// @Param request query string true "query params" extensions(x-example=[0, 9],x-foo=bar)
// @Success 200
// @Router /test [get]
func Fun() {
}
`
expected := `{
"info": {
"contact": {}
},
"paths": {
"/test": {
"get": {
"parameters": [
{
"type": "string",
"x-example": "[0, 9]",
"x-foo": "bar",
"description": "query params",
"name": "request",
"in": "query",
"required": true
}
],
"responses": {
"200": {
"description": ""
}
}
}
}
}
}`

f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments)
assert.NoError(t, err)

p := New()
_ = p.packages.CollectAstFile("api", "api/api.go", f)

_, err = p.packages.ParseTypes()
assert.NoError(t, err)

err = p.ParseRouterAPIInfo("", f)
assert.NoError(t, err)

b, _ := json.MarshalIndent(p.swagger, "", " ")
assert.JSONEq(t, expected, string(b))
}

func TestParseRenamedStructDefinition(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 3a778dc

Please sign in to comment.