diff --git a/config/http_config.go b/config/http_config.go index 4b872417..d8b9b34b 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -418,6 +418,9 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT http2t.ReadIdleTimeout = time.Minute } + // hostnameRoundTripper sets the http.Request Host to the value set as the TLS server name + rt = newHostnameRoundTripper(tlsConfig, rt) + // If a authorization_credentials is provided, create a round tripper that will set the // Authorization header correctly on each request. if cfg.Authorization != nil && len(cfg.Authorization.Credentials) > 0 { @@ -457,6 +460,28 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT) } +type hostnameRoundTripper struct { + tlsConfig *tls.Config + rt http.RoundTripper +} + +func newHostnameRoundTripper(tlsConfig *tls.Config, rt http.RoundTripper) http.RoundTripper { + return &hostnameRoundTripper{ + tlsConfig: tlsConfig, + rt: rt, + } +} + +func (rt *hostnameRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + if rt.tlsConfig.ServerName != "" { + hostParts := strings.Split(req.Host, ":") + hostParts[0] = rt.tlsConfig.ServerName + req.Host = strings.Join(hostParts, ":") + } + return rt.rt.RoundTrip(req) +} + type authorizationCredentialsRoundTripper struct { authType string authCredentials Secret diff --git a/config/http_config_test.go b/config/http_config_test.go index cbaaba0a..0ef6bfeb 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -37,7 +37,7 @@ import ( "testing" "time" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) const ( @@ -179,6 +179,27 @@ func TestNewClientFromConfig(t *testing.T) { handler: func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, ExpectedMessage) }, + }, { + clientConfig: HTTPClientConfig{ + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "test-domain.com", + InsecureSkipVerify: true}, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + srvAddr := r.Context().Value(http.LocalAddrContextKey).(net.Addr) + srvPort := strings.Split(srvAddr.String(), ":")[1] + + expectedHostHeader := "test-domain.com:" + srvPort + actualHostHeader := r.Host + if actualHostHeader != expectedHostHeader { + fmt.Fprintf(w, "The expected Host header (%s) differs from the obtained Host header (%s)", + expectedHostHeader, actualHostHeader) + } + fmt.Fprint(w, ExpectedMessage) + }, }, { clientConfig: HTTPClientConfig{ BearerToken: BearerToken, @@ -512,7 +533,7 @@ func TestCustomIdleConnTimeout(t *testing.T) { t.Fatalf("Can't create a round-tripper from this config: %+v", cfg) } - transport, ok := rt.(*http.Transport) + transport, ok := rt.(*hostnameRoundTripper).rt.(*http.Transport) if !ok { t.Fatalf("Unexpected transport: %+v", transport) }