diff --git a/http.go b/http.go index c8c2eb439e..6102f65757 100644 --- a/http.go +++ b/http.go @@ -56,6 +56,9 @@ type Request struct { // Request timeout. Usually set by DoDeadline or DoTimeout // if <= 0, means not set timeout time.Duration + + // Use Host header (request.Header.SetHost) instead of the host from SetRequestURI, SetHost, or URI().SetHost + UseHostHeader bool } // Response represents HTTP response. @@ -1357,10 +1360,15 @@ func (req *Request) Write(w *bufio.Writer) error { if len(req.Header.Host()) == 0 || req.parsedURI { uri := req.URI() host := uri.Host() - if len(host) == 0 { - return errRequestHostRequired + if len(req.Header.Host()) == 0 { + if len(host) == 0 { + return errRequestHostRequired + } else { + req.Header.SetHostBytes(host) + } + } else if !req.UseHostHeader { + req.Header.SetHostBytes(host) } - req.Header.SetHostBytes(host) req.Header.SetRequestURIBytes(uri.RequestURI()) if len(uri.username) > 0 { diff --git a/http_test.go b/http_test.go index 8bd050b98e..2259611ebf 100644 --- a/http_test.go +++ b/http_test.go @@ -58,7 +58,7 @@ func TestIssue875(t *testing.T) { expectedLocation string } - var testcases = []testcase{ + testcases := []testcase{ { uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n", @@ -117,7 +117,6 @@ func TestRequestCopyTo(t *testing.T) { t.Fatalf("unexpected error: %s", err) } testRequestCopyTo(t, &req) - } func TestResponseCopyTo(t *testing.T) { @@ -134,7 +133,6 @@ func TestResponseCopyTo(t *testing.T) { resp.Header.SetStatusCode(200) resp.SetBodyString("test") testResponseCopyTo(t, &resp) - } func testRequestCopyTo(t *testing.T, src *Request) { @@ -594,6 +592,31 @@ func TestRequestUpdateURI(t *testing.T) { } } +func TestUseHostHeader(t *testing.T) { + t.Parallel() + + var r Request + r.UseHostHeader = true + r.Header.SetHost("aaa.bbb") + r.SetRequestURI("/lkjkl/kjl") + + // Modify request uri and host via URI() object and make sure + // the requestURI and Host header are properly updated + u := r.URI() + u.SetPath("/123/432.html") + u.SetHost("foobar.com") + a := u.QueryArgs() + a.Set("aaa", "bcse") + + s := r.String() + if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") { + t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s) + } + if !strings.Contains(s, "\r\nHost: aaa.bbb\r\n") { + t.Fatalf("cannot find %q in %q", "\r\nHost: aaa.bbb\r\n", s) + } +} + func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) { t.Parallel() @@ -1053,7 +1076,6 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) { if string(formData) != string(r.Body()) { t.Fatalf("The body given must equal the body in the Request") } - } func TestRequestMayContinue(t *testing.T) { @@ -2183,7 +2205,6 @@ Content-Type: application/json `, "\n", "\r\n", -1) mr := multipart.NewReader(strings.NewReader(s), "foo") form, err := mr.ReadForm(1024) - if err != nil { t.Fatalf("unexpected error: %s", err) }