From 6fe3b660f9369d3bd49c69abce601ae927a11f18 Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Wed, 12 Oct 2022 15:40:32 +0200 Subject: [PATCH] A more lenient fix to #624, thanks to @orensolo Signed-off-by: Pierre Fenoll --- openapi3filter/issue624_test.go | 25 +++++++++++++++---------- openapi3filter/req_resp_decoder.go | 18 +++++++++--------- openapi3filter/validate_request.go | 22 +++++++--------------- openapi3filter/validation_test.go | 13 ++++++------- 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/openapi3filter/issue624_test.go b/openapi3filter/issue624_test.go index d93682e52..1fdbdea34 100644 --- a/openapi3filter/issue624_test.go +++ b/openapi3filter/issue624_test.go @@ -48,17 +48,22 @@ paths: router, err := gorillamux.NewRouter(doc) require.NoError(t, err) - httpReq, err := http.NewRequest(http.MethodGet, `/items?test=test1`, nil) - require.NoError(t, err) - route, pathParams, err := router.FindRoute(httpReq) - require.NoError(t, err) + for _, testcase := range []string{`test1`, `test[1`} { + t.Run(testcase, func(t *testing.T) { + httpReq, err := http.NewRequest(http.MethodGet, `/items?test=`+testcase, nil) + require.NoError(t, err) - requestValidationInput := &RequestValidationInput{ - Request: httpReq, - PathParams: pathParams, - Route: route, + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + } + err = ValidateRequest(ctx, requestValidationInput) + require.NoError(t, err) + }) } - err = ValidateRequest(ctx, requestValidationInput) - require.NoError(t, err) } diff --git a/openapi3filter/req_resp_decoder.go b/openapi3filter/req_resp_decoder.go index 515d09cd1..ca8194432 100644 --- a/openapi3filter/req_resp_decoder.go +++ b/openapi3filter/req_resp_decoder.go @@ -110,8 +110,11 @@ func invalidSerializationMethodErr(sm *openapi3.SerializationMethod) error { // Decodes a parameter defined via the content property as an object. It uses // the user specified decoder, or our build-in decoder for application/json func decodeContentParameter(param *openapi3.Parameter, input *RequestValidationInput) ( - value interface{}, schema *openapi3.Schema, found bool, err error) { - + value interface{}, + schema *openapi3.Schema, + found bool, + err error, +) { var paramValues []string switch param.In { case openapi3.ParameterInPath: @@ -186,12 +189,9 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) } outSchema = mt.Schema.Value - unmarshal := func(encoded string) (decoded interface{}, err error) { + unmarshal := func(encoded string, paramSchema *openapi3.SchemaRef) (decoded interface{}, err error) { if err = json.Unmarshal([]byte(encoded), &decoded); err != nil { - const specialJSONChars = `[]{}":,` - if !strings.ContainsAny(encoded, specialJSONChars) { - // A string in a query parameter is not serialized with (double) quotes - // as JSON would expect, so let's fallback to that. + if paramSchema != nil && paramSchema.Value.Type != "object" { decoded, err = encoded, nil } } @@ -199,7 +199,7 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) } if len(values) == 1 { - if outValue, err = unmarshal(values[0]); err != nil { + if outValue, err = unmarshal(values[0], mt.Schema); err != nil { err = fmt.Errorf("error unmarshaling parameter %q", param.Name) return } @@ -207,7 +207,7 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) outArray := make([]interface{}, 0, len(values)) for _, v := range values { var item interface{} - if item, err = unmarshal(v); err != nil { + if item, err = unmarshal(v, outSchema.Items); err != nil { err = fmt.Errorf("error unmarshaling parameter %q", param.Name) return } diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index b09987f74..4b0bd3413 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -28,11 +28,8 @@ var ErrInvalidEmptyValue = errors.New("empty value is not allowed") // // Note: One can tune the behavior of uniqueItems: true verification // by registering a custom function with openapi3.RegisterArrayUniqueItemsChecker -func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { - var ( - err error - me openapi3.MultiError - ) +func ValidateRequest(ctx context.Context, input *RequestValidationInput) (err error) { + var me openapi3.MultiError options := input.Options if options == nil { @@ -52,9 +49,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { } if security != nil { if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { - return err + return } - if err != nil { me = append(me, err) } @@ -70,9 +66,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { } if err = ValidateParameter(ctx, input, parameter); err != nil && !options.MultiError { - return err + return } - if err != nil { me = append(me, err) } @@ -81,9 +76,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { // For each parameter of the Operation for _, parameter := range operationParameters { if err = ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError { - return err + return } - if err != nil { me = append(me, err) } @@ -93,9 +87,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { requestBody := operation.RequestBody if requestBody != nil && !options.ExcludeRequestBody { if err = ValidateRequestBody(ctx, input, requestBody.Value); err != nil && !options.MultiError { - return err + return } - if err != nil { me = append(me, err) } @@ -104,8 +97,7 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { if len(me) > 0 { return me } - - return nil + return } // ValidateParameter validates a parameter's value by JSON schema. diff --git a/openapi3filter/validation_test.go b/openapi3filter/validation_test.go index cd1fa8990..cdbeb1262 100644 --- a/openapi3filter/validation_test.go +++ b/openapi3filter/validation_test.go @@ -198,7 +198,7 @@ func TestFilter(t *testing.T) { } err = ValidateResponse(context.Background(), responseValidationInput) require.NoError(t, err) - return err + return nil } expect := func(req ExampleRequest, resp ExampleResponse) error { return expectWithDecoder(req, resp, nil) @@ -207,13 +207,12 @@ func TestFilter(t *testing.T) { resp := ExampleResponse{ Status: 200, } - // Test paths + // Test paths req := ExampleRequest{ Method: "POST", URL: "http://example.com/api/prefix/v/suffix", } - err = expect(req, resp) require.NoError(t, err) @@ -328,7 +327,7 @@ func TestFilter(t *testing.T) { // enough. req = ExampleRequest{ Method: "POST", - URL: "http://example.com/api/prefix/v/suffix?contentArg={\"name\":\"bob\", \"id\":\"a\"}", + URL: `http://example.com/api/prefix/v/suffix?contentArg={"name":"bob", "id":"a"}`, } err = expect(req, resp) require.NoError(t, err) @@ -336,7 +335,7 @@ func TestFilter(t *testing.T) { // Now it should fail due the ID being too long req = ExampleRequest{ Method: "POST", - URL: "http://example.com/api/prefix/v/suffix?contentArg={\"name\":\"bob\", \"id\":\"EXCEEDS_MAX_LENGTH\"}", + URL: `http://example.com/api/prefix/v/suffix?contentArg={"name":"bob", "id":"EXCEEDS_MAX_LENGTH"}`, } err = expect(req, resp) require.IsType(t, &RequestError{}, err) @@ -351,7 +350,7 @@ func TestFilter(t *testing.T) { req = ExampleRequest{ Method: "POST", - URL: "http://example.com/api/prefix/v/suffix?contentArg2={\"name\":\"bob\", \"id\":\"a\"}", + URL: `http://example.com/api/prefix/v/suffix?contentArg2={"name":"bob", "id":"a"}`, } err = expectWithDecoder(req, resp, customDecoder) require.NoError(t, err) @@ -359,7 +358,7 @@ func TestFilter(t *testing.T) { // Now it should fail due the ID being too long req = ExampleRequest{ Method: "POST", - URL: "http://example.com/api/prefix/v/suffix?contentArg2={\"name\":\"bob\", \"id\":\"EXCEEDS_MAX_LENGTH\"}", + URL: `http://example.com/api/prefix/v/suffix?contentArg2={"name":"bob", "id":"EXCEEDS_MAX_LENGTH"}`, } err = expectWithDecoder(req, resp, customDecoder) require.IsType(t, &RequestError{}, err)