Skip to content

Commit

Permalink
Update tlsClientHandshake (#1263)
Browse files Browse the repository at this point in the history
* Update tlsClientHandshake

* Update client.go

* Update client.go

* Update client.go

* Update client.go

* Changes according to the review
  • Loading branch information
moredure committed Apr 9, 2022
1 parent c7576cc commit b40b5a4
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions client.go
Expand Up @@ -1994,41 +1994,33 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake.
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")

var timeoutErrorChPool sync.Pool

func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)

var ch chan error
chv := timeoutErrorChPool.Get()
if chv == nil {
chv = make(chan error)
}
ch = chv.(chan error)
defer timeoutErrorChPool.Put(chv)

conn := tls.Client(rawConn, tlsConfig)

go func() {
ch <- conn.Handshake()
}()

select {
case <-tc.C:
rawConn.Close()
<-ch
return nil, ErrTLSHandshakeTimeout
case err := <-ch:
if err != nil {
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.Time) (_ net.Conn, retErr error) {
defer func() {
if retErr != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}()
conn := tls.Client(rawConn, tlsConfig)
err := conn.SetDeadline(deadline)
if err != nil {
return nil, err
}
err = conn.Handshake()
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, ErrTLSHandshakeTimeout
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, err
}
return conn, nil
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(timeout)
if dial == nil {
if dialDualStack {
dial = DialDualStack
Expand All @@ -2049,7 +2041,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, timeout)
return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}
Expand Down

0 comments on commit b40b5a4

Please sign in to comment.