diff --git a/field_parser.go b/field_parser.go index faaa53ec9..9b24e7872 100644 --- a/field_parser.go +++ b/field_parser.go @@ -96,13 +96,25 @@ func (ps *tagBaseFieldParser) FieldName() (string, error) { } } -func (ps *tagBaseFieldParser) FormName() string { +func (ps *tagBaseFieldParser) firstTagValue(tag string) string { if ps.field.Tag != nil { - return strings.TrimRight(strings.TrimSpace(strings.Split(ps.tag.Get(formTag), ",")[0]), "[]") + return strings.TrimRight(strings.TrimSpace(strings.Split(ps.tag.Get(tag), ",")[0]), "[]") } return "" } +func (ps *tagBaseFieldParser) FormName() string { + return ps.firstTagValue(formTag) +} + +func (ps *tagBaseFieldParser) HeaderName() string { + return ps.firstTagValue(headerTag) +} + +func (ps *tagBaseFieldParser) PathName() string { + return ps.firstTagValue(uriTag) +} + func toSnakeCase(in string) string { var ( runes = []rune(in) diff --git a/operation.go b/operation.go index 8cf7d5b95..169510ffc 100644 --- a/operation.go +++ b/operation.go @@ -286,16 +286,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F param := createParameter(paramType, description, name, objectType, refType, required, enums, operation.parser.collectionFormatInQuery) switch paramType { - case "path", "header": - switch objectType { - case ARRAY: - if !IsPrimitiveType(refType) { - return fmt.Errorf("%s is not supported array type for %s", refType, paramType) - } - case OBJECT: - return fmt.Errorf("%s is not supported type for %s", refType, paramType) - } - case "query", "formData": + case "path", "header", "query", "formData": switch objectType { case ARRAY: if !IsPrimitiveType(refType) && !(refType == "file" && paramType == "formData") { @@ -324,11 +315,14 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F } } - var formName = name - if item.Schema.Extensions != nil { - if nameVal, ok := item.Schema.Extensions[formTag]; ok { - formName = nameVal.(string) - } + nameOverrideType := paramType + // query also uses formData tags + if paramType == "query" { + nameOverrideType = "formData" + } + // load overridden type specific name from extensions if exists + if nameVal, ok := item.Schema.Extensions[nameOverrideType]; ok { + name = nameVal.(string) } switch { @@ -346,10 +340,10 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F if !IsSimplePrimitiveType(itemSchema.Type[0]) { continue } - param = createParameter(paramType, prop.Description, formName, prop.Type[0], itemSchema.Type[0], findInSlice(schema.Required, name), itemSchema.Enum, operation.parser.collectionFormatInQuery) + param = createParameter(paramType, prop.Description, name, prop.Type[0], itemSchema.Type[0], findInSlice(schema.Required, item.Name), itemSchema.Enum, operation.parser.collectionFormatInQuery) case IsSimplePrimitiveType(prop.Type[0]): - param = createParameter(paramType, prop.Description, formName, PRIMITIVE, prop.Type[0], findInSlice(schema.Required, name), nil, operation.parser.collectionFormatInQuery) + param = createParameter(paramType, prop.Description, name, PRIMITIVE, prop.Type[0], findInSlice(schema.Required, item.Name), nil, operation.parser.collectionFormatInQuery) default: operation.parser.debug.Printf("skip field [%s] in %s is not supported type for %s", name, refType, paramType) continue @@ -406,6 +400,8 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F const ( formTag = "form" jsonTag = "json" + uriTag = "uri" + headerTag = "header" bindingTag = "binding" defaultTag = "default" enumsTag = "enums" diff --git a/operation_test.go b/operation_test.go index e9214d71a..fb43a446e 100644 --- a/operation_test.go +++ b/operation_test.go @@ -2,6 +2,7 @@ package swag import ( "encoding/json" + "fmt" "go/ast" goparser "go/parser" "go/token" @@ -1177,11 +1178,17 @@ func TestOperation_ParseParamComment(t *testing.T) { t.Parallel() for _, paramType := range []string{"header", "path", "query", "formData"} { t.Run(paramType, func(t *testing.T) { + // unknown object returns error assert.Error(t, NewOperation(nil).ParseComment(`@Param some_object `+paramType+` main.Object true "Some Object"`, nil)) + + // verify objects are supported here + o := NewOperation(nil) + o.parser.addTestType("main.TestObject") + err := o.ParseComment(`@Param some_object `+paramType+` main.TestObject true "Some Object"`, nil) + assert.NoError(t, err) }) } }) - } // Test ParseParamComment Query Params @@ -2067,6 +2074,146 @@ func TestParseParamCommentByExtensions(t *testing.T) { assert.Equal(t, expected, string(b)) } +func TestParseParamStructCodeExample(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + ast, err := goparser.ParseFile(fset, "operation_test.go", `package swag + import structs "github.com/swaggo/swag/testdata/param_structs" + `, goparser.ParseComments) + assert.NoError(t, err) + + parser := New() + err = parser.parseFile("github.com/swaggo/swag/testdata/param_structs", "testdata/param_structs/structs.go", nil, ParseModels) + assert.NoError(t, err) + _, err = parser.packages.ParseTypes() + assert.NoError(t, err) + + validateParameters := func(operation *Operation, params ...spec.Parameter) { + assert.Equal(t, len(params), len(operation.Parameters)) + + for _, param := range params { + found := false + for _, p := range operation.Parameters { + if p.Name == param.Name { + assert.Equal(t, param.ParamProps, p.ParamProps) + assert.Equal(t, param.CommonValidations, p.CommonValidations) + assert.Equal(t, param.SimpleSchema, p.SimpleSchema) + found = true + break + } + } + assert.True(t, found, "found parameter %s", param.Name) + } + } + + // values used in validation checks + max := float64(10) + maxLen := int64(10) + min := float64(0) + + // query and form behave the same + for _, param := range []string{"query", "formData"} { + t.Run(param+" struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := fmt.Sprintf(`@Param model %s structs.FormModel true "query params"`, param) + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "f", + Description: "", + In: param, + Required: true, + }, + CommonValidations: spec.CommonValidations{ + MaxLength: &maxLen, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "b", + Description: "B is another field", + In: param, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "boolean", + }, + }) + }) + } + + t.Run("header struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := `@Param auth header structs.AuthHeader true "auth header"` + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "X-Auth-Token", + Description: "Token is the auth token", + In: "header", + Required: true, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }, spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "anotherHeader", + Description: "AnotherHeader is another header", + In: "header", + }, + CommonValidations: spec.CommonValidations{ + Maximum: &max, + Minimum: &min, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "integer", + }, + }) + }) + + t.Run("path struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := `@Param path path structs.PathModel true "path params"` + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "id", + Description: "ID is the id", + In: "path", + Required: true, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "integer", + }, + }, spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "name", + Description: "", + In: "path", + }, + CommonValidations: spec.CommonValidations{ + MaxLength: &maxLen, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }) + }) +} + func TestParseIdComment(t *testing.T) { t.Parallel() diff --git a/parser.go b/parser.go index a97da1d59..6832d5886 100644 --- a/parser.go +++ b/parser.go @@ -189,6 +189,8 @@ type FieldParser interface { ShouldSkip() bool FieldName() (string, error) FormName() string + HeaderName() string + PathName() string CustomSchema() (*spec.Schema, error) ComplementSchema(schema *spec.Schema) error IsRequired() (bool, error) @@ -1506,11 +1508,17 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st tagRequired = append(tagRequired, fieldName) } + if schema.Extensions == nil { + schema.Extensions = make(spec.Extensions) + } if formName := ps.FormName(); len(formName) > 0 { - if schema.Extensions == nil { - schema.Extensions = make(spec.Extensions) - } - schema.Extensions[formTag] = formName + schema.Extensions["formData"] = formName + } + if headerName := ps.HeaderName(); len(headerName) > 0 { + schema.Extensions["header"] = headerName + } + if pathName := ps.PathName(); len(pathName) > 0 { + schema.Extensions["path"] = pathName } return map[string]spec.Schema{fieldName: *schema}, tagRequired, nil diff --git a/testdata/param_structs/structs.go b/testdata/param_structs/structs.go new file mode 100644 index 000000000..2c8673a5d --- /dev/null +++ b/testdata/param_structs/structs.go @@ -0,0 +1,20 @@ +package structs + +type FormModel struct { + Foo string `form:"f" binding:"required" validate:"max=10"` + // B is another field + B bool +} + +type AuthHeader struct { + // Token is the auth token + Token string `header:"X-Auth-Token" binding:"required"` + // AnotherHeader is another header + AnotherHeader int `validate:"gte=0,lte=10"` +} + +type PathModel struct { + // ID is the id + Identifier int `uri:"id" binding:"required"` + Name string `validate:"max=10"` +}