From b657deb7dda3dfe4c736b6caba31f53eab0d09ce Mon Sep 17 00:00:00 2001 From: Lluis Campos Date: Fri, 10 Dec 2021 16:15:11 +0100 Subject: [PATCH] MEN-5273: proxy: Fix websocket connection for advanced auth settings By switching to the "enhanced" API for websocket.Dialer from mendersoftware's fork. There is a limitation in current gorilla/websocket.Dialer API in that the user cannot specify a dial method for TLS/TCP connections. The TLS handshake is always done by the library based on user's TLSClientConfig, but that is not enough for Mender as we need it to be done via OpenSSL (aka our dial wrapper for TLS) so that advance auth features like getting the keys from HSM. This commit switches to mendersoftware's fork and modifies the code accordingly (one line change!). The patch has been submitted upstream. See: * https://github.com/gorilla/websocket/issues/745 * https://github.com/gorilla/websocket/pull/746 Changelog: None No changelog, commit 84204a3 claims to support websockets, this commit just fixes a bug there which has not been released. Signed-off-by: Lluis Campos --- app/proxy/proxy_ws.go | 2 +- app/proxy/proxy_ws_test.go | 153 ++++++++++++++++-- client/client.go | 5 +- go.mod | 2 + go.sum | 4 +- vendor/github.com/gorilla/websocket/client.go | 58 +++++-- .../gorilla/websocket/client_clone.go | 1 + .../gorilla/websocket/client_clone_legacy.go | 1 + .../gorilla/websocket/conn_write.go | 1 + .../gorilla/websocket/conn_write_legacy.go | 1 + vendor/github.com/gorilla/websocket/mask.go | 1 + .../github.com/gorilla/websocket/mask_safe.go | 1 + vendor/github.com/gorilla/websocket/server.go | 4 +- vendor/github.com/gorilla/websocket/trace.go | 1 + .../github.com/gorilla/websocket/trace_17.go | 1 + vendor/modules.txt | 3 +- 16 files changed, 206 insertions(+), 33 deletions(-) diff --git a/app/proxy/proxy_ws.go b/app/proxy/proxy_ws.go index 02e1aa652..466556d04 100644 --- a/app/proxy/proxy_ws.go +++ b/app/proxy/proxy_ws.go @@ -84,7 +84,7 @@ func (pc *proxyControllerInner) DoWsUpgrade(w http.ResponseWriter, r *http.Reque connBackend, resp, err := pc.wsDialer.Dial(wsUrl.String(), requestHeader) if err != nil { - log.Errorf("couldn't dial to remote backend url %s", err) + log.Errorf("couldn't dial to remote backend url %q, err: %s", wsUrl.String(), err.Error()) if resp != nil { // WebSocket handshake failed, reply the client with backend's resp if err := copyResponse(w, resp); err != nil { diff --git a/app/proxy/proxy_ws_test.go b/app/proxy/proxy_ws_test.go index c02239542..99d67732a 100644 --- a/app/proxy/proxy_ws_test.go +++ b/app/proxy/proxy_ws_test.go @@ -14,6 +14,9 @@ package proxy import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" "net/http" "reflect" "runtime" @@ -26,21 +29,33 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/mendersoftware/mender/client" cltest "github.com/mendersoftware/mender/client/test" + "github.com/mendersoftware/mender/conf" ) -func prepareProxyWsTest( - t *testing.T, - srv *cltest.ClientTestWsServer, -) (*ProxyController, *websocket.Conn) { +func prepareProxyWsTest(t *testing.T, srv *cltest.ClientTestWsServer) *ProxyController { + + wsDialer, err := client.NewWebsocketDialer(client.Config{}) + require.NoError(t, err) + proxyController, err := NewProxyController( &http.Client{}, - nil, + wsDialer, srv.TestServer.URL, "SecretJwtToken", ) require.NoError(t, err) + return proxyController +} + +func connectProxyWsTest( + t *testing.T, + srv *cltest.ClientTestWsServer, + proxyController *ProxyController, +) *websocket.Conn { + proxyServerUrl := proxyController.GetServerUrl() require.Contains(t, proxyServerUrl, "http://localhost") @@ -51,14 +66,25 @@ func prepareProxyWsTest( require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - return proxyController, conn + return conn } -func TestProxyWsConnect(t *testing.T) { - srv := cltest.NewClientTestWsServer() - defer srv.StopWs() - defer srv.Close() +func prepareAndConnectProxyWsTest( + t *testing.T, + srv *cltest.ClientTestWsServer, +) (*ProxyController, *websocket.Conn) { + proxyController := prepareProxyWsTest(t, srv) + conn := connectProxyWsTest(t, srv, proxyController) + + return proxyController, conn +} + +func runTestSendReceiveWs( + t *testing.T, + srv *cltest.ClientTestWsServer, + proxyController *ProxyController, +) { // Expectations for the test srv.Connect.SendMessages = append( srv.Connect.SendMessages, @@ -82,8 +108,8 @@ func TestProxyWsConnect(t *testing.T) { {MsgType: websocket.TextMessage, Msg: []byte("hello-world")}, } - proxyController, conn := prepareProxyWsTest(t, srv) - defer proxyController.Stop() + conn := connectProxyWsTest(t, srv, proxyController) + defer conn.Close() wg := sync.WaitGroup{} @@ -150,6 +176,17 @@ func TestProxyWsConnect(t *testing.T) { ) } +func TestProxyWsConnect(t *testing.T) { + srv := cltest.NewClientTestWsServer() + defer srv.StopWs() + defer srv.Close() + + proxyController := prepareProxyWsTest(t, srv) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} + func TestProxyWsWebSocketProtocolHeader(t *testing.T) { srv := cltest.NewClientTestWsServer() defer srv.StopWs() @@ -195,7 +232,7 @@ func TestProxyWsTooMany(t *testing.T) { defer srv.StopWs() defer srv.Close() - proxyController, conn := prepareProxyWsTest(t, srv) + proxyController, conn := prepareAndConnectProxyWsTest(t, srv) defer proxyController.Stop() defer conn.Close() @@ -218,7 +255,7 @@ func TestProxyWsStop(t *testing.T) { defer srv.StopWs() defer srv.Close() - proxyController, conn := prepareProxyWsTest(t, srv) + proxyController, conn := prepareAndConnectProxyWsTest(t, srv) defer proxyController.Stop() defer conn.Close() @@ -309,3 +346,91 @@ func TestProxyWsGoroutines(t *testing.T) { 1*time.Millisecond, ) } + +func TestProxyWsConnectCustomCert(t *testing.T) { + serverCert, err := tls.LoadX509KeyPair( + "../../client/test/server.crt", + "../../client/test/server.key", + ) + require.NoError(t, err) + + tc := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + + srv := cltest.NewClientTestWsServer(&tc) + defer srv.StopWs() + defer srv.Close() + + conffromfile := conf.MenderConfigFromFile{ + ServerCertificate: "../../client/test/server.crt", + } + testconf := &conf.MenderConfig{MenderConfigFromFile: conffromfile} + httpConfig := testconf.GetHttpConfig() + + api, err := client.New(httpConfig) + require.NoError(t, err) + + wsDialer, err := client.NewWebsocketDialer(httpConfig) + require.NoError(t, err) + + proxyController, err := NewProxyController( + api, + wsDialer, + srv.TestServer.URL, + "SecretJwtToken", + ) + require.NoError(t, err) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} +func TestProxyWsConnectMutualTLS(t *testing.T) { + serverCert, err := tls.LoadX509KeyPair( + "../../client/test/server.crt", + "../../client/test/server.key", + ) + require.NoError(t, err) + + clientClientCertPool := x509.NewCertPool() + pb, err := ioutil.ReadFile("../../client/testdata/client.crt") + require.NoError(t, err) + clientClientCertPool.AppendCertsFromPEM(pb) + + tc := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientClientCertPool, + } + + srv := cltest.NewClientTestWsServer(&tc) + defer srv.StopWs() + defer srv.Close() + + conffromfile := conf.MenderConfigFromFile{ + ServerCertificate: "../../client/test/server.crt", + HttpsClient: client.HttpsClient{ + Certificate: "../../client/testdata/client.crt", + Key: "../../client/testdata/client-cert.key", + }, + } + testconf := &conf.MenderConfig{MenderConfigFromFile: conffromfile} + httpConfig := testconf.GetHttpConfig() + + api, err := client.New(httpConfig) + require.NoError(t, err) + + wsDialer, err := client.NewWebsocketDialer(httpConfig) + require.NoError(t, err) + + proxyController, err := NewProxyController( + api, + wsDialer, + srv.TestServer.URL, + "SecretJwtToken", + ) + require.NoError(t, err) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} diff --git a/client/client.go b/client/client.go index 206729239..3ee6bb786 100644 --- a/client/client.go +++ b/client/client.go @@ -489,8 +489,7 @@ func loadClientTrust(ctx *openssl.Ctx, conf *Config) (*openssl.Ctx, error) { return ctx, nil } -func dialOpenSSL(ctx *openssl.Ctx, conf *Config, network string, addr string) (net.Conn, error) { - +func dialOpenSSL(ctx *openssl.Ctx, conf *Config, _ string, addr string) (net.Conn, error) { flags := openssl.DialFlags(0) if conf.NoVerify { @@ -694,7 +693,7 @@ func newWebsocketDialerTLS(conf Config) (*websocket.Dialer, error) { } dialer := websocket.Dialer{ - NetDial: func(network string, addr string) (net.Conn, error) { + NetDialTLS: func(network string, addr string) (net.Conn, error) { return dialOpenSSL(ctx, &conf, network, addr) }, } diff --git a/go.mod b/go.mod index 4a83e0a01..d347b584c 100644 --- a/go.mod +++ b/go.mod @@ -21,3 +21,5 @@ require ( ) replace github.com/urfave/cli/v2 => github.com/mendersoftware/cli/v2 v2.1.1-minimal + +replace github.com/gorilla/websocket => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 diff --git a/go.sum b/go.sum index 05e9840c2..8d3b99913 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/godbus/dbus v4.1.0+incompatible h1:WqqLRTsQic3apZUK9qC5sGNfXthmPXzUZ7nQPrNITa4= github.com/godbus/dbus v4.1.0+incompatible/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/klauspost/compress v1.10.5 h1:7q6vHIqubShURwQz8cQK6yIe/xC3IF0Vm7TGfqjewrc= github.com/klauspost/compress v1.10.5/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= @@ -34,6 +32,8 @@ github.com/mendersoftware/openssl v1.1.0 h1:eRiG3CwzkMIna1xrTE9SiX9lrsme9irlb6i5 github.com/mendersoftware/openssl v1.1.0/go.mod h1:tikEC94q+Y0TU6r19L6mHzwruoTNYPEkrQPvsHEcQyU= github.com/mendersoftware/progressbar v0.0.3 h1:AUdBGPvXO0l9i39rmXKZbEAPet2FzBeiG8b30D5/2Vc= github.com/mendersoftware/progressbar v0.0.3/go.mod h1:NYaLNLhy3UXkRweGjhR3We3Q1ngmUmOWjC3+m8EzwjE= +github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 h1:bxs2j1011PQiBPAP127cmBdAnw+zzq65tWOUeCFxVXU= +github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/minio/sha256-simd v0.1.1 h1:5QHSlgo3nt5yKOJrC7W8w7X+NFl8cMPZm96iu8kKUJU= github.com/minio/sha256-simd v0.1.1/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go index 962c06a39..7a6f09f7b 100644 --- a/vendor/github.com/gorilla/websocket/client.go +++ b/vendor/github.com/gorilla/websocket/client.go @@ -54,9 +54,21 @@ type Dialer struct { NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLS specifies the dial function for creating TLS/TCP connections. If + // NetDialTLS is nil, net.Dial is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLS func(network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialTLS is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -65,6 +77,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -237,13 +251,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialTLS != nil { + netDial = d.NetDialTLS + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -304,7 +339,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil && d.NetDialTLS == nil { + // If either NetDialTLS or NetDialTLSContext are set, assume that + // the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort @@ -348,8 +386,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application diff --git a/vendor/github.com/gorilla/websocket/client_clone.go b/vendor/github.com/gorilla/websocket/client_clone.go index 4f0d94372..4179c7a07 100644 --- a/vendor/github.com/gorilla/websocket/client_clone.go +++ b/vendor/github.com/gorilla/websocket/client_clone.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/client_clone_legacy.go b/vendor/github.com/gorilla/websocket/client_clone_legacy.go index babb007fb..7e241a88d 100644 --- a/vendor/github.com/gorilla/websocket/client_clone_legacy.go +++ b/vendor/github.com/gorilla/websocket/client_clone_legacy.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/conn_write.go b/vendor/github.com/gorilla/websocket/conn_write.go index a509a21f8..497467adb 100644 --- a/vendor/github.com/gorilla/websocket/conn_write.go +++ b/vendor/github.com/gorilla/websocket/conn_write.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/conn_write_legacy.go b/vendor/github.com/gorilla/websocket/conn_write_legacy.go index 37edaff5a..8501a2334 100644 --- a/vendor/github.com/gorilla/websocket/conn_write_legacy.go +++ b/vendor/github.com/gorilla/websocket/conn_write_legacy.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go index 577fce9ef..d0742bf2a 100644 --- a/vendor/github.com/gorilla/websocket/mask.go +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build !appengine // +build !appengine package websocket diff --git a/vendor/github.com/gorilla/websocket/mask_safe.go b/vendor/github.com/gorilla/websocket/mask_safe.go index 2aac060e5..36250ca7c 100644 --- a/vendor/github.com/gorilla/websocket/mask_safe.go +++ b/vendor/github.com/gorilla/websocket/mask_safe.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build appengine // +build appengine package websocket diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go index 887d55891..152ebf898 100644 --- a/vendor/github.com/gorilla/websocket/server.go +++ b/vendor/github.com/gorilla/websocket/server.go @@ -115,8 +115,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade -// request. Use the responseHeader to specify cookies (Set-Cookie) and the -// application negotiated subprotocol (Sec-WebSocket-Protocol). +// request. Use the responseHeader to specify cookies (Set-Cookie). To specify +// subprotocols supported by the server, set Upgrader.Subprotocols directly. // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. diff --git a/vendor/github.com/gorilla/websocket/trace.go b/vendor/github.com/gorilla/websocket/trace.go index 834f122a0..246a5d33d 100644 --- a/vendor/github.com/gorilla/websocket/trace.go +++ b/vendor/github.com/gorilla/websocket/trace.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/trace_17.go b/vendor/github.com/gorilla/websocket/trace_17.go index 77d05a0b5..f4be940ad 100644 --- a/vendor/github.com/gorilla/websocket/trace_17.go +++ b/vendor/github.com/gorilla/websocket/trace_17.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/modules.txt b/vendor/modules.txt index e36f25dcb..dd79fb683 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -8,7 +8,7 @@ github.com/davecgh/go-spew/spew # github.com/godbus/dbus v4.1.0+incompatible ## explicit github.com/godbus/dbus -# github.com/gorilla/websocket v1.4.2 +# github.com/gorilla/websocket v1.4.2 => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 ## explicit github.com/gorilla/websocket # github.com/klauspost/compress v1.10.5 @@ -74,3 +74,4 @@ golang.org/x/term # gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 gopkg.in/yaml.v3 # github.com/urfave/cli/v2 => github.com/mendersoftware/cli/v2 v2.1.1-minimal +# github.com/gorilla/websocket => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918