From f1360474effe1390d6e49ce5ffcfd831d1614260 Mon Sep 17 00:00:00 2001 From: ShouheiNishi <96609867+ShouheiNishi@users.noreply.github.com> Date: Wed, 14 Dec 2022 17:25:06 +0900 Subject: [PATCH] Use and update GetBody() member of request (#704) --- openapi3filter/validate_request.go | 23 +++++++++++++++++++++-- openapi3filter/validate_request_test.go | 6 ++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 4acb9ff1f..8a747724e 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "net/http" "sort" @@ -216,7 +217,19 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } } // Put the data back into the input - req.Body = ioutil.NopCloser(bytes.NewReader(data)) + req.Body = nil + if req.GetBody != nil { + if req.Body, err = req.GetBody(); err != nil { + req.Body = nil + } + } + if req.Body == nil { + req.ContentLength = int64(len(data)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } + req.Body, _ = req.GetBody() // no error return + } } if len(data) == 0 { @@ -292,8 +305,14 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } } // Put the data back into the input - req.Body = ioutil.NopCloser(bytes.NewReader(data)) + if req.Body != nil { + req.Body.Close() + } req.ContentLength = int64(len(data)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } + req.Body, _ = req.GetBody() // no error return } return nil diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go index 450ee5988..8da550ce0 100644 --- a/openapi3filter/validate_request_test.go +++ b/openapi3filter/validate_request_test.go @@ -212,6 +212,12 @@ components: assert.Equal(t, contentLen, bodySize, "expect ContentLength %d to equal body size %d", contentLen, bodySize) bodyModified := originalBodySize != bodySize assert.Equal(t, bodyModified, tc.expectedModification, "expect request body modification happened: %t, expected %t", bodyModified, tc.expectedModification) + + validationInput.Request.Body, err = validationInput.Request.GetBody() + assert.NoError(t, err, "unable to re-generate body by GetBody(): %v", err) + body2, err := io.ReadAll(validationInput.Request.Body) + assert.NoError(t, err, "unable to read request body: %v", err) + assert.Equal(t, body, body2, "body by GetBody() is not matched") }) } }