Skip to content

Commit

Permalink
openapi3filter: support deepObject with nested objects and array para…
Browse files Browse the repository at this point in the history
…meters (#911)

* support arrays

* support for nested objects

* update

* dont check found on error and parse primitives

* update

* make delimiters internal

* equal error

* ParseError

* additionalProperties attempt

* clean

* single delimiter

* implement nested object array

* extra test

* FIXME - nested array type check

* validate array item types - should refactor

* extra index check

* refactor

* add tests

* complete path for parse errors and remove panic

* remove comment

* full coverage on makeObject

* exit early on parse error
  • Loading branch information
danicc097 committed Feb 20, 2024
1 parent 3bbab36 commit 05ccac2
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 53 deletions.
150 changes: 132 additions & 18 deletions openapi3filter/req_resp_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,12 @@ func (d *urlValuesDecoder) parseValue(v string, schema *openapi3.SchemaRef) (int
}

return parsePrimitive(v, schema)

}

const (
urlDecoderDelimiter = "\x1F" // should not conflict with URL characters
)

func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (map[string]interface{}, bool, error) {
var propsFn func(url.Values) (map[string]string, error)
switch sm.Style {
Expand Down Expand Up @@ -646,12 +649,20 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization
propsFn = func(params url.Values) (map[string]string, error) {
props := make(map[string]string)
for key, values := range params {
groups := regexp.MustCompile(fmt.Sprintf("%s\\[(.+?)\\]", param)).FindAllStringSubmatch(key, -1)
if len(groups) == 0 {
matches := regexp.MustCompile(`\[(.*?)\]`).FindAllStringSubmatch(key, -1)
switch l := len(matches); {
case l == 0:
// A query parameter's name does not match the required format, so skip it.
continue
case l == 1:
props[matches[0][1]] = strings.Join(values, urlDecoderDelimiter)
case l > 1:
kk := []string{}
for _, m := range matches {
kk = append(kk, m[1])
}
props[strings.Join(kk, urlDecoderDelimiter)] = strings.Join(values, urlDecoderDelimiter)
}
props[groups[0][1]] = values[0]
}
if len(props) == 0 {
// HTTP request does not contain query parameters encoded by rules of style "deepObject".
Expand All @@ -662,7 +673,6 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization
default:
return nil, false, invalidSerializationMethodErr(sm)
}

props, err := propsFn(d.values)
if err != nil {
return nil, false, err
Expand All @@ -671,16 +681,30 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization
return nil, false, nil
}

// check the props
val, err := makeObject(props, schema)
if err != nil {
return nil, false, err
}

found := false
for propName := range schema.Value.Properties {
if _, ok := props[propName]; ok {
found = true
break
}

if schema.Value.Type == "array" || schema.Value.Type == "object" {
for k := range props {
path := strings.Split(k, urlDecoderDelimiter)
if _, ok := deepGet(val, path...); ok {
found = true
break
}
}
}
}
val, err := makeObject(props, schema)
return val, found, err

return val, found, nil
}

// headerParamDecoder decodes values of header parameters.
Expand Down Expand Up @@ -845,24 +869,115 @@ func propsFromString(src, propDelim, valueDelim string) (map[string]string, erro
return props, nil
}

func deepGet(m map[string]interface{}, keys ...string) (interface{}, bool) {
for _, key := range keys {
val, ok := m[key]
if !ok {
return nil, false
}
if m, ok = val.(map[string]interface{}); !ok {
return val, true
}
}
return m, true
}

func deepSet(m map[string]interface{}, keys []string, value interface{}) {
for i := 0; i < len(keys)-1; i++ {
key := keys[i]
if _, ok := m[key]; !ok {
m[key] = make(map[string]interface{})
}
m = m[key].(map[string]interface{})
}
m[keys[len(keys)-1]] = value
}

func findNestedSchema(parentSchema *openapi3.SchemaRef, keys []string) (*openapi3.SchemaRef, error) {
currentSchema := parentSchema
for _, key := range keys {
propertySchema, ok := currentSchema.Value.Properties[key]
if !ok {
if currentSchema.Value.AdditionalProperties.Schema == nil {
return nil, fmt.Errorf("nested schema for key %q not found", key)
}
currentSchema = currentSchema.Value.AdditionalProperties.Schema
continue
}
currentSchema = propertySchema
}
return currentSchema, nil
}

// makeObject returns an object that contains properties from props.
// A value of every property is parsed as a primitive value.
// The function returns an error when an error happened while parse object's properties.
func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string]interface{}, error) {
obj := make(map[string]interface{})
for propName, propSchema := range schema.Value.Properties {
value, err := parsePrimitive(props[propName], propSchema)
if err != nil {
if v, ok := err.(*ParseError); ok {
return nil, &ParseError{path: []interface{}{propName}, Cause: v}
switch propSchema.Value.Type {
case "array":
vals := strings.Split(props[propName], urlDecoderDelimiter)
for _, v := range vals {
_, err := parsePrimitive(v, propSchema.Value.Items)
if err != nil {
return nil, handlePropParseError([]string{propName}, err)
}
}
return nil, fmt.Errorf("property %q: %w", propName, err)
obj[propName] = vals
case "object":
for prop := range props {
if !strings.HasPrefix(prop, propName+urlDecoderDelimiter) {
continue
}
mapKeys := strings.Split(prop, urlDecoderDelimiter)
nestedSchema, err := findNestedSchema(schema, mapKeys)
if err != nil {
return nil, &ParseError{path: pathFromKeys(mapKeys), Reason: err.Error()}
}
if nestedSchema.Value.Type == "array" {
vals := strings.Split(props[prop], urlDecoderDelimiter)
for _, v := range vals {
_, err := parsePrimitive(v, nestedSchema.Value.Items)
if err != nil {
return nil, handlePropParseError(mapKeys, err)
}
}
deepSet(obj, mapKeys, vals)
continue
}
value, err := parsePrimitive(props[prop], nestedSchema)
if err != nil {
return nil, handlePropParseError(mapKeys, err)
}
deepSet(obj, mapKeys, value)
}
default:
value, err := parsePrimitive(props[propName], propSchema)
if err != nil {
return nil, handlePropParseError([]string{propName}, err)
}
obj[propName] = value
}
obj[propName] = value
}
return obj, nil
}

func handlePropParseError(path []string, err error) error {
if v, ok := err.(*ParseError); ok {
return &ParseError{path: pathFromKeys(path), Cause: v}
}
return fmt.Errorf("property %q: %w", strings.Join(path, "."), err)
}

func pathFromKeys(kk []string) []interface{} {
path := make([]interface{}, len(kk))
for i, v := range kk {
path[i] = v
}
return path
}

// parseArray returns an array that contains items from a raw array.
// Every item is parsed as a primitive value.
// The function returns an error when an error happened while parse array's items.
Expand Down Expand Up @@ -923,7 +1038,7 @@ func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error)
case "string":
return raw, nil
default:
panic(fmt.Sprintf("schema has non primitive type %q", schema.Value.Type))
return nil, &ParseError{Kind: KindOther, Value: raw, Reason: "schema has non primitive type " + schema.Value.Type}
}
}

Expand Down Expand Up @@ -1180,10 +1295,10 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S
if anyProperties := schema.Value.AdditionalProperties.Has; anyProperties != nil {
switch *anyProperties {
case true:
//additionalProperties: true
// additionalProperties: true
continue
default:
//additionalProperties: false
// additionalProperties: false
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
}
}
Expand Down Expand Up @@ -1300,7 +1415,6 @@ func zipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Sch

return nil
}()

if err != nil {
return nil, err
}
Expand Down

0 comments on commit 05ccac2

Please sign in to comment.