From ba47f30a14555e4325410154bb36fe51463ff7ba Mon Sep 17 00:00:00 2001 From: sorintm <112782063+sorintm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:08:00 +0100 Subject: [PATCH] Expose request/response validation options in the middleware Validator (#608) --- openapi3filter/middleware.go | 10 ++++++++++ openapi3filter/middleware_test.go | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/openapi3filter/middleware.go b/openapi3filter/middleware.go index 3709faf9b..3bcb9db43 100644 --- a/openapi3filter/middleware.go +++ b/openapi3filter/middleware.go @@ -16,6 +16,7 @@ type Validator struct { errFunc ErrFunc logFunc LogFunc strict bool + options Options } // ErrFunc handles errors that may occur during validation. @@ -106,6 +107,13 @@ func Strict(strict bool) ValidatorOption { } } +// ValidationOptions sets request/response validation options on the validator. +func ValidationOptions(options Options) ValidatorOption { + return func(v *Validator) { + v.options = options + } +} + // Middleware returns an http.Handler which wraps the given handler with // request and response validation. func (v *Validator) Middleware(h http.Handler) http.Handler { @@ -120,6 +128,7 @@ func (v *Validator) Middleware(h http.Handler) http.Handler { Request: r, PathParams: pathParams, Route: route, + Options: &v.options, } if err = ValidateRequest(r.Context(), requestValidationInput); err != nil { v.logFunc("invalid request", err) @@ -141,6 +150,7 @@ func (v *Validator) Middleware(h http.Handler) http.Handler { Status: wr.statusCode(), Header: wr.Header(), Body: ioutil.NopCloser(bytes.NewBuffer(wr.bodyContents())), + Options: &v.options, }); err != nil { v.logFunc("invalid response", err) if v.strict { diff --git a/openapi3filter/middleware_test.go b/openapi3filter/middleware_test.go index c6a5e9bc2..ff6059c9d 100644 --- a/openapi3filter/middleware_test.go +++ b/openapi3filter/middleware_test.go @@ -328,6 +328,27 @@ func TestValidator(t *testing.T) { body: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, }, strict: false, + }, { + name: "POST response status code not in spec (return 200, spec only has 201)", + handler: validatorTestHandler{ + postBody: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + errStatusCode: 200, + errBody: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }.withDefaults(), + options: []openapi3filter.ValidatorOption{openapi3filter.ValidationOptions(openapi3filter.Options{ + IncludeResponseStatus: true, + })}, + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10}`, + contentType: "application/json", + }, + response: testResponse{ + statusCode: 200, + body: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }, + strict: false, }} for i, test := range tests { t.Logf("test#%d: %s", i, test.name)