diff --git a/client.go b/client.go index 832513fe25..fa0790b0f2 100644 --- a/client.go +++ b/client.go @@ -70,6 +70,10 @@ func DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return defaultClient.DoTimeout(req, resp, timeout) } +func DoTimeouts(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error { + return defaultClient.doInternal(req, resp, readTimeout, writeTimeout) +} + // DoDeadline performs the given request and waits for response until // the given deadline. // @@ -117,7 +121,7 @@ func DoDeadline(req *Request, resp *Response, deadline time.Time) error { // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { - _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient) + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient, 0, 0) return err } @@ -326,6 +330,10 @@ func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (stat return clientGetURLTimeout(dst, url, timeout, c) } +func (c *Client) GetTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) { + return clientGetURLTimeouts(dst, url, readTimeout, writeTimeout, c) +} + // GetDeadline returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst @@ -336,7 +344,8 @@ func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (stat // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *Client) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { - return clientGetURLDeadline(dst, url, deadline, c) + timeout := deadline.Sub(time.Now()) + return clientGetURLTimeout(dst, url, timeout, c) } // Post sends POST request to the given url with the given POST arguments. @@ -379,10 +388,16 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try setting a ReadTimeout. +// +// Deprecated: please use DoTimeouts if you want to overwrite timeout values. func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return clientDoTimeout(req, resp, timeout, c) } +func (c *Client) DoTimeouts(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error { + return c.doInternal(req, resp, readTimeout, writeTimeout) +} + // DoDeadline performs the given request and waits for response until // the given deadline. // @@ -406,6 +421,8 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. +// +// Deprecated: please use DoTimeouts if you want to overwrite timeout values. func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error { return clientDoDeadline(req, resp, deadline, c) } @@ -430,7 +447,7 @@ func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) er // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { - _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c) + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c, 0, 0) return err } @@ -454,6 +471,10 @@ func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) Do(req *Request, resp *Response) error { + return c.doInternal(req, resp, 0, 0) +} + +func (c *Client) doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error { uri := req.URI() if uri == nil { return ErrorInvalidURI @@ -521,7 +542,7 @@ func (c *Client) Do(req *Request, resp *Response) error { go c.mCleaner(m) } - return hc.Do(req, resp) + return hc.doInternal(req, resp, readTimeout, writeTimeout) } // CloseIdleConnections closes any connections which were previously @@ -824,6 +845,10 @@ func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) ( return clientGetURLTimeout(dst, url, timeout, c) } +func (c *HostClient) GetTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) { + return clientGetURLTimeouts(dst, url, readTimeout, writeTimeout, c) +} + // GetDeadline returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst @@ -834,7 +859,8 @@ func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) ( // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *HostClient) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { - return clientGetURLDeadline(dst, url, deadline, c) + timeout := deadline.Sub(time.Now()) + return clientGetURLTimeout(dst, url, timeout, c) } // Post sends POST request to the given url with the given POST arguments. @@ -851,34 +877,32 @@ func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode in type clientDoer interface { Do(req *Request, resp *Response) error + doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error } func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { req := AcquireRequest() - statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c) + statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0) ReleaseRequest(req) return statusCode, body, err } -func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { - deadline := time.Now().Add(timeout) - return clientGetURLDeadline(dst, url, deadline, c) -} - type clientURLResponse struct { statusCode int body []byte err error } -func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { - timeout := -time.Since(deadline) - if timeout <= 0 { - return 0, dst, ErrTimeout - } +func clientGetURLTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { + req := AcquireRequest() + defer ReleaseRequest(req) + + return doRequestFollowRedirectsBuffer(req, dst, url, c, readTimeout, writeTimeout) +} +func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { var ch chan clientURLResponse chv := clientURLResponseChPool.Get() if chv == nil { @@ -900,7 +924,7 @@ func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDo go func() { req := AcquireRequest() - statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c) + statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0) mu.Lock() { if !timedout { @@ -958,7 +982,7 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status } } - statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c) + statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0) ReleaseRequest(req) return statusCode, body, err @@ -978,14 +1002,14 @@ var ( const defaultMaxRedirectsCount = 16 -func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { +func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) { resp := AcquireResponse() bodyBuf := resp.bodyBuffer() resp.keepBodyBuffer = true oldBody := bodyBuf.B bodyBuf.B = dst - statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c) + statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c, readTimeout, writeTimeout) body = bodyBuf.B bodyBuf.B = oldBody @@ -995,7 +1019,7 @@ func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clie return statusCode, body, err } -func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) { +func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) { redirectsCount := 0 for { @@ -1004,7 +1028,7 @@ func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedir return 0, nil, err } - if err = c.Do(req, resp); err != nil { + if err = c.doInternal(req, resp, readTimeout, writeTimeout); err != nil { break } statusCode = resp.Header.StatusCode() @@ -1164,7 +1188,7 @@ func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { - _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c) + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c, 0, 0) return err } @@ -1270,6 +1294,10 @@ var errorChPool sync.Pool // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) Do(req *Request, resp *Response) error { + return c.doInternal(req, resp, 0, 0) +} + +func (c *HostClient) doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error { var err error var retry bool maxAttempts := c.MaxIdemponentCallAttempts @@ -1285,7 +1313,7 @@ func (c *HostClient) Do(req *Request, resp *Response) error { atomic.AddInt32(&c.pendingRequests, 1) for { - retry, err = c.do(req, resp) + retry, err = c.do(req, resp, readTimeout, writeTimeout) if err == nil || !retry { break } @@ -1331,14 +1359,14 @@ func isIdempotent(req *Request) bool { return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut() } -func (c *HostClient) do(req *Request, resp *Response) (bool, error) { +func (c *HostClient) do(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) (bool, error) { nilResp := false if resp == nil { nilResp = true resp = AcquireResponse() } - ok, err := c.doNonNilReqResp(req, resp) + ok, err := c.doNonNilReqResp(req, resp, readTimeout, writeTimeout) if nilResp { ReleaseResponse(resp) @@ -1347,7 +1375,7 @@ func (c *HostClient) do(req *Request, resp *Response) (bool, error) { return ok, err } -func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) { +func (c *HostClient) doNonNilReqResp(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) (bool, error) { if req == nil { panic("BUG: req cannot be nil") } @@ -1365,6 +1393,13 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) return false, ErrHostClientRedirectToDifferentScheme } + if c.WriteTimeout > writeTimeout { + writeTimeout = c.WriteTimeout + } + if c.ReadTimeout > readTimeout { + readTimeout = c.ReadTimeout + } + atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // Free up resources occupied by response before sending the request, @@ -1395,11 +1430,11 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) resp.parseNetConn(conn) - if c.WriteTimeout > 0 { + if writeTimeout > 0 { // Set Deadline every time, since golang has fixed the performance issue // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details currentTime := time.Now() - if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil { + if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { c.closeConn(cc) return true, err } @@ -1428,11 +1463,8 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) } c.releaseWriter(bw) - if c.ReadTimeout > 0 { - // Set Deadline every time, since golang has fixed the performance issue - // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details - currentTime := time.Now() - if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil { + if readTimeout > 0 { + if err = conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { c.closeConn(cc) return true, err } diff --git a/client_test.go b/client_test.go index e960745245..f83a23c323 100644 --- a/client_test.go +++ b/client_test.go @@ -20,6 +20,76 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestClientDoTimeoutsSuccess(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go func() { + if err := s.Serve(ln); err != nil { + t.Error(err) + } + }() + defer s.Shutdown() //nolint:errcheck + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + defer c.CloseIdleConnections() + + var req Request + var resp Response + + req.SetRequestURI("http://example.com") + if err := c.DoTimeouts(&req, &resp, time.Second, time.Second); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } +} + +func TestClientDoTimeoutsTimeout(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + time.Sleep(time.Millisecond * 400) + }, + Logger: &testLogger{}, + } + go func() { + if err := s.Serve(ln); err != nil { + t.Error(err) + } + }() + defer s.Shutdown() //nolint:errcheck + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxIdemponentCallAttempts: 1, + } + defer c.CloseIdleConnections() + + var req Request + var resp Response + + req.SetRequestURI("http://example.com") + if err := c.DoTimeouts(&req, &resp, time.Millisecond*200, time.Millisecond*200); err == nil { + t.Fatal("expected timeout error") + } +} + func TestCloseIdleConnections(t *testing.T) { t.Parallel()