Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set Host header for Client #1169

Merged
merged 4 commits into from Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 11 additions & 3 deletions http.go
Expand Up @@ -56,6 +56,9 @@ type Request struct {
// Request timeout. Usually set by DoDeadline or DoTimeout
// if <= 0, means not set
timeout time.Duration

// Use Host header (request.Header.SetHost) instead of the host from SetRequestURI, SetHost, or URI().SetHost
UseHostHeader bool
}

// Response represents HTTP response.
Expand Down Expand Up @@ -1357,10 +1360,15 @@ func (req *Request) Write(w *bufio.Writer) error {
if len(req.Header.Host()) == 0 || req.parsedURI {
uri := req.URI()
host := uri.Host()
if len(host) == 0 {
return errRequestHostRequired
if len(req.Header.Host()) == 0 {
if len(host) == 0 {
return errRequestHostRequired
} else {
req.Header.SetHostBytes(host)
}
} else if !req.UseHostHeader {
req.Header.SetHostBytes(host)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the if len(host) == 0 { return errRequestHostRequired check then also be inside the if req.AllowToChangeHostHeader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The uri.Host() used as connection address at least for Client type. If host will be empty but Host header is not, connection will be broken. I've checked such code in http.go

	if len(req.Header.Host()) == 0 || !req.UseHostHeader {
		req.Header.SetHostBytes(host)
	} else if len(host) == 0 {
		return errRequestHostRequired
	}

And such usage for this code

	req.UseHostHeader = true
	req.SetHost("")
	req.Header.SetHost(host)

There is result - lookup : no such host

We can use headers for setting uri.Host() but I am not sure that it is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've check code again. It is not important for Client, because error raised inside dialAddr. But I am still not sure that it is correct allow uri.Host() to be empty if Host header presents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But shouldn't len(host) == 0 be allowed when req.UseHostHeader === true? (when it's used outside of Client). It makes more sense to me to put that inside the if?

Sorry I'm replying bit slow, I'm traveling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how it should work in other cases. But when used in Client len(host) == 0 does not work. Is it ok for you if I change code to this?

if len(req.Header.Host()) == 0 || !req.UseHostHeader {
	req.Header.SetHostBytes(host)
} else if len(host) == 0 {
	return errRequestHostRequired
} 

In my case it does not important at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked logic again and pushed a new version. errRequestHostRequired should be returned only if both uri.Host() and req.Header.Host() are empty. And it does not depend from UseHostHeader. If Header.Host() is empty we always need to set header from uri.Host(). And only when both are not empty we need to check UseHostHeader.

}
req.Header.SetHostBytes(host)
req.Header.SetRequestURIBytes(uri.RequestURI())

if len(uri.username) > 0 {
Expand Down
31 changes: 26 additions & 5 deletions http_test.go
Expand Up @@ -58,7 +58,7 @@ func TestIssue875(t *testing.T) {
expectedLocation string
}

var testcases = []testcase{
testcases := []testcase{
{
uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
Expand Down Expand Up @@ -117,7 +117,6 @@ func TestRequestCopyTo(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
testRequestCopyTo(t, &req)

}

func TestResponseCopyTo(t *testing.T) {
Expand All @@ -134,7 +133,6 @@ func TestResponseCopyTo(t *testing.T) {
resp.Header.SetStatusCode(200)
resp.SetBodyString("test")
testResponseCopyTo(t, &resp)

}

func testRequestCopyTo(t *testing.T, src *Request) {
Expand Down Expand Up @@ -594,6 +592,31 @@ func TestRequestUpdateURI(t *testing.T) {
}
}

func TestUseHostHeader(t *testing.T) {
t.Parallel()

var r Request
r.UseHostHeader = true
r.Header.SetHost("aaa.bbb")
r.SetRequestURI("/lkjkl/kjl")

// Modify request uri and host via URI() object and make sure
// the requestURI and Host header are properly updated
u := r.URI()
u.SetPath("/123/432.html")
u.SetHost("foobar.com")
a := u.QueryArgs()
a.Set("aaa", "bcse")

s := r.String()
if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") {
t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s)
}
if !strings.Contains(s, "\r\nHost: aaa.bbb\r\n") {
t.Fatalf("cannot find %q in %q", "\r\nHost: aaa.bbb\r\n", s)
}
}

func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1053,7 +1076,6 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) {
if string(formData) != string(r.Body()) {
t.Fatalf("The body given must equal the body in the Request")
}

}

func TestRequestMayContinue(t *testing.T) {
Expand Down Expand Up @@ -2183,7 +2205,6 @@ Content-Type: application/json
`, "\n", "\r\n", -1)
mr := multipart.NewReader(strings.NewReader(s), "foo")
form, err := mr.ReadForm(1024)

if err != nil {
t.Fatalf("unexpected error: %s", err)
}
Expand Down