diff --git a/header.go b/header.go index e7616df688..e960177715 100644 --- a/header.go +++ b/header.go @@ -371,6 +371,21 @@ func (h *RequestHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) } +// ContentEncoding returns Content-Encoding header value. +func (h *RequestHeader) ContentEncoding() []byte { + return peekArgBytes(h.h, strContentEncoding) +} + +// SetContentEncoding sets Content-Encoding header value. +func (h *RequestHeader) SetContentEncoding(contentEncoding string) { + h.SetBytesK(strContentEncoding, contentEncoding) +} + +// SetContentEncodingBytes sets Content-Encoding header value. +func (h *RequestHeader) SetContentEncodingBytes(contentEncoding []byte) { + h.setNonSpecial(strContentEncoding, contentEncoding) +} + // SetMultipartFormBoundary sets the following Content-Type: // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. diff --git a/http.go b/http.go index 9db9be5815..0cf1bfed1f 100644 --- a/http.go +++ b/http.go @@ -486,6 +486,48 @@ func inflateData(p []byte) ([]byte, error) { return bb.B, nil } +var ErrContentEncodingUnsupported = errors.New("unsupported Content-Encoding") + +// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli. +// +// This method may be used if the response header contains +// 'Content-Encoding' for reading uncompressed request body. +// Use Body for reading the raw request body. +func (req *Request) BodyUncompressed() ([]byte, error) { + switch string(req.Header.ContentEncoding()) { + case "": + return req.Body(), nil + case "deflate": + return req.BodyInflate() + case "gzip": + return req.BodyGunzip() + case "br": + return req.BodyUnbrotli() + default: + return nil, ErrContentEncodingUnsupported + } +} + +// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli. +// +// This method may be used if the response header contains +// 'Content-Encoding' for reading uncompressed response body. +// Use Body for reading the raw response body. +func (resp *Response) BodyUncompressed() ([]byte, error) { + switch string(resp.Header.ContentEncoding()) { + case "": + return resp.Body(), nil + case "deflate": + return resp.BodyInflate() + case "gzip": + return resp.BodyGunzip() + case "br": + return resp.BodyUnbrotli() + default: + return nil, ErrContentEncodingUnsupported + } +} + // BodyWriteTo writes request body to w. func (req *Request) BodyWriteTo(w io.Writer) error { if req.bodyStream != nil { diff --git a/http_test.go b/http_test.go index 8d99cdd9ed..7d6c8d5b0e 100644 --- a/http_test.go +++ b/http_test.go @@ -347,6 +347,12 @@ func testResponseBodyStreamDeflate(t *testing.T, body []byte, bodySize int) { if !bytes.Equal(respBody, body) { t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) } + // check for invalid + resp.SetBodyRaw([]byte("invalid")) + _, errDeflate := resp.BodyInflate() + if errDeflate == nil || errDeflate.Error() != "zlib: invalid header" { + t.Fatalf("expected error: 'zlib: invalid header' but was %v", errDeflate) + } } func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) { @@ -375,6 +381,12 @@ func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) { if !bytes.Equal(respBody, body) { t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) } + // check for invalid + resp.SetBodyRaw([]byte("invalid")) + _, errUnzip := resp.BodyGunzip() + if errUnzip == nil || errUnzip.Error() != "unexpected EOF" { + t.Fatalf("expected error: 'unexpected EOF' but was %v", errUnzip) + } } func TestResponseWriteGzipNilBody(t *testing.T) { @@ -405,6 +417,46 @@ func TestResponseWriteDeflateNilBody(t *testing.T) { } } +func TestResponseBodyUncompressed(t *testing.T) { + body := "body" + var r Response + r.SetBodyStream(bytes.NewReader([]byte(body)), len(body)) + + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := r.WriteDeflate(bw); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var resp Response + br := bufio.NewReader(w) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce := resp.Header.ContentEncoding() + if string(ce) != "deflate" { + t.Fatalf("unexpected Content-Encoding: %s", ce) + } + respBody, err := resp.BodyUncompressed() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(respBody) != body { + t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) + } + + // check for invalid encoding + resp.Header.SetContentEncoding("invalid") + _, decodeErr := resp.BodyUncompressed() + if decodeErr != ErrContentEncodingUnsupported { + t.Fatalf("unexpected error: %v", decodeErr) + } +} + func TestResponseSwapBodySerial(t *testing.T) { t.Parallel() @@ -1145,8 +1197,8 @@ func TestRequestReadGzippedBody(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if string(r.Header.Peek(HeaderContentEncoding)) != "gzip" { - t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek(HeaderContentEncoding), "gzip") + if string(r.Header.ContentEncoding()) != "gzip" { + t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.ContentEncoding(), "gzip") } if r.Header.ContentLength() != len(body) { t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body))