From 873e67e4d598af0769fdf077ff074dfe82add46a Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Thu, 20 Aug 2020 16:43:18 +0300 Subject: [PATCH] Fix how the client checks for presence of Upgrade: websocket, Connection: upgrade (#604) The values of the `Upgrade` and `Connection` response headers can contain multiple tokens, for example Connection: upgrade, keep-alive The WebSocket RFC describes the checking of these as follows: 2. If the response lacks an |Upgrade| header field or the |Upgrade| header field contains a value that is not an ASCII case- insensitive match for the value "websocket", the client MUST _Fail the WebSocket Connection_. 3. If the response lacks a |Connection| header field or the |Connection| header field doesn't contain a token that is an ASCII case-insensitive match for the value "Upgrade", the client MUST _Fail the WebSocket Connection_. It is careful to note "contains a value", "contains a token". Previously, the client would reject with "bad handshake" if the header doesn't contain exactly the value it looks for. Change the checks to use `tokenListContainsValue` instead, which is incidentally what the server is already doing for similar checks. --- client.go | 4 ++-- client_server_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 962c06a3..c4b62fbc 100644 --- a/client.go +++ b/client.go @@ -348,8 +348,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application diff --git a/client_server_test.go b/client_server_test.go index 7e7636f4..5fd2c85a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -481,6 +481,23 @@ func TestBadMethod(t *testing.T) { } } +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close()