diff --git a/app/proxy/proxy_ws.go b/app/proxy/proxy_ws.go index 02e1aa652..466556d04 100644 --- a/app/proxy/proxy_ws.go +++ b/app/proxy/proxy_ws.go @@ -84,7 +84,7 @@ func (pc *proxyControllerInner) DoWsUpgrade(w http.ResponseWriter, r *http.Reque connBackend, resp, err := pc.wsDialer.Dial(wsUrl.String(), requestHeader) if err != nil { - log.Errorf("couldn't dial to remote backend url %s", err) + log.Errorf("couldn't dial to remote backend url %q, err: %s", wsUrl.String(), err.Error()) if resp != nil { // WebSocket handshake failed, reply the client with backend's resp if err := copyResponse(w, resp); err != nil { diff --git a/app/proxy/proxy_ws_test.go b/app/proxy/proxy_ws_test.go index c02239542..99d67732a 100644 --- a/app/proxy/proxy_ws_test.go +++ b/app/proxy/proxy_ws_test.go @@ -14,6 +14,9 @@ package proxy import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" "net/http" "reflect" "runtime" @@ -26,21 +29,33 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/mendersoftware/mender/client" cltest "github.com/mendersoftware/mender/client/test" + "github.com/mendersoftware/mender/conf" ) -func prepareProxyWsTest( - t *testing.T, - srv *cltest.ClientTestWsServer, -) (*ProxyController, *websocket.Conn) { +func prepareProxyWsTest(t *testing.T, srv *cltest.ClientTestWsServer) *ProxyController { + + wsDialer, err := client.NewWebsocketDialer(client.Config{}) + require.NoError(t, err) + proxyController, err := NewProxyController( &http.Client{}, - nil, + wsDialer, srv.TestServer.URL, "SecretJwtToken", ) require.NoError(t, err) + return proxyController +} + +func connectProxyWsTest( + t *testing.T, + srv *cltest.ClientTestWsServer, + proxyController *ProxyController, +) *websocket.Conn { + proxyServerUrl := proxyController.GetServerUrl() require.Contains(t, proxyServerUrl, "http://localhost") @@ -51,14 +66,25 @@ func prepareProxyWsTest( require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - return proxyController, conn + return conn } -func TestProxyWsConnect(t *testing.T) { - srv := cltest.NewClientTestWsServer() - defer srv.StopWs() - defer srv.Close() +func prepareAndConnectProxyWsTest( + t *testing.T, + srv *cltest.ClientTestWsServer, +) (*ProxyController, *websocket.Conn) { + proxyController := prepareProxyWsTest(t, srv) + conn := connectProxyWsTest(t, srv, proxyController) + + return proxyController, conn +} + +func runTestSendReceiveWs( + t *testing.T, + srv *cltest.ClientTestWsServer, + proxyController *ProxyController, +) { // Expectations for the test srv.Connect.SendMessages = append( srv.Connect.SendMessages, @@ -82,8 +108,8 @@ func TestProxyWsConnect(t *testing.T) { {MsgType: websocket.TextMessage, Msg: []byte("hello-world")}, } - proxyController, conn := prepareProxyWsTest(t, srv) - defer proxyController.Stop() + conn := connectProxyWsTest(t, srv, proxyController) + defer conn.Close() wg := sync.WaitGroup{} @@ -150,6 +176,17 @@ func TestProxyWsConnect(t *testing.T) { ) } +func TestProxyWsConnect(t *testing.T) { + srv := cltest.NewClientTestWsServer() + defer srv.StopWs() + defer srv.Close() + + proxyController := prepareProxyWsTest(t, srv) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} + func TestProxyWsWebSocketProtocolHeader(t *testing.T) { srv := cltest.NewClientTestWsServer() defer srv.StopWs() @@ -195,7 +232,7 @@ func TestProxyWsTooMany(t *testing.T) { defer srv.StopWs() defer srv.Close() - proxyController, conn := prepareProxyWsTest(t, srv) + proxyController, conn := prepareAndConnectProxyWsTest(t, srv) defer proxyController.Stop() defer conn.Close() @@ -218,7 +255,7 @@ func TestProxyWsStop(t *testing.T) { defer srv.StopWs() defer srv.Close() - proxyController, conn := prepareProxyWsTest(t, srv) + proxyController, conn := prepareAndConnectProxyWsTest(t, srv) defer proxyController.Stop() defer conn.Close() @@ -309,3 +346,91 @@ func TestProxyWsGoroutines(t *testing.T) { 1*time.Millisecond, ) } + +func TestProxyWsConnectCustomCert(t *testing.T) { + serverCert, err := tls.LoadX509KeyPair( + "../../client/test/server.crt", + "../../client/test/server.key", + ) + require.NoError(t, err) + + tc := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + + srv := cltest.NewClientTestWsServer(&tc) + defer srv.StopWs() + defer srv.Close() + + conffromfile := conf.MenderConfigFromFile{ + ServerCertificate: "../../client/test/server.crt", + } + testconf := &conf.MenderConfig{MenderConfigFromFile: conffromfile} + httpConfig := testconf.GetHttpConfig() + + api, err := client.New(httpConfig) + require.NoError(t, err) + + wsDialer, err := client.NewWebsocketDialer(httpConfig) + require.NoError(t, err) + + proxyController, err := NewProxyController( + api, + wsDialer, + srv.TestServer.URL, + "SecretJwtToken", + ) + require.NoError(t, err) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} +func TestProxyWsConnectMutualTLS(t *testing.T) { + serverCert, err := tls.LoadX509KeyPair( + "../../client/test/server.crt", + "../../client/test/server.key", + ) + require.NoError(t, err) + + clientClientCertPool := x509.NewCertPool() + pb, err := ioutil.ReadFile("../../client/testdata/client.crt") + require.NoError(t, err) + clientClientCertPool.AppendCertsFromPEM(pb) + + tc := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientClientCertPool, + } + + srv := cltest.NewClientTestWsServer(&tc) + defer srv.StopWs() + defer srv.Close() + + conffromfile := conf.MenderConfigFromFile{ + ServerCertificate: "../../client/test/server.crt", + HttpsClient: client.HttpsClient{ + Certificate: "../../client/testdata/client.crt", + Key: "../../client/testdata/client-cert.key", + }, + } + testconf := &conf.MenderConfig{MenderConfigFromFile: conffromfile} + httpConfig := testconf.GetHttpConfig() + + api, err := client.New(httpConfig) + require.NoError(t, err) + + wsDialer, err := client.NewWebsocketDialer(httpConfig) + require.NoError(t, err) + + proxyController, err := NewProxyController( + api, + wsDialer, + srv.TestServer.URL, + "SecretJwtToken", + ) + require.NoError(t, err) + defer proxyController.Stop() + + runTestSendReceiveWs(t, srv, proxyController) +} diff --git a/client/client.go b/client/client.go index 206729239..3ee6bb786 100644 --- a/client/client.go +++ b/client/client.go @@ -489,8 +489,7 @@ func loadClientTrust(ctx *openssl.Ctx, conf *Config) (*openssl.Ctx, error) { return ctx, nil } -func dialOpenSSL(ctx *openssl.Ctx, conf *Config, network string, addr string) (net.Conn, error) { - +func dialOpenSSL(ctx *openssl.Ctx, conf *Config, _ string, addr string) (net.Conn, error) { flags := openssl.DialFlags(0) if conf.NoVerify { @@ -694,7 +693,7 @@ func newWebsocketDialerTLS(conf Config) (*websocket.Dialer, error) { } dialer := websocket.Dialer{ - NetDial: func(network string, addr string) (net.Conn, error) { + NetDialTLS: func(network string, addr string) (net.Conn, error) { return dialOpenSSL(ctx, &conf, network, addr) }, } diff --git a/go.mod b/go.mod index 4a83e0a01..d347b584c 100644 --- a/go.mod +++ b/go.mod @@ -21,3 +21,5 @@ require ( ) replace github.com/urfave/cli/v2 => github.com/mendersoftware/cli/v2 v2.1.1-minimal + +replace github.com/gorilla/websocket => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 diff --git a/go.sum b/go.sum index 05e9840c2..8d3b99913 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/godbus/dbus v4.1.0+incompatible h1:WqqLRTsQic3apZUK9qC5sGNfXthmPXzUZ7nQPrNITa4= github.com/godbus/dbus v4.1.0+incompatible/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/klauspost/compress v1.10.5 h1:7q6vHIqubShURwQz8cQK6yIe/xC3IF0Vm7TGfqjewrc= github.com/klauspost/compress v1.10.5/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= @@ -34,6 +32,8 @@ github.com/mendersoftware/openssl v1.1.0 h1:eRiG3CwzkMIna1xrTE9SiX9lrsme9irlb6i5 github.com/mendersoftware/openssl v1.1.0/go.mod h1:tikEC94q+Y0TU6r19L6mHzwruoTNYPEkrQPvsHEcQyU= github.com/mendersoftware/progressbar v0.0.3 h1:AUdBGPvXO0l9i39rmXKZbEAPet2FzBeiG8b30D5/2Vc= github.com/mendersoftware/progressbar v0.0.3/go.mod h1:NYaLNLhy3UXkRweGjhR3We3Q1ngmUmOWjC3+m8EzwjE= +github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 h1:bxs2j1011PQiBPAP127cmBdAnw+zzq65tWOUeCFxVXU= +github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/minio/sha256-simd v0.1.1 h1:5QHSlgo3nt5yKOJrC7W8w7X+NFl8cMPZm96iu8kKUJU= github.com/minio/sha256-simd v0.1.1/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go index 962c06a39..7a6f09f7b 100644 --- a/vendor/github.com/gorilla/websocket/client.go +++ b/vendor/github.com/gorilla/websocket/client.go @@ -54,9 +54,21 @@ type Dialer struct { NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLS specifies the dial function for creating TLS/TCP connections. If + // NetDialTLS is nil, net.Dial is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLS func(network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialTLS is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -65,6 +77,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -237,13 +251,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialTLS != nil { + netDial = d.NetDialTLS + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -304,7 +339,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil && d.NetDialTLS == nil { + // If either NetDialTLS or NetDialTLSContext are set, assume that + // the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort @@ -348,8 +386,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application diff --git a/vendor/github.com/gorilla/websocket/client_clone.go b/vendor/github.com/gorilla/websocket/client_clone.go index 4f0d94372..4179c7a07 100644 --- a/vendor/github.com/gorilla/websocket/client_clone.go +++ b/vendor/github.com/gorilla/websocket/client_clone.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/client_clone_legacy.go b/vendor/github.com/gorilla/websocket/client_clone_legacy.go index babb007fb..7e241a88d 100644 --- a/vendor/github.com/gorilla/websocket/client_clone_legacy.go +++ b/vendor/github.com/gorilla/websocket/client_clone_legacy.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/conn_write.go b/vendor/github.com/gorilla/websocket/conn_write.go index a509a21f8..497467adb 100644 --- a/vendor/github.com/gorilla/websocket/conn_write.go +++ b/vendor/github.com/gorilla/websocket/conn_write.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/conn_write_legacy.go b/vendor/github.com/gorilla/websocket/conn_write_legacy.go index 37edaff5a..8501a2334 100644 --- a/vendor/github.com/gorilla/websocket/conn_write_legacy.go +++ b/vendor/github.com/gorilla/websocket/conn_write_legacy.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go index 577fce9ef..d0742bf2a 100644 --- a/vendor/github.com/gorilla/websocket/mask.go +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build !appengine // +build !appengine package websocket diff --git a/vendor/github.com/gorilla/websocket/mask_safe.go b/vendor/github.com/gorilla/websocket/mask_safe.go index 2aac060e5..36250ca7c 100644 --- a/vendor/github.com/gorilla/websocket/mask_safe.go +++ b/vendor/github.com/gorilla/websocket/mask_safe.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build appengine // +build appengine package websocket diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go index 887d55891..152ebf898 100644 --- a/vendor/github.com/gorilla/websocket/server.go +++ b/vendor/github.com/gorilla/websocket/server.go @@ -115,8 +115,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade -// request. Use the responseHeader to specify cookies (Set-Cookie) and the -// application negotiated subprotocol (Sec-WebSocket-Protocol). +// request. Use the responseHeader to specify cookies (Set-Cookie). To specify +// subprotocols supported by the server, set Upgrader.Subprotocols directly. // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. diff --git a/vendor/github.com/gorilla/websocket/trace.go b/vendor/github.com/gorilla/websocket/trace.go index 834f122a0..246a5d33d 100644 --- a/vendor/github.com/gorilla/websocket/trace.go +++ b/vendor/github.com/gorilla/websocket/trace.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package websocket diff --git a/vendor/github.com/gorilla/websocket/trace_17.go b/vendor/github.com/gorilla/websocket/trace_17.go index 77d05a0b5..f4be940ad 100644 --- a/vendor/github.com/gorilla/websocket/trace_17.go +++ b/vendor/github.com/gorilla/websocket/trace_17.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package websocket diff --git a/vendor/modules.txt b/vendor/modules.txt index e36f25dcb..dd79fb683 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -8,7 +8,7 @@ github.com/davecgh/go-spew/spew # github.com/godbus/dbus v4.1.0+incompatible ## explicit github.com/godbus/dbus -# github.com/gorilla/websocket v1.4.2 +# github.com/gorilla/websocket v1.4.2 => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918 ## explicit github.com/gorilla/websocket # github.com/klauspost/compress v1.10.5 @@ -74,3 +74,4 @@ golang.org/x/term # gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 gopkg.in/yaml.v3 # github.com/urfave/cli/v2 => github.com/mendersoftware/cli/v2 v2.1.1-minimal +# github.com/gorilla/websocket => github.com/mendersoftware/websocket v1.4.3-0.20211210145825-8a45e5d03918