diff --git a/pkg/app/fs.go b/pkg/app/fs.go index 644a9e93c..7b468ba19 100644 --- a/pkg/app/fs.go +++ b/pkg/app/fs.go @@ -899,7 +899,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { hdr := &ctx.Response.Header if ff.compressed { - hdr.SetCanonical(bytestr.StrContentEncoding, bytestr.StrGzip) + hdr.SetContentEncodingBytes(bytestr.StrGzip) } statusCode := consts.StatusOK diff --git a/pkg/app/fs_test.go b/pkg/app/fs_test.go index 4e1a02bf6..389fe9e40 100644 --- a/pkg/app/fs_test.go +++ b/pkg/app/fs_test.go @@ -160,7 +160,7 @@ func TestServeFileHead(t *testing.T) { t.Fatalf("unexpected error: %s", err) } - ce := r.Header.Peek(consts.HeaderContentEncoding) + ce := r.Header.ContentEncoding() if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } @@ -250,8 +250,7 @@ func TestServeFileCompressed(t *testing.T) { if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } - - ce := r.Header.Peek(consts.HeaderContentEncoding) + ce := r.Header.ContentEncoding() if string(ce) != "gzip" { t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") } @@ -287,7 +286,7 @@ func TestServeFileUncompressed(t *testing.T) { t.Fatalf("unexpected error: %s", err) } - ce := r.Header.Peek(consts.HeaderContentEncoding) + ce := r.Header.ContentEncoding() if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } @@ -534,7 +533,7 @@ func testFSCompress(t *testing.T, h HandlerFunc, filePath string) { if r.StatusCode() != consts.StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r.StatusCode(), consts.StatusOK, filePath) } - ce := r.Header.Peek(consts.HeaderContentEncoding) + ce := r.Header.ContentEncoding() if string(ce) != "" { t.Fatalf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } @@ -553,7 +552,7 @@ func testFSCompress(t *testing.T, h HandlerFunc, filePath string) { if r.StatusCode() != consts.StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r.StatusCode(), consts.StatusOK, filePath) } - ce = r.Header.Peek(consts.HeaderContentEncoding) + ce = r.Header.ContentEncoding() if string(ce) != "gzip" { t.Fatalf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) } diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index a3f94bc1c..dcd3c9aa8 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -500,18 +500,21 @@ func verifyResponse(t *testing.T, zr network.Reader, expectedStatusCode int, exp if !bytes.Equal(r.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", r.Body(), []byte(expectedBody)) } - verifyResponseHeader(t, &r.Header, expectedStatusCode, len(r.Body()), expectedContentType) + verifyResponseHeader(t, &r.Header, expectedStatusCode, len(r.Body()), expectedContentType, "") } -func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType string) { +func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) { if h.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) } if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } - if string(h.Peek(consts.HeaderContentType)) != expectedContentType { - t.Fatalf("Unexpected content type %q. Expected %q", h.Peek(consts.HeaderContentType), expectedContentType) + if string(h.ContentType()) != expectedContentType { + t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType) + } + if string(h.ContentEncoding()) != expectedContentEncoding { + t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding) } } diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index a01966aa3..56f553bcc 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -117,6 +117,7 @@ type ResponseHeader struct { statusCode int contentLength int contentLengthBytes []byte + contentEncoding []byte contentType []byte server []byte @@ -224,6 +225,7 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.statusCode = h.statusCode dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) + dst.contentEncoding = append(dst.contentEncoding[:0], h.contentEncoding...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.server = append(dst.server[:0], h.server...) dst.h = copyArgs(dst.h, h.h) @@ -258,6 +260,10 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) { if len(contentType) > 0 { f(bytestr.StrContentType, contentType) } + contentEncoding := h.ContentEncoding() + if len(contentEncoding) > 0 { + f(bytestr.StrContentEncoding, contentEncoding) + } server := h.Server() if len(server) > 0 { f(bytestr.StrServer, server) @@ -474,6 +480,7 @@ func (h *ResponseHeader) ResetSkipNormalize() { h.statusCode = 0 h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] + h.contentEncoding = h.contentEncoding[:0] h.contentType = h.contentType[:0] h.server = h.server[:0] @@ -663,6 +670,8 @@ func (h *ResponseHeader) peek(key []byte) []byte { switch string(key) { case consts.HeaderContentType: return h.ContentType() + case consts.HeaderContentEncoding: + return h.ContentEncoding() case consts.HeaderServer: return h.Server() case consts.HeaderConnection: @@ -684,6 +693,21 @@ func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) } +// ContentEncoding returns Content-Encoding header value. +func (h *ResponseHeader) ContentEncoding() []byte { + return h.contentEncoding +} + +// SetContentEncoding sets Content-Encoding header value. +func (h *ResponseHeader) SetContentEncoding(contentEncoding string) { + h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...) +} + +// SetContentEncodingBytes sets Content-Encoding header value. +func (h *ResponseHeader) SetContentEncodingBytes(contentEncoding []byte) { + h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...) +} + func (h *ResponseHeader) SetContentLengthBytes(contentLength []byte) { h.contentLengthBytes = append(h.contentLengthBytes[:0], contentLength...) } @@ -746,7 +770,10 @@ func (h *ResponseHeader) AppendBytes(dst []byte) []byte { dst = appendHeaderLine(dst, bytestr.StrContentType, contentType) } } - + contentEncoding := h.ContentEncoding() + if len(contentEncoding) > 0 { + dst = appendHeaderLine(dst, bytestr.StrContentEncoding, contentEncoding) + } if len(h.contentLengthBytes) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentLength, h.contentLengthBytes) } @@ -833,6 +860,8 @@ func (h *ResponseHeader) del(key []byte) { switch string(key) { case consts.HeaderContentType: h.contentType = h.contentType[:0] + case consts.HeaderContentEncoding: + h.contentEncoding = h.contentEncoding[:0] case consts.HeaderServer: h.server = h.server[:0] case consts.HeaderSetCookie: @@ -1506,6 +1535,9 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) } return true + } else if utils.CaseInsensitiveCompare(bytestr.StrContentEncoding, key) { + h.SetContentEncodingBytes(value) + return true } else if utils.CaseInsensitiveCompare(bytestr.StrConnection, key) { if bytes.Equal(bytestr.StrClose, value) { h.SetConnectionClose(true) diff --git a/pkg/protocol/header_test.go b/pkg/protocol/header_test.go index 10db9036f..14446d56e 100644 --- a/pkg/protocol/header_test.go +++ b/pkg/protocol/header_test.go @@ -101,6 +101,12 @@ func TestSetContentLengthBytes(t *testing.T) { assert.DeepEqual(t, rh.contentLengthBytes, []byte("foo")) } +func TestSetContentEncoding(t *testing.T) { + rh := ResponseHeader{} + rh.SetContentEncoding("gzip") + assert.DeepEqual(t, rh.contentEncoding, []byte("gzip")) +} + func Test_peekRawHeader(t *testing.T) { s := "Expect: 100-continue\r\nUser-Agent: foo\r\nHost: 127.0.0.1\r\nConnection: Keep-Alive\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" assert.DeepEqual(t, []byte("127.0.0.1"), peekRawHeader([]byte(s), []byte("Host"))) @@ -250,6 +256,7 @@ func TestResponseHeaderDel(t *testing.T) { h.Set("aaa", "bbb") h.Set(consts.HeaderConnection, "keep-alive") h.Set(consts.HeaderContentType, "aaa") + h.Set(consts.HeaderContentEncoding, "gzip") h.Set(consts.HeaderServer, "aaabbb") h.Set(consts.HeaderContentLength, "1123") @@ -264,6 +271,7 @@ func TestResponseHeaderDel(t *testing.T) { h.Del(consts.HeaderServer) h.Del("content-length") h.Del("set-cookie") + h.Del("content-encoding") hv := h.Peek("aaa") if string(hv) != "bbb" { @@ -281,6 +289,10 @@ func TestResponseHeaderDel(t *testing.T) { if string(hv) != string(bytestr.DefaultContentType) { t.Fatalf("unexpected content-type: %q. Expecting %q", hv, bytestr.DefaultContentType) } + hv = h.Peek(consts.HeaderContentEncoding) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } hv = h.Peek(consts.HeaderServer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) @@ -374,20 +386,22 @@ func TestResponseHeaderAdd(t *testing.T) { var h ResponseHeader h.Add("aaa", "bbb") h.Add("content-type", "xxx") + h.SetContentEncoding("gzip") m["bbb"] = struct{}{} m["xxx"] = struct{}{} + m["gzip"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } - if h.Len() != 12 { - t.Fatalf("unexpected header len %d. Expecting 12", h.Len()) + if h.Len() != 13 { + t.Fatalf("unexpected header len %d. Expecting 13", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar", "Content-Type": + case "Aaa", "Foo-Bar", "Content-Type", "Content-Encoding": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } @@ -452,6 +466,23 @@ func TestResponseHeaderAddContentType(t *testing.T) { } } +func TestResponseHeaderAddContentEncoding(t *testing.T) { + t.Parallel() + + var h ResponseHeader + h.Add("Content-Encoding", "test") + + got := string(h.ContentEncoding()) + expected := "test" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + if n := strings.Count(string(h.Header()), "Content-Encoding: "); n != 1 { + t.Errorf("Content-Encoding occurred %d times", n) + } +} + func TestRequestHeaderAddContentType(t *testing.T) { t.Parallel() diff --git a/pkg/protocol/http1/resp/header.go b/pkg/protocol/http1/resp/header.go index fbabffdeb..0944ef7a6 100644 --- a/pkg/protocol/http1/resp/header.go +++ b/pkg/protocol/http1/resp/header.go @@ -141,6 +141,10 @@ func parseHeaders(h *protocol.ResponseHeader, buf []byte) (int, error) { h.SetContentTypeBytes(s.Value) continue } + if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentEncoding) { + h.SetContentEncodingBytes(s.Value) + continue + } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentLength) { var contentLength int if h.ContentLength() != -1 { diff --git a/pkg/protocol/http1/resp/response_test.go b/pkg/protocol/http1/resp/response_test.go index f23f6dfeb..1b4eefb7f 100644 --- a/pkg/protocol/http1/resp/response_test.go +++ b/pkg/protocol/http1/resp/response_test.go @@ -167,7 +167,7 @@ func testResponseReadSuccess(t *testing.T, resp *protocol.Response, response str t.Fatalf("Unexpected error: %s", err) } - verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType) + verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "") if !bytes.Equal(resp.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } @@ -387,7 +387,8 @@ func TestResponseReadLimitBody(t *testing.T) { testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100) testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9) - + // response with content-encoding + testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Encoding: gzip\r\n\r\n9876543210", 10) // chunked response testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 100) @@ -404,32 +405,35 @@ func TestResponseReadWithoutBody(t *testing.T) { var resp protocol.Response - testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\nfoobar", false, - consts.StatusNotModified, 1235, "aa", "foobar") + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Encoding: gzip\r\nContent-Length: 1235\r\n\r\nfoobar", false, + consts.StatusNotModified, 1235, "aa", "foobar", "gzip") - testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false, - consts.StatusNoContent, -1, "aab", "123\r\nss") + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false, + consts.StatusNoContent, -1, "aab", "123\r\nss", "deflate") - testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false, - 123, 3434, "xxx", "aaaa") + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Encoding: gzip\r\nContent-Length: 3434\r\n\r\naaaa", false, + 123, 3434, "xxx", "aaaa", "gzip") - testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true, - consts.StatusOK, 123, "text/xml", "xxxx") + testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Encoding: deflate\r\nContent-Length: 123\r\n\r\nxxxx", true, + consts.StatusOK, 123, "text/xml", "xxxx", "deflate") // '100 Continue' must be skipped. - testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true, - 329, 894, "qwe", "foobar") + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Encoding: gzip\r\nContent-Length: 894\r\n\r\nfoobar", true, + 329, 894, "qwe", "foobar", "gzip") } -func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType string) { +func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) { if h.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) } if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } - if string(h.Peek(consts.HeaderContentType)) != expectedContentType { - t.Fatalf("Unexpected content type %q. Expected %q", h.Peek(consts.HeaderContentType), expectedContentType) + if string(h.ContentType()) != expectedContentType { + t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType) + } + if string(h.ContentEncoding()) != expectedContentEncoding { + t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding) } } @@ -478,7 +482,7 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, } func testResponseReadWithoutBody(t *testing.T, resp *protocol.Response, s string, skipBody bool, - expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string, + expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer, expectedContentEncoding string, ) { zr := mock.NewZeroCopyReader(s) resp.SkipBody = skipBody @@ -489,7 +493,7 @@ func testResponseReadWithoutBody(t *testing.T, resp *protocol.Response, s string if len(resp.Body()) != 0 { t.Fatalf("Unexpected response body %q. Expected %q. response=%q", resp.Body(), "", s) } - verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType) + verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, expectedContentEncoding) assert.VerifyTrailer(t, zr, expectedTrailer) // verify that ordinal response is read after null-body response