Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use and update GetBody() member of request #704

Merged
merged 1 commit into from Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 21 additions & 2 deletions openapi3filter/validate_request.go
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sort"
Expand Down Expand Up @@ -216,7 +217,19 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}
}
// Put the data back into the input
req.Body = ioutil.NopCloser(bytes.NewReader(data))
req.Body = nil
if req.GetBody != nil {
if req.Body, err = req.GetBody(); err != nil {
req.Body = nil
}
}
if req.Body == nil {
req.ContentLength = int64(len(data))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
req.Body, _ = req.GetBody() // no error return
}
}

if len(data) == 0 {
Expand Down Expand Up @@ -292,8 +305,14 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}
}
// Put the data back into the input
req.Body = ioutil.NopCloser(bytes.NewReader(data))
if req.Body != nil {
req.Body.Close()
}
req.ContentLength = int64(len(data))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
req.Body, _ = req.GetBody() // no error return
}

return nil
Expand Down
6 changes: 6 additions & 0 deletions openapi3filter/validate_request_test.go
Expand Up @@ -212,6 +212,12 @@ components:
assert.Equal(t, contentLen, bodySize, "expect ContentLength %d to equal body size %d", contentLen, bodySize)
bodyModified := originalBodySize != bodySize
assert.Equal(t, bodyModified, tc.expectedModification, "expect request body modification happened: %t, expected %t", bodyModified, tc.expectedModification)

validationInput.Request.Body, err = validationInput.Request.GetBody()
assert.NoError(t, err, "unable to re-generate body by GetBody(): %v", err)
body2, err := io.ReadAll(validationInput.Request.Body)
assert.NoError(t, err, "unable to read request body: %v", err)
assert.Equal(t, body, body2, "body by GetBody() is not matched")
})
}
}