Skip to content

Commit

Permalink
Dialer: add optional methods NetDialTLS and NetDialTLSContext
Browse files Browse the repository at this point in the history
Fixes issue: #745

With the previous interface, NetDial and NetDialContext were used for
both TLS and non-TLS TCP connections, and afterwards TLSClientConfig was
used to do the TLS handshake.

While this API works for most cases, it prevents from using more advance
authentication methods during the TLS handshake, as this is out of the
control of the user.

This commits introduces another pair of dial methods, NetDialTLS and
NetDialTLSContext which are used when dialing for TLS/TCP. The code then
assumes that the handshake is done there and TLSClientConfig is not
used.

This API change is fully backwards compatible and it better aligns with
net/http.Transport API, which has these four dial flavors. See:
https://pkg.go.dev/net/http#Transport

Signed-off-by: Lluis Campos <lluis.campos@northern.tech>
  • Loading branch information
lluiscampos committed Dec 10, 2021
1 parent d84bdd3 commit e87c21b
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 8 deletions.
54 changes: 46 additions & 8 deletions client.go
Expand Up @@ -54,9 +54,21 @@ type Dialer struct {
NetDial func(network, addr string) (net.Conn, error)

// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
// NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)

// NetDialTLS specifies the dial function for creating TLS/TCP connections. If
// NetDialTLS is nil, net.Dial is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
NetDialTLS func(network, addr string) (net.Conn, error)

// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialTLS is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)

// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
Expand All @@ -65,6 +77,8 @@ type Dialer struct {

// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config

// HandshakeTimeout specifies the duration for the handshake to complete.
Expand Down Expand Up @@ -237,13 +251,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)

if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialTLS != nil {
netDial = d.NetDialTLS
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}

if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
Expand Down Expand Up @@ -304,7 +339,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()

if u.Scheme == "https" {
if u.Scheme == "https" && d.NetDialTLSContext == nil && d.NetDialTLS == nil {
// If either NetDialTLS or NetDialTLSContext are set, assume that
// the TLS handshake has already been done

cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
Expand Down
267 changes: 267 additions & 0 deletions client_server_test.go
Expand Up @@ -920,3 +920,270 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
defer ws.Close()
sendRecv(t, ws)
}

// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
func TestNetDialConnect(t *testing.T) {

upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})

server := httptest.NewServer(handler)
defer server.Close()

tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()

testUrls := map[*httptest.Server]string{
server: "ws://" + server.Listener.Addr().String() + "/",
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
}

cas := rootCAs(t, tlsServer)
tlsConfig := &tls.Config{
RootCAs: cas,
ServerName: "example.com",
InsecureSkipVerify: false,
}

tests := map[string]struct {
server *httptest.Server // server to use
netDial func(network, addr string) (net.Conn, error)
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
netDialTLS func(network, addr string) (net.Conn, error)
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
tlsClientConfig *tls.Config
}{
"HTTP server, all NetDial* defined, shall use NetDialContext": {
server: server,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLS: func(network, addr string) (net.Conn, error) {
t.Error("NetDialTLS should not be called")
t.FailNow()
return nil, nil
},
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialTLSContext should not be called")
t.FailNow()
return nil, nil
},
tlsClientConfig: nil,
},
"HTTP server, all NetDial* undefined": {
server: server,
netDial: nil,
netDialContext: nil,
netDialTLS: nil,
netDialTLSContext: nil,
tlsClientConfig: nil,
},
"HTTP server, NetDialContext undefined, shall fallback to NetDial": {
server: server,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLS: func(network, addr string) (net.Conn, error) {
t.Error("NetDialTLS should not be called")
t.FailNow()
return nil, nil
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialTLSContext should not be called")
t.FailNow()
return nil, nil
},
tlsClientConfig: nil,
},
"HTTPS server, all NetDial* defined, shall use NetDialTLSContext": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialContext should not be called")
t.FailNow()
return nil, nil
},
netDialTLS: func(network, addr string) (net.Conn, error) {
t.Error("NetDialTLS should not be called")
t.FailNow()
return nil, nil
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: nil,
},
"HTTPS server, NetDialTLSContext undefined, shall fallback to NetTLSDial": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialContext should not be called")
t.FailNow()
return nil, nil
},
netDialTLS: func(network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
netDialTLSContext: nil,
tlsClientConfig: nil,
},
"HTTPS server, NetDialTLS* undefined, shall fallback to NetDialContext and do handshake": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLS: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
"HTTPS server, NetDialTLS* and NetDialContext undefined, shall fallback to NetDial and do handshake": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLS: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
"HTTPS server, all NetDial* undefined": {
server: tlsServer,
netDial: nil,
netDialContext: nil,
netDialTLS: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
"HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialContext should not be called")
t.FailNow()
return nil, nil
},
netDialTLS: func(network, addr string) (net.Conn, error) {
t.Error("NetDialTLS should not be called")
t.FailNow()
return nil, nil
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: &tls.Config{
RootCAs: nil,
ServerName: "badserver.com",
InsecureSkipVerify: false,
},
},
"HTTPS server, NetDialTLS defined, dummy TlsClientConfig defined, shall not do handshake": {
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
t.Error("NetDial should not be called")
t.FailNow()
return nil, nil
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Error("NetDialContext should not be called")
t.FailNow()
return nil, nil
},
netDialTLS: func(network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
netDialTLSContext: nil,
tlsClientConfig: &tls.Config{
RootCAs: nil,
ServerName: "badserver.com",
InsecureSkipVerify: false,
},
},
}

for name, tc := range tests {
dialer := Dialer{
NetDial: tc.netDial,
NetDialContext: tc.netDialContext,
NetDialTLS: tc.netDialTLS,
NetDialTLSContext: tc.netDialTLSContext,
TLSClientConfig: tc.tlsClientConfig,
}

// Test websocket dial
c, _, err := dialer.Dial(testUrls[tc.server], nil)
if err != nil {
t.Errorf("FAILED %s, err: %s", name, err.Error())
} else {
c.Close()
}
}
}

0 comments on commit e87c21b

Please sign in to comment.