From b40b5a4ca3e3e56378b1bf77f5321a0b92736fd6 Mon Sep 17 00:00:00 2001 From: Mikhail Faraponov <11322032+moredure@users.noreply.github.com> Date: Sun, 10 Apr 2022 02:21:37 +0300 Subject: [PATCH] Update tlsClientHandshake (#1263) * Update tlsClientHandshake * Update client.go * Update client.go * Update client.go * Update client.go * Changes according to the review --- client.go | 52 ++++++++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/client.go b/client.go index 5c40f782af..7ffae8b6f4 100644 --- a/client.go +++ b/client.go @@ -1994,41 +1994,33 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config { // ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake. var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out") -var timeoutErrorChPool sync.Pool - -func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { - tc := AcquireTimer(timeout) - defer ReleaseTimer(tc) - - var ch chan error - chv := timeoutErrorChPool.Get() - if chv == nil { - chv = make(chan error) - } - ch = chv.(chan error) - defer timeoutErrorChPool.Put(chv) - - conn := tls.Client(rawConn, tlsConfig) - - go func() { - ch <- conn.Handshake() - }() - - select { - case <-tc.C: - rawConn.Close() - <-ch - return nil, ErrTLSHandshakeTimeout - case err := <-ch: - if err != nil { +func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.Time) (_ net.Conn, retErr error) { + defer func() { + if retErr != nil { rawConn.Close() - return nil, err } - return conn, nil + }() + conn := tls.Client(rawConn, tlsConfig) + err := conn.SetDeadline(deadline) + if err != nil { + return nil, err + } + err = conn.Handshake() + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, ErrTLSHandshakeTimeout } + if err != nil { + return nil, err + } + err = conn.SetDeadline(time.Time{}) + if err != nil { + return nil, err + } + return conn, nil } func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + deadline := time.Now().Add(timeout) if dial == nil { if dialDualStack { dial = DialDualStack @@ -2049,7 +2041,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * if timeout == 0 { return tls.Client(conn, tlsConfig), nil } - return tlsClientHandshake(conn, tlsConfig, timeout) + return tlsClientHandshake(conn, tlsConfig, deadline) } return conn, nil }