diff --git a/tcpdialer.go b/tcpdialer.go index e8430cb9c..e5f06bd01 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -3,6 +3,7 @@ package fasthttp import ( "context" "errors" + "fmt" "net" "strconv" "sync" @@ -302,7 +303,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne if err == nil { return conn, nil } - if err == ErrDialTimeout { + if errors.Is(err, ErrDialTimeout) { return nil, err } idx++ @@ -316,7 +317,7 @@ func (d *TCPDialer) tryDial( ) (net.Conn, error) { timeout := time.Until(deadline) if timeout <= 0 { - return nil, ErrDialTimeout + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) } if concurrencyCh != nil { @@ -332,7 +333,7 @@ func (d *TCPDialer) tryDial( } ReleaseTimer(tc) if isTimeout { - return nil, ErrDialTimeout + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) } } defer func() { <-concurrencyCh }() @@ -346,15 +347,49 @@ func (d *TCPDialer) tryDial( ctx, cancelCtx := context.WithDeadline(context.Background(), deadline) defer cancelCtx() conn, err := dialer.DialContext(ctx, network, addr) - if err != nil && ctx.Err() == context.DeadlineExceeded { - return nil, ErrDialTimeout + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) + } + return nil, wrapDialWithUpstream(err, addr) } - return conn, err + return conn, nil } // ErrDialTimeout is returned when TCP dialing is timed out. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out") +// ErrDialWithUpstream wraps dial error with upstream info. +// +// Should use errors.As to get upstream information from error: +// +// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"} +// err := hc.Do(req, res) +// +// var dialErr *fasthttp.ErrDialWithUpstream +// if errors.As(err, &dialErr) { +// upstream = dialErr.Upstream // 34.206.39.153:80 +// } +type ErrDialWithUpstream struct { + Upstream string + wrapErr error +} + +func (e *ErrDialWithUpstream) Error() string { + return fmt.Sprintf("error when dialing %s: %s", e.Upstream, e.wrapErr.Error()) +} + +func (e *ErrDialWithUpstream) Unwrap() error { + return e.wrapErr +} + +func wrapDialWithUpstream(err error, upstream string) error { + return &ErrDialWithUpstream{ + Upstream: upstream, + wrapErr: err, + } +} + // DefaultDialTimeout is timeout used by Dial and DialDualStack // for establishing TCP connections. const DefaultDialTimeout = 3 * time.Second