Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gzhttp: Add zstd to transport #400

Merged
merged 6 commits into from Mar 1, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 50 additions & 4 deletions gzhttp/transport.go
Expand Up @@ -10,17 +10,19 @@ import (
"sync"

"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
)

// Transport will wrap a transport with a custom gzip handler
// Transport will wrap a transport with a custom handler
// that will request gzip and automatically decompress it.
// Using this is significantly faster than using the default transport.
func Transport(parent http.RoundTripper) http.RoundTripper {
return gzRoundtripper{parent: parent}
return gzRoundtripper{parent: parent, withZstd: true}
}

type gzRoundtripper struct {
parent http.RoundTripper
parent http.RoundTripper
withZstd bool
}

func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -41,7 +43,12 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
requestedGzip = true
req.Header.Set("Accept-Encoding", "gzip")
if g.withZstd {
// Swap when we want zstd to default.
req.Header.Set("Accept-Encoding", "gzip,zstd")
} else {
req.Header.Set("Accept-Encoding", "gzip")
}
}
resp, err := g.parent.RoundTrip(req)
if err != nil || !requestedGzip {
Expand All @@ -54,6 +61,16 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp.ContentLength = -1
resp.Uncompressed = true
}
if g.withZstd {
if asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
resp.Body = &zstdReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}
}

return resp, nil
}

Expand Down Expand Up @@ -114,3 +131,32 @@ func lower(b byte) byte {
}
return b
}

// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type zstdReader struct {
body io.ReadCloser // underlying HTTP/1 response body framing
zr *zstd.Decoder // lazily-initialized gzip reader
zerr error // any error from zstd.NewReader; sticky
}

func (zr *zstdReader) Read(p []byte) (n int, err error) {
if zr.zr == nil {
if zr.zerr == nil {
zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20))
}
if zr.zerr != nil {
return 0, zr.zerr
}
}

return zr.zr.Read(p)
}

func (zr *zstdReader) Close() error {
if zr.zr != nil {
zr.zr.Close()
zr.zr = nil
}
return zr.body.Close()
klauspost marked this conversation as resolved.
Show resolved Hide resolved
}