Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comprehensive host test #429

Merged
merged 1 commit into from Sep 24, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
273 changes: 194 additions & 79 deletions client_server_test.go
Expand Up @@ -11,8 +11,10 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/cookiejar"
Expand Down Expand Up @@ -42,17 +44,12 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second,
}

var cstDialerWithoutHandshakeTimeout = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

type cstHandler struct{ *testing.T }

type cstServer struct {
*httptest.Server
URL string
t *testing.T
}

const (
Expand Down Expand Up @@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
sendRecv(t, ws)
}

func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
Expand All @@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t)
defer s.Close()

ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
return certs
}

func TestDialTLSNoVerify(t *testing.T) {
func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
Expand Down Expand Up @@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
s := newServer(t)
defer s.Close()

d := cstDialerWithoutHandshakeTimeout
d := cstDialer
d.HandshakeTimeout = 0
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
netDialer := &net.Dialer{}
c, err := netDialer.DialContext(ctx, n, a)
Expand Down Expand Up @@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
}
}

// TestHostHeader confirms that the host header provided in the call to Dial is
// sent to the server.
func TestHostHeader(t *testing.T) {
s := newServer(t)
defer s.Close()
type testLogWriter struct {
t *testing.T
}

specifiedHost := make(chan string, 1)
origHandler := s.Server.Config.Handler
func (w testLogWriter) Write(p []byte) (int, error) {
w.t.Logf("%s", p)
return len(p), nil
}

// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host
origHandler.ServeHTTP(w, r)
})
// TestHost tests handling of host names and confirms that it matches net/http.
func TestHost(t *testing.T) {

ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})

server := httptest.NewServer(handler)
defer server.Close()

tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()

addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}

// Avoid log noise from net/http server by logging to testing.T
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
tlsServer.Config.ErrorLog = server.Config.ErrorLog

cas := rootCAs(t, tlsServer)

tests := []struct {
fail bool // true if dial / get should fail
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optiona host for tls ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
}{
{
server: server,
url: addrs[server],
wantAddr: addrs[server],
wantHeader: addrs[server],
},
{
server: tlsServer,
url: addrs[tlsServer],
wantAddr: addrs[tlsServer],
wantHeader: addrs[tlsServer],
},

{
server: server,
url: addrs[server],
header: "badhost.com",
wantAddr: addrs[server],
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: addrs[tlsServer],
header: "badhost.com",
wantAddr: addrs[tlsServer],
wantHeader: "badhost.com",
},

{
server: server,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:80",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:443",
wantHeader: "badhost.com",
},

if gotHost := <-specifiedHost; gotHost != "testhost" {
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
{
server: server,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:80",
wantHeader: "example.com",
},
{
fail: true,
server: tlsServer,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:443",
},
{
server: tlsServer,
url: "badhost.com",
insecureSkipVerify: true,
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "badhost.com",
tls: "example.com",
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
}

sendRecv(t, ws)
for i, tt := range tests {

tls := &tls.Config{
RootCAs: cas,
ServerName: tt.tls,
InsecureSkipVerify: tt.insecureSkipVerify,
}

var gotAddr string
dialer := Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
gotAddr = addr
return net.Dial(network, addrs[tt.server])
},
TLSClientConfig: tls,
}

// Test websocket dial

h := http.Header{}
if tt.header != "" {
h.Set("Host", tt.header)
}
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
if err == nil {
c.Close()
}

check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
switch {
case tt.fail && err == nil:
t.Errorf("%s: unexpected success", name)
case !tt.fail && err != nil:
t.Errorf("%s: unexpected error %v", name, err)
case !tt.fail && err == nil:
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
}
}
}

check(wsProtos)

// Confirm that net/http has same result

transport := &http.Transport{
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
client := &http.Client{Transport: transport}
resp, err = client.Do(req)
if err == nil {
resp.Body.Close()
}
transport.CloseIdleConnections()
check(httpProtos)
}
}

func TestDialCompression(t *testing.T) {
Expand Down Expand Up @@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}

ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {
Expand Down Expand Up @@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
s := newTLSServer(t)
defer s.Close()

certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}

d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}

ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {
Expand Down