From e87c21ba4de184ff4babea361beb82cc090aa733 Mon Sep 17 00:00:00 2001 From: Lluis Campos Date: Fri, 10 Dec 2021 10:46:09 +0100 Subject: [PATCH] Dialer: add optional methods NetDialTLS and NetDialTLSContext Fixes issue: https://github.com/gorilla/websocket/issues/745 With the previous interface, NetDial and NetDialContext were used for both TLS and non-TLS TCP connections, and afterwards TLSClientConfig was used to do the TLS handshake. While this API works for most cases, it prevents from using more advance authentication methods during the TLS handshake, as this is out of the control of the user. This commits introduces another pair of dial methods, NetDialTLS and NetDialTLSContext which are used when dialing for TLS/TCP. The code then assumes that the handshake is done there and TLSClientConfig is not used. This API change is fully backwards compatible and it better aligns with net/http.Transport API, which has these four dial flavors. See: https://pkg.go.dev/net/http#Transport Signed-off-by: Lluis Campos --- client.go | 54 +++++++-- client_server_test.go | 267 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index c4b62fbc..7a6f09f7 100644 --- a/client.go +++ b/client.go @@ -54,9 +54,21 @@ type Dialer struct { NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLS specifies the dial function for creating TLS/TCP connections. If + // NetDialTLS is nil, net.Dial is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLS func(network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialTLS is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -65,6 +77,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -237,13 +251,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialTLS != nil { + netDial = d.NetDialTLS + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -304,7 +339,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil && d.NetDialTLS == nil { + // If either NetDialTLS or NetDialTLSContext are set, assume that + // the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort diff --git a/client_server_test.go b/client_server_test.go index 5fd2c85a..84f9a912 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -920,3 +920,270 @@ func TestEmptyTracingDialWithContext(t *testing.T) { defer ws.Close() sendRecv(t, ws) } + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := map[string]struct { + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLS func(network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + "HTTP server, all NetDial* defined, shall use NetDialContext": { + server: server, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLS: func(network, addr string) (net.Conn, error) { + t.Error("NetDialTLS should not be called") + t.FailNow() + return nil, nil + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialTLSContext should not be called") + t.FailNow() + return nil, nil + }, + tlsClientConfig: nil, + }, + "HTTP server, all NetDial* undefined": { + server: server, + netDial: nil, + netDialContext: nil, + netDialTLS: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + "HTTP server, NetDialContext undefined, shall fallback to NetDial": { + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLS: func(network, addr string) (net.Conn, error) { + t.Error("NetDialTLS should not be called") + t.FailNow() + return nil, nil + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialTLSContext should not be called") + t.FailNow() + return nil, nil + }, + tlsClientConfig: nil, + }, + "HTTPS server, all NetDial* defined, shall use NetDialTLSContext": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialContext should not be called") + t.FailNow() + return nil, nil + }, + netDialTLS: func(network, addr string) (net.Conn, error) { + t.Error("NetDialTLS should not be called") + t.FailNow() + return nil, nil + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + "HTTPS server, NetDialTLSContext undefined, shall fallback to NetTLSDial": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialContext should not be called") + t.FailNow() + return nil, nil + }, + netDialTLS: func(network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + "HTTPS server, NetDialTLS* undefined, shall fallback to NetDialContext and do handshake": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLS: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + "HTTPS server, NetDialTLS* and NetDialContext undefined, shall fallback to NetDial and do handshake": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLS: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + "HTTPS server, all NetDial* undefined": { + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLS: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialContext should not be called") + t.FailNow() + return nil, nil + }, + netDialTLS: func(network, addr string) (net.Conn, error) { + t.Error("NetDialTLS should not be called") + t.FailNow() + return nil, nil + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + "HTTPS server, NetDialTLS defined, dummy TlsClientConfig defined, shall not do handshake": { + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + t.Error("NetDial should not be called") + t.FailNow() + return nil, nil + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Error("NetDialContext should not be called") + t.FailNow() + return nil, nil + }, + netDialTLS: func(network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + netDialTLSContext: nil, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for name, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLS: tc.netDialTLS, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", name, err.Error()) + } else { + c.Close() + } + } +}