Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tlsClientHandshake #1263

Merged
merged 6 commits into from Apr 9, 2022
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 30 additions & 30 deletions client.go
Expand Up @@ -1994,41 +1994,41 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake.
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")

var timeoutErrorChPool sync.Pool

func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)

var ch chan error
chv := timeoutErrorChPool.Get()
if chv == nil {
chv = make(chan error)
}
ch = chv.(chan error)
defer timeoutErrorChPool.Put(chv)

conn := tls.Client(rawConn, tlsConfig)

go func() {
ch <- conn.Handshake()
}()

select {
case <-tc.C:
rawConn.Close()
<-ch
return nil, ErrTLSHandshakeTimeout
case err := <-ch:
if err != nil {
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.Time) (_ net.Conn, retErr error) {
defer func() {
if retErr != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}()
conn := tls.Client(rawConn, tlsConfig)
err := conn.SetReadDeadline(deadline)
if err != nil {
return nil, err
}
err = conn.SetWriteDeadline(deadline)
if err != nil {
return nil, err
}
moredure marked this conversation as resolved.
Show resolved Hide resolved
err = conn.Handshake()
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, ErrTLSHandshakeTimeout
}
if err != nil {
return nil, err
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
return nil, err
}
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
return nil, err
}
return conn, nil
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(timeout)
if dial == nil {
if dialDualStack {
dial = DialDualStack
Expand All @@ -2049,7 +2049,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, timeout)
return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}
Expand Down