diff --git a/client.go b/client.go index 962c06a3..c4b62fbc 100644 --- a/client.go +++ b/client.go @@ -348,8 +348,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application diff --git a/client_server_test.go b/client_server_test.go index 7e7636f4..5fd2c85a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -481,6 +481,23 @@ func TestBadMethod(t *testing.T) { } } +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close()