Skip to content

Commit

Permalink
gzhttp: Always delete HeaderNoCompression (#683)
Browse files Browse the repository at this point in the history
* gzhttp: Always delete `HeaderNoCompression`

Also when it cannot be gzipped.

* Also remove header when starting to write
  • Loading branch information
klauspost committed Oct 26, 2022
1 parent 1619336 commit 9559b03
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 7 deletions.
78 changes: 72 additions & 6 deletions gzhttp/compress.go
Expand Up @@ -100,12 +100,13 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) {
}
w.buf = append(w.buf, b[:toAdd]...)
remain := b[toAdd:]
hdr := w.Header()

// Only continue if they didn't already choose an encoding or a known unhandled content length or type.
if len(w.Header()[HeaderNoCompression]) == 0 && w.Header().Get(contentEncoding) == "" && w.Header().Get(contentRange) == "" {
if len(hdr[HeaderNoCompression]) == 0 && hdr.Get(contentEncoding) == "" && hdr.Get(contentRange) == "" {
// Check more expensive parts now.
cl, _ := atoi(w.Header().Get(contentLength))
ct := w.Header().Get(contentType)
cl, _ := atoi(hdr.Get(contentLength))
ct := hdr.Get(contentType)
if cl == 0 || cl >= w.minSize && (ct == "" || w.contentTypeFilter(ct)) {
// If the current buffer is less than minSize and a Content-Length isn't set, then wait until we have more data.
if len(w.buf) < w.minSize && cl == 0 {
Expand All @@ -121,8 +122,8 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) {

// Handles the intended case of setting a nil Content-Type (as for http/server or http/fs)
// Set the header only if the key does not exist
if _, ok := w.Header()[contentType]; w.setContentType && !ok {
w.Header().Set(contentType, ct)
if _, ok := hdr[contentType]; w.setContentType && !ok {
hdr.Set(contentType, ct)
}

// If the Content-Type is acceptable to GZIP, initialize the GZIP writer.
Expand Down Expand Up @@ -388,7 +389,8 @@ func NewWrapper(opts ...option) (func(http.Handler) http.HandlerFunc, error) {
h.ServeHTTP(gw, r)
}
} else {
h.ServeHTTP(w, r)
h.ServeHTTP(newNoCompressResponseWriter(w), r)
w.Header().Del(HeaderNoCompression)
}
}
}, nil
Expand Down Expand Up @@ -743,3 +745,67 @@ func atoi(s string) (int, bool) {
i64, err := strconv.ParseInt(s, 10, 0)
return int(i64), err == nil
}

// newNoCompressResponseWriter will return a response writer that
// cleans up compression artifacts.
// Depending on whether http.Hijacker is supported the returned will as well.
func newNoCompressResponseWriter(w http.ResponseWriter) http.ResponseWriter {
n := &noCompressResponseWriter{hw: w}
if hj, ok := w.(http.Hijacker); ok {
x := struct {
http.ResponseWriter
http.Hijacker
http.Flusher
}{
ResponseWriter: n,
Hijacker: hj,
Flusher: n,
}
return x
}

return n
}

// noCompressResponseWriter filters out HeaderNoCompression.
type noCompressResponseWriter struct {
hw http.ResponseWriter
hdrCleaned bool
}

func (n *noCompressResponseWriter) CloseNotify() <-chan bool {
if cn, ok := n.hw.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
return nil
}

func (n *noCompressResponseWriter) Flush() {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
if f, ok := n.hw.(http.Flusher); ok {
f.Flush()
}
}

func (n *noCompressResponseWriter) Header() http.Header {
return n.hw.Header()
}

func (n *noCompressResponseWriter) Write(bytes []byte) (int, error) {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
return n.hw.Write(bytes)
}

func (n *noCompressResponseWriter) WriteHeader(statusCode int) {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
n.hw.WriteHeader(statusCode)
}
66 changes: 65 additions & 1 deletion gzhttp/compress_test.go
Expand Up @@ -748,9 +748,9 @@ func TestContentTypes(t *testing.T) {
})
t.Run("disable-"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.WriteHeader(http.StatusOK)
w.Write(testBody)
})

Expand All @@ -765,6 +765,70 @@ func TestContentTypes(t *testing.T) {

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("head-req"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.WriteHeader(http.StatusOK)
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("HEAD", "/whatever", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("head-req-no-ok"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("HEAD", "/whatever", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("req-no-ok-write"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.Write(testBody)
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("GET", "/whatever", nil)
req.Header.Set("Accept-Encoding", "")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
}
}
Expand Down

0 comments on commit 9559b03

Please sign in to comment.