From 6233fbc08ea58da4f846ae3c4fd4bfd6feabb55b Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Tue, 1 Jun 2021 10:52:35 +0200 Subject: [PATCH] Fix header .Add functions (#1036) These functions should take the headers that are handled differently into account. --- fasthttpadaptor/adaptor_test.go | 6 +- header.go | 212 ++++++++++++++++++++++---------- header_test.go | 47 +++++-- 3 files changed, 184 insertions(+), 81 deletions(-) diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go index f98d0907b0..e252f3985e 100644 --- a/fasthttpadaptor/adaptor_test.go +++ b/fasthttpadaptor/adaptor_test.go @@ -20,7 +20,6 @@ func TestNewFastHTTPHandler(t *testing.T) { expectedRequestURI := "/foo/bar?baz=123" expectedBody := "body 123 foo bar baz" expectedContentLength := len(expectedBody) - expectedTransferEncoding := "encoding" expectedHost := "foobar.com" expectedRemoteAddr := "1.2.3.4:6789" expectedHeader := map[string]string{ @@ -56,8 +55,8 @@ func TestNewFastHTTPHandler(t *testing.T) { if r.ContentLength != int64(expectedContentLength) { t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength) } - if len(r.TransferEncoding) != 1 || r.TransferEncoding[0] != expectedTransferEncoding { - t.Fatalf("unexpected transferEncoding %q. Expecting %q", r.TransferEncoding, expectedTransferEncoding) + if len(r.TransferEncoding) != 0 { + t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding) } if r.Host != expectedHost { t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost) @@ -101,7 +100,6 @@ func TestNewFastHTTPHandler(t *testing.T) { req.Header.SetMethod(expectedMethod) req.SetRequestURI(expectedRequestURI) req.Header.SetHost(expectedHost) - req.Header.Add(fasthttp.HeaderTransferEncoding, expectedTransferEncoding) req.BodyWriter().Write([]byte(expectedBody)) // nolint:errcheck for k, v := range expectedHeader { req.Header.Set(k, v) diff --git a/header.go b/header.go index 2680de5808..34e2de94e1 100644 --- a/header.go +++ b/header.go @@ -917,37 +917,159 @@ func (h *RequestHeader) del(key []byte) { h.h = delAllArgsBytes(h.h, key) } +// setSpecialHeader handles special headers and return true when a header is processed. +func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { + if len(key) == 0 { + return false + } + + switch key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(strContentType, key) { + h.SetContentTypeBytes(value) + return true + } else if caseInsensitiveCompare(strContentLength, key) { + if contentLength, err := parseContentLength(value); err == nil { + h.contentLength = contentLength + h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) + } + return true + } else if caseInsensitiveCompare(strConnection, key) { + if bytes.Equal(strClose, value) { + h.SetConnectionClose() + } else { + h.ResetConnectionClose() + h.h = setArgBytes(h.h, key, value, argsHasValue) + } + return true + } + case 's': + if caseInsensitiveCompare(strServer, key) { + h.SetServerBytes(value) + return true + } else if caseInsensitiveCompare(strSetCookie, key) { + var kv *argsKV + h.cookies, kv = allocArg(h.cookies) + kv.key = getCookieKey(kv.key, value) + kv.value = append(kv.value[:0], value...) + return true + } + case 't': + if caseInsensitiveCompare(strTransferEncoding, key) { + // Transfer-Encoding is managed automatically. + return true + } + case 'd': + if caseInsensitiveCompare(strDate, key) { + // Date is managed automatically. + return true + } + } + + return false +} + +// setSpecialHeader handles special headers and return true when a header is processed. +func (h *RequestHeader) setSpecialHeader(key, value []byte) bool { + if len(key) == 0 { + return false + } + + switch key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(strContentType, key) { + h.SetContentTypeBytes(value) + return true + } else if caseInsensitiveCompare(strContentLength, key) { + if contentLength, err := parseContentLength(value); err == nil { + h.contentLength = contentLength + h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) + } + return true + } else if caseInsensitiveCompare(strConnection, key) { + if bytes.Equal(strClose, value) { + h.SetConnectionClose() + } else { + h.ResetConnectionClose() + h.h = setArgBytes(h.h, key, value, argsHasValue) + } + return true + } else if caseInsensitiveCompare(strCookie, key) { + h.collectCookies() + h.cookies = parseRequestCookies(h.cookies, value) + return true + } + case 't': + if caseInsensitiveCompare(strTransferEncoding, key) { + // Transfer-Encoding is managed automatically. + return true + } + case 'h': + if caseInsensitiveCompare(strHost, key) { + h.SetHostBytes(value) + return true + } + case 'u': + if caseInsensitiveCompare(strUserAgent, key) { + h.SetUserAgentBytes(value) + return true + } + } + + return false +} + // Add adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use Set for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) Add(key, value string) { - k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) - h.h = appendArg(h.h, b2s(k), value, argsHasValue) + h.AddBytesKV(s2b(key), s2b(value)) } // AddBytesK adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use SetBytesK for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesK(key []byte, value string) { - h.Add(b2s(key), value) + h.AddBytesKV(key, s2b(value)) } // AddBytesV adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use SetBytesV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesV(key string, value []byte) { - h.Add(key, b2s(value)) + h.AddBytesKV(s2b(key), value) } // AddBytesKV adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use SetBytesKV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesKV(key, value []byte) { - h.Add(b2s(key), b2s(value)) + if h.setSpecialHeader(key, value) { + return + } + + k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing) + h.h = appendArgBytes(h.h, k, value, argsHasValue) } // Set sets the given 'key: value' header. @@ -986,35 +1108,11 @@ func (h *ResponseHeader) SetBytesKV(key, value []byte) { // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *ResponseHeader) SetCanonical(key, value []byte) { - switch string(key) { - case HeaderContentType: - h.SetContentTypeBytes(value) - case HeaderServer: - h.SetServerBytes(value) - case HeaderSetCookie: - var kv *argsKV - h.cookies, kv = allocArg(h.cookies) - kv.key = getCookieKey(kv.key, value) - kv.value = append(kv.value[:0], value...) - case HeaderContentLength: - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - } - case HeaderConnection: - if bytes.Equal(strClose, value) { - h.SetConnectionClose() - } else { - h.ResetConnectionClose() - h.h = setArgBytes(h.h, key, value, argsHasValue) - } - case HeaderTransferEncoding: - // Transfer-Encoding is managed automatically. - case HeaderDate: - // Date is managed automatically. - default: - h.h = setArgBytes(h.h, key, value, argsHasValue) + if h.setSpecialHeader(key, value) { + return } + + h.h = setArgBytes(h.h, key, value, argsHasValue) } // SetCookie sets the given response cookie. @@ -1123,8 +1221,7 @@ func (h *RequestHeader) DelAllCookies() { // Multiple headers with the same key may be added with this function. // Use Set for setting a single header for the given key. func (h *RequestHeader) Add(key, value string) { - k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) - h.h = appendArg(h.h, b2s(k), value, argsHasValue) + h.AddBytesKV(s2b(key), s2b(value)) } // AddBytesK adds the given 'key: value' header. @@ -1132,7 +1229,7 @@ func (h *RequestHeader) Add(key, value string) { // Multiple headers with the same key may be added with this function. // Use SetBytesK for setting a single header for the given key. func (h *RequestHeader) AddBytesK(key []byte, value string) { - h.Add(b2s(key), value) + h.AddBytesKV(key, s2b(value)) } // AddBytesV adds the given 'key: value' header. @@ -1140,15 +1237,24 @@ func (h *RequestHeader) AddBytesK(key []byte, value string) { // Multiple headers with the same key may be added with this function. // Use SetBytesV for setting a single header for the given key. func (h *RequestHeader) AddBytesV(key string, value []byte) { - h.Add(key, b2s(value)) + h.AddBytesKV(s2b(key), value) } // AddBytesKV adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use SetBytesKV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Cookie, +// Transfer-Encoding, Host and User-Agent headers can only be set once +// and will overwrite the previous value. func (h *RequestHeader) AddBytesKV(key, value []byte) { - h.Add(b2s(key), b2s(value)) + if h.setSpecialHeader(key, value) { + return + } + + k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing) + h.h = appendArgBytes(h.h, k, value, argsHasValue) } // Set sets the given 'key: value' header. @@ -1187,33 +1293,11 @@ func (h *RequestHeader) SetBytesKV(key, value []byte) { // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *RequestHeader) SetCanonical(key, value []byte) { - switch string(key) { - case HeaderHost: - h.SetHostBytes(value) - case HeaderContentType: - h.SetContentTypeBytes(value) - case HeaderUserAgent: - h.SetUserAgentBytes(value) - case HeaderCookie: - h.collectCookies() - h.cookies = parseRequestCookies(h.cookies, value) - case HeaderContentLength: - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - } - case HeaderConnection: - if bytes.Equal(strClose, value) { - h.SetConnectionClose() - } else { - h.ResetConnectionClose() - h.h = setArgBytes(h.h, key, value, argsHasValue) - } - case HeaderTransferEncoding: - // Transfer-Encoding is managed automatically. - default: - h.h = setArgBytes(h.h, key, value, argsHasValue) + if h.setSpecialHeader(key, value) { + return } + + h.h = setArgBytes(h.h, key, value, argsHasValue) } // Peek returns header value for the given key. diff --git a/header_test.go b/header_test.go index 354ebb4866..7e114d5a1e 100644 --- a/header_test.go +++ b/header_test.go @@ -13,6 +13,26 @@ import ( "testing" ) +func TestResponseHeaderAddContentType(t *testing.T) { + t.Parallel() + + var h ResponseHeader + h.Add("Content-Type", "test") + + got := string(h.Peek("Content-Type")) + expected := "test" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + var buf bytes.Buffer + h.WriteTo(&buf) //nolint:errcheck + + if n := strings.Count(buf.String(), "Content-Type: "); n != 1 { + t.Errorf("Content-Type occured %d times", n) + } +} + func TestResponseHeaderMultiLineValue(t *testing.T) { s := "HTTP/1.1 200 OK\r\n" + "EmptyValue1:\r\n" + @@ -331,7 +351,9 @@ func TestResponseHeaderAdd(t *testing.T) { m := make(map[string]struct{}) var h ResponseHeader h.Add("aaa", "bbb") + h.Add("content-type", "xxx") m["bbb"] = struct{}{} + m["xxx"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) @@ -343,12 +365,11 @@ func TestResponseHeaderAdd(t *testing.T) { h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "Content-Type": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) - case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } @@ -366,15 +387,14 @@ func TestResponseHeaderAdd(t *testing.T) { h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "Content-Type": m[string(v)] = struct{}{} - case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } }) - if len(m) != 11 { - t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) + if len(m) != 12 { + t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m)) } } @@ -382,19 +402,21 @@ func TestRequestHeaderAdd(t *testing.T) { m := make(map[string]struct{}) var h RequestHeader h.Add("aaa", "bbb") + h.Add("user-agent", "xxx") m["bbb"] = struct{}{} + m["xxx"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } - if h.Len() != 11 { - t.Fatalf("unexpected header len %d. Expecting 11", h.Len()) + if h.Len() != 12 { + t.Fatalf("unexpected header len %d. Expecting 12", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "User-Agent": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } @@ -416,15 +438,14 @@ func TestRequestHeaderAdd(t *testing.T) { h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "User-Agent": m[string(v)] = struct{}{} - case "User-Agent": default: t.Fatalf("unexpected key found: %q", k) } }) - if len(m) != 11 { - t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) + if len(m) != 12 { + t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m)) } s1 := h1.String() if s != s1 {