Skip to content

Commit

Permalink
gzhttp: Add zstd to transport (#400)
Browse files Browse the repository at this point in the history
```
goos: windows
goarch: amd64
pkg: github.com/klauspost/compress/gzhttp
cpu: AMD Ryzen 9 3950X 16-Core Processor            
BenchmarkTransport
BenchmarkTransport/gzhttp-32         	    2179	    503006 ns/op	 259.61 MB/s	        25.67 pct	    9623 B/op	      74 allocs/op
BenchmarkTransport/stdlib-32         	    2001	    596271 ns/op	 219.00 MB/s	        25.67 pct	   52275 B/op	      92 allocs/op
BenchmarkTransport/zstd-32           	    3404	    343757 ns/op	 379.87 MB/s	        24.44 pct	    5358 B/op	      69 allocs/op
BenchmarkTransport/gzhttp-par-32     	   47127	     25402 ns/op	5140.75 MB/s	        25.67 pct	    9598 B/op	      72 allocs/op
BenchmarkTransport/stdlib-par-32     	   39920	     30834 ns/op	4235.03 MB/s	        25.67 pct	   52269 B/op	      90 allocs/op
BenchmarkTransport/zstd-par-32       	   68941	     17277 ns/op	7558.02 MB/s	        24.44 pct	    5436 B/op	      67 allocs/op

PASS

Process finished with the exit code 0
```

* [x] Tests added.
  • Loading branch information
klauspost committed Mar 1, 2022
1 parent 2982376 commit bf9102c
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 18 deletions.
3 changes: 3 additions & 0 deletions gzhttp/gzip_test.go
Expand Up @@ -1133,6 +1133,9 @@ func newTestHandler(body []byte) http.Handler {
case "/gzipped":
w.Header().Set("Content-Encoding", "gzip")
w.Write(body)
case "/zstd":
w.Header().Set("Content-Encoding", "zstd")
w.Write(body)
default:
w.Write(body)
}
Expand Down
115 changes: 105 additions & 10 deletions gzhttp/transport.go
Expand Up @@ -7,24 +7,58 @@ package gzhttp
import (
"io"
"net/http"
"strings"
"sync"

"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
)

// Transport will wrap a transport with a custom gzip handler
// Transport will wrap a transport with a custom handler
// that will request gzip and automatically decompress it.
// Using this is significantly faster than using the default transport.
func Transport(parent http.RoundTripper) http.RoundTripper {
return gzRoundtripper{parent: parent}
func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true}
for _, o := range opts {
o(&g)
}
var ae []string
if g.withZstd {
ae = append(ae, "zstd")
}
if g.withGzip {
ae = append(ae, "gzip")
}
g.acceptEncoding = strings.Join(ae, ",")
return &g
}

type transportOption func(c *gzRoundtripper)

// TransportEnableZstd will send Zstandard as a compression option to the server.
// Enabled by default, but may be disabled if future problems arise.
func TransportEnableZstd(b bool) transportOption {
return func(c *gzRoundtripper) {
c.withZstd = b
}
}

// TransportEnableGzip will send Gzip as a compression option to the server.
// Enabled by default.
func TransportEnableGzip(b bool) transportOption {
return func(c *gzRoundtripper) {
c.withGzip = b
}
}

type gzRoundtripper struct {
parent http.RoundTripper
parent http.RoundTripper
acceptEncoding string
withZstd, withGzip bool
}

func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
var requestedGzip bool
func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
var requestedComp bool
if req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
req.Method != "HEAD" {
Expand All @@ -40,20 +74,31 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
requestedGzip = true
req.Header.Set("Accept-Encoding", "gzip")
requestedComp = len(g.acceptEncoding) > 0
req.Header.Set("Accept-Encoding", g.acceptEncoding)
}

resp, err := g.parent.RoundTrip(req)
if err != nil || !requestedGzip {
if err != nil || !requestedComp {
return resp, err
}
if asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {

// Decompress
if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
resp.Body = &gzipReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}
if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
resp.Body = &zstdReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}

return resp, nil
}

Expand Down Expand Up @@ -114,3 +159,53 @@ func lower(b byte) byte {
}
return b
}

// zstdReaderPool pools zstd decoders.
var zstdReaderPool sync.Pool

// zstdReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type zstdReader struct {
body io.ReadCloser // underlying HTTP/1 response body framing
zr *zstd.Decoder // lazily-initialized gzip reader
zerr error // any error from zstd.NewReader; sticky
}

func (zr *zstdReader) Read(p []byte) (n int, err error) {
if zr.zerr != nil {
return 0, zr.zerr
}
if zr.zr == nil {
if zr.zerr == nil {
reader, ok := zstdReaderPool.Get().(*zstd.Decoder)
if ok {
zr.zerr = reader.Reset(zr.body)
zr.zr = reader
} else {
zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1))
}
}
if zr.zerr != nil {
return 0, zr.zerr
}
}
n, err = zr.zr.Read(p)
if err != nil {
// Usually this will be io.EOF,
// stash the decoder and keep the error.
zr.zr.Reset(nil)
zstdReaderPool.Put(zr.zr)
zr.zr = nil
zr.zerr = err
}
return
}

func (zr *zstdReader) Close() error {
if zr.zr != nil {
zr.zr.Reset(nil)
zstdReaderPool.Put(zr.zr)
zr.zr = nil
}
return zr.body.Close()
}

0 comments on commit bf9102c

Please sign in to comment.