Skip to content

Commit

Permalink
Add comprehensive host test (#429)
Browse files Browse the repository at this point in the history
Add table driven test for handling of host in request URL, request
header and TLS server name. In addition to testing various uses of host
names, this test also confirms that host names are handled the same as
the net/http client.

The new table driven test replaces TestDialTLS, TestDialTLSNoverify,
TestDialTLSBadCert and TestHostHeader.

Eliminate duplicated code for constructing root CA.
  • Loading branch information
Steven Scott authored and garyburd committed Sep 24, 2018
1 parent 66b9c49 commit cdd40f5
Showing 1 changed file with 194 additions and 79 deletions.
273 changes: 194 additions & 79 deletions client_server_test.go
Expand Up @@ -11,8 +11,10 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/cookiejar"
Expand Down Expand Up @@ -42,17 +44,12 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second,
}

var cstDialerWithoutHandshakeTimeout = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

type cstHandler struct{ *testing.T }

type cstServer struct {
*httptest.Server
URL string
t *testing.T
}

const (
Expand Down Expand Up @@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
sendRecv(t, ws)
}

func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
Expand All @@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t)
defer s.Close()

ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
return certs
}

func TestDialTLSNoVerify(t *testing.T) {
func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
Expand Down Expand Up @@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
s := newServer(t)
defer s.Close()

d := cstDialerWithoutHandshakeTimeout
d := cstDialer
d.HandshakeTimeout = 0
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
netDialer := &net.Dialer{}
c, err := netDialer.DialContext(ctx, n, a)
Expand Down Expand Up @@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
}
}

// TestHostHeader confirms that the host header provided in the call to Dial is
// sent to the server.
func TestHostHeader(t *testing.T) {
s := newServer(t)
defer s.Close()
type testLogWriter struct {
t *testing.T
}

specifiedHost := make(chan string, 1)
origHandler := s.Server.Config.Handler
func (w testLogWriter) Write(p []byte) (int, error) {
w.t.Logf("%s", p)
return len(p), nil
}

// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host
origHandler.ServeHTTP(w, r)
})
// TestHost tests handling of host names and confirms that it matches net/http.
func TestHost(t *testing.T) {

ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})

server := httptest.NewServer(handler)
defer server.Close()

tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()

addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}

// Avoid log noise from net/http server by logging to testing.T
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
tlsServer.Config.ErrorLog = server.Config.ErrorLog

cas := rootCAs(t, tlsServer)

tests := []struct {
fail bool // true if dial / get should fail
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optiona host for tls ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
}{
{
server: server,
url: addrs[server],
wantAddr: addrs[server],
wantHeader: addrs[server],
},
{
server: tlsServer,
url: addrs[tlsServer],
wantAddr: addrs[tlsServer],
wantHeader: addrs[tlsServer],
},

{
server: server,
url: addrs[server],
header: "badhost.com",
wantAddr: addrs[server],
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: addrs[tlsServer],
header: "badhost.com",
wantAddr: addrs[tlsServer],
wantHeader: "badhost.com",
},

{
server: server,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:80",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:443",
wantHeader: "badhost.com",
},

if gotHost := <-specifiedHost; gotHost != "testhost" {
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
{
server: server,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:80",
wantHeader: "example.com",
},
{
fail: true,
server: tlsServer,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:443",
},
{
server: tlsServer,
url: "badhost.com",
insecureSkipVerify: true,
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "badhost.com",
tls: "example.com",
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
}

sendRecv(t, ws)
for i, tt := range tests {

tls := &tls.Config{
RootCAs: cas,
ServerName: tt.tls,
InsecureSkipVerify: tt.insecureSkipVerify,
}

var gotAddr string
dialer := Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
gotAddr = addr
return net.Dial(network, addrs[tt.server])
},
TLSClientConfig: tls,
}

// Test websocket dial

h := http.Header{}
if tt.header != "" {
h.Set("Host", tt.header)
}
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
if err == nil {
c.Close()
}

check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
switch {
case tt.fail && err == nil:
t.Errorf("%s: unexpected success", name)
case !tt.fail && err != nil:
t.Errorf("%s: unexpected error %v", name, err)
case !tt.fail && err == nil:
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
}
}
}

check(wsProtos)

// Confirm that net/http has same result

transport := &http.Transport{
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
client := &http.Client{Transport: transport}
resp, err = client.Do(req)
if err == nil {
resp.Body.Close()
}
transport.CloseIdleConnections()
check(httpProtos)
}
}

func TestDialCompression(t *testing.T) {
Expand Down Expand Up @@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}

ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {
Expand Down Expand Up @@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}

ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {
Expand Down

0 comments on commit cdd40f5

Please sign in to comment.