Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gzhttp: Add zstd to transport #400

Merged
merged 6 commits into from Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions gzhttp/gzip_test.go
Expand Up @@ -1133,6 +1133,9 @@ func newTestHandler(body []byte) http.Handler {
case "/gzipped":
w.Header().Set("Content-Encoding", "gzip")
w.Write(body)
case "/zstd":
w.Header().Set("Content-Encoding", "zstd")
w.Write(body)
default:
w.Write(body)
}
Expand Down
115 changes: 105 additions & 10 deletions gzhttp/transport.go
Expand Up @@ -7,24 +7,58 @@ package gzhttp
import (
"io"
"net/http"
"strings"
"sync"

"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
)

// Transport will wrap a transport with a custom gzip handler
// Transport will wrap a transport with a custom handler
// that will request gzip and automatically decompress it.
// Using this is significantly faster than using the default transport.
func Transport(parent http.RoundTripper) http.RoundTripper {
return gzRoundtripper{parent: parent}
func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true}
for _, o := range opts {
o(&g)
}
var ae []string
if g.withZstd {
ae = append(ae, "zstd")
}
if g.withGzip {
ae = append(ae, "gzip")
}
g.acceptEncoding = strings.Join(ae, ",")
return &g
}

type transportOption func(c *gzRoundtripper)

// TransportEnableZstd will send Zstandard as a compression option to the server.
// Enabled by default, but may be disabled if future problems arise.
func TransportEnableZstd(b bool) transportOption {
return func(c *gzRoundtripper) {
c.withZstd = b
}
}

// TransportEnableGzip will send Gzip as a compression option to the server.
// Enabled by default.
func TransportEnableGzip(b bool) transportOption {
return func(c *gzRoundtripper) {
c.withGzip = b
}
}

type gzRoundtripper struct {
parent http.RoundTripper
parent http.RoundTripper
acceptEncoding string
withZstd, withGzip bool
}

func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
var requestedGzip bool
func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
var requestedComp bool
if req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
req.Method != "HEAD" {
Expand All @@ -40,20 +74,31 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
requestedGzip = true
req.Header.Set("Accept-Encoding", "gzip")
requestedComp = len(g.acceptEncoding) > 0
req.Header.Set("Accept-Encoding", g.acceptEncoding)
}

resp, err := g.parent.RoundTrip(req)
if err != nil || !requestedGzip {
if err != nil || !requestedComp {
return resp, err
}
if asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {

// Decompress
if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
resp.Body = &gzipReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}
if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
resp.Body = &zstdReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}

return resp, nil
}

Expand Down Expand Up @@ -114,3 +159,53 @@ func lower(b byte) byte {
}
return b
}

// zstdReaderPool pools zstd decoders.
var zstdReaderPool sync.Pool

// zstdReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type zstdReader struct {
body io.ReadCloser // underlying HTTP/1 response body framing
zr *zstd.Decoder // lazily-initialized gzip reader
zerr error // any error from zstd.NewReader; sticky
}

func (zr *zstdReader) Read(p []byte) (n int, err error) {
if zr.zerr != nil {
return 0, zr.zerr
}
if zr.zr == nil {
if zr.zerr == nil {
reader, ok := zstdReaderPool.Get().(*zstd.Decoder)
if ok {
zr.zerr = reader.Reset(zr.body)
zr.zr = reader
} else {
zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1))
}
}
if zr.zerr != nil {
return 0, zr.zerr
}
}
n, err = zr.zr.Read(p)
if err != nil {
// Usually this will be io.EOF,
// stash the decoder and keep the error.
zr.zr.Reset(nil)
zstdReaderPool.Put(zr.zr)
zr.zr = nil
zr.zerr = err
}
return
}

func (zr *zstdReader) Close() error {
if zr.zr != nil {
zr.zr.Reset(nil)
zstdReaderPool.Put(zr.zr)
zr.zr = nil
}
return zr.body.Close()
klauspost marked this conversation as resolved.
Show resolved Hide resolved
}