Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openapi3filter: RegisterBodyDecoder for application/zip #730

Merged
merged 17 commits into from
Jan 4, 2023
56 changes: 52 additions & 4 deletions openapi3filter/req_resp_decoder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package openapi3filter

import (
"archive/zip"
"bytes"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -1004,15 +1006,16 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef,
}

func init() {
RegisterBodyDecoder("text/plain", plainBodyDecoder)
RegisterBodyDecoder("application/json", jsonBodyDecoder)
RegisterBodyDecoder("application/json-patch+json", jsonBodyDecoder)
RegisterBodyDecoder("application/x-yaml", yamlBodyDecoder)
RegisterBodyDecoder("application/yaml", yamlBodyDecoder)
RegisterBodyDecoder("application/octet-stream", FileBodyDecoder)
RegisterBodyDecoder("application/problem+json", jsonBodyDecoder)
RegisterBodyDecoder("application/x-www-form-urlencoded", urlencodedBodyDecoder)
RegisterBodyDecoder("application/x-yaml", yamlBodyDecoder)
RegisterBodyDecoder("application/yaml", yamlBodyDecoder)
RegisterBodyDecoder("application/zip", ZipFileBodyDecoder)
fenollp marked this conversation as resolved.
Show resolved Hide resolved
RegisterBodyDecoder("multipart/form-data", multipartBodyDecoder)
RegisterBodyDecoder("application/octet-stream", FileBodyDecoder)
RegisterBodyDecoder("text/plain", plainBodyDecoder)
}

func plainBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) {
Expand Down Expand Up @@ -1221,3 +1224,48 @@ func FileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema
}
return string(data), nil
}

// ZipFileBodyDecoder is a body decoder that decodes a zip file body to a string.
func ZipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) {
buff := bytes.NewBuffer([]byte{})
size, err := io.Copy(buff, body)
if err != nil {
return nil, err
}

zr, err := zip.NewReader(bytes.NewReader(buff.Bytes()), size)
fenollp marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

const bufferSize = 256
content := make([]byte, 0, bufferSize*len(zr.File))
buffer := make([]byte, bufferSize)

for _, f := range zr.File {
func() {
rc, err := f.Open()
if err != nil {
panic(err)
fenollp marked this conversation as resolved.
Show resolved Hide resolved
}
defer func() {
_ = rc.Close()
}()

for {
n, err := rc.Read(buffer)
if 0 < n {
content = append(content, buffer...)
}
if err == io.EOF {
break
}
if err != nil {
panic(err)
}
}
}()
}

return string(content), nil
}
116 changes: 116 additions & 0 deletions openapi3filter/zip_file_upload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package openapi3filter_test

import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"testing"

"github.com/stretchr/testify/require"

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers/gorillamux"
)

func TestValidateZipFileUpload(t *testing.T) {
const spec = `
openapi: 3.0.0
info:
title: 'Validator'
version: 0.0.1
paths:
/test:
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
required:
- file
properties:
file:
type: string
format: binary
responses:
'200':
description: Created
`

loader := openapi3.NewLoader()
doc, err := loader.LoadFromData([]byte(spec))
require.NoError(t, err)

err = doc.Validate(loader.Context)
require.NoError(t, err)

router, err := gorillamux.NewRouter(doc)
require.NoError(t, err)

tests := []struct {
zipData []byte
wantErr bool
}{
{
[]byte{
0x50, 0x4b, 0x03, 0x04, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7c, 0x7d, 0x23, 0x56, 0xcd, 0xfd, 0x67, 0xf8, 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x09, 0x00, 0x1c, 0x00, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2e, 0x74, 0x78, 0x74, 0x55, 0x54, 0x09, 0x00, 0x03, 0xac, 0xce, 0xb3, 0x63, 0xaf, 0xce, 0xb3, 0x63, 0x75, 0x78, 0x0b, 0x00, 0x01, 0x04, 0xf7, 0x01, 0x00, 0x00, 0x04, 0x14, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x0a, 0x50, 0x4b, 0x01, 0x02, 0x1e, 0x03, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7c, 0x7d, 0x23, 0x56, 0xcd, 0xfd, 0x67, 0xf8, 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x09, 0x00, 0x18, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xa4, 0x81, 0x00, 0x00, 0x00, 0x00, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2e, 0x74, 0x78, 0x74, 0x55, 0x54, 0x05, 0x00, 0x03, 0xac, 0xce, 0xb3, 0x63, 0x75, 0x78, 0x0b, 0x00, 0x01, 0x04, 0xf7, 0x01, 0x00, 0x00, 0x04, 0x14, 0x00, 0x00, 0x00, 0x50, 0x4b, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00,
},
false,
},
{
[]byte{
0x50, 0x4b, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, // No entry
true,
},
}
for _, tt := range tests {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

{ // Add file data
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", `form-data; name="file"; filename="hello.zip"`)
h.Set("Content-Type", "application/zip")

fw, err := writer.CreatePart(h)
require.NoError(t, err)
_, err = io.Copy(fw, bytes.NewReader(tt.zipData))

require.NoError(t, err)
}

writer.Close()

req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes()))
require.NoError(t, err)

req.Header.Set("Content-Type", writer.FormDataContentType())

route, pathParams, err := router.FindRoute(req)
require.NoError(t, err)

if err = openapi3filter.ValidateRequestBody(
context.Background(),
&openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
},
route.Operation.RequestBody.Value,
); err != nil {
if !tt.wantErr {
t.Errorf("got %v", err)
}
continue
}
if tt.wantErr {
t.Errorf("want err")
}
}
}