Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Timeouts methods to Client #1096

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 66 additions & 34 deletions client.go
Expand Up @@ -70,6 +70,10 @@ func DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return defaultClient.DoTimeout(req, resp, timeout)
}

func DoTimeouts(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error {
return defaultClient.doInternal(req, resp, readTimeout, writeTimeout)
}

// DoDeadline performs the given request and waits for response until
// the given deadline.
//
Expand Down Expand Up @@ -117,7 +121,7 @@ func DoDeadline(req *Request, resp *Response, deadline time.Time) error {
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient)
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient, 0, 0)
return err
}

Expand Down Expand Up @@ -326,6 +330,10 @@ func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (stat
return clientGetURLTimeout(dst, url, timeout, c)
}

func (c *Client) GetTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeouts(dst, url, readTimeout, writeTimeout, c)
}

// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
Expand All @@ -336,7 +344,8 @@ func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (stat
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *Client) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
timeout := deadline.Sub(time.Now())
return clientGetURLTimeout(dst, url, timeout, c)
}

// Post sends POST request to the given url with the given POST arguments.
Expand Down Expand Up @@ -379,10 +388,16 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
//
// Deprecated: please use DoTimeouts if you want to overwrite timeout values.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
}

func (c *Client) DoTimeouts(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error {
return c.doInternal(req, resp, readTimeout, writeTimeout)
}

// DoDeadline performs the given request and waits for response until
// the given deadline.
//
Expand All @@ -406,6 +421,8 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration)
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
//
// Deprecated: please use DoTimeouts if you want to overwrite timeout values.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
}
Expand All @@ -430,7 +447,7 @@ func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) er
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c, 0, 0)
return err
}

Expand All @@ -454,6 +471,10 @@ func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) Do(req *Request, resp *Response) error {
return c.doInternal(req, resp, 0, 0)
}

func (c *Client) doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error {
uri := req.URI()
if uri == nil {
return ErrorInvalidURI
Expand Down Expand Up @@ -521,7 +542,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
go c.mCleaner(m)
}

return hc.Do(req, resp)
return hc.doInternal(req, resp, readTimeout, writeTimeout)
}

// CloseIdleConnections closes any connections which were previously
Expand Down Expand Up @@ -824,6 +845,10 @@ func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (
return clientGetURLTimeout(dst, url, timeout, c)
}

func (c *HostClient) GetTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeouts(dst, url, readTimeout, writeTimeout, c)
}

// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
Expand All @@ -834,7 +859,8 @@ func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *HostClient) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
timeout := deadline.Sub(time.Now())
return clientGetURLTimeout(dst, url, timeout, c)
}

// Post sends POST request to the given url with the given POST arguments.
Expand All @@ -851,34 +877,32 @@ func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode in

type clientDoer interface {
Do(req *Request, resp *Response) error
doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error
}

func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()

statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0)

ReleaseRequest(req)
return statusCode, body, err
}

func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) {
deadline := time.Now().Add(timeout)
return clientGetURLDeadline(dst, url, deadline, c)
}

type clientURLResponse struct {
statusCode int
body []byte
err error
}

func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) {
timeout := -time.Since(deadline)
if timeout <= 0 {
return 0, dst, ErrTimeout
}
func clientGetURLTimeouts(dst []byte, url string, readTimeout, writeTimeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
defer ReleaseRequest(req)

return doRequestFollowRedirectsBuffer(req, dst, url, c, readTimeout, writeTimeout)
}

func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) {
var ch chan clientURLResponse
chv := clientURLResponseChPool.Get()
if chv == nil {
Expand All @@ -900,7 +924,7 @@ func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDo
go func() {
req := AcquireRequest()

statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c)
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0)
mu.Lock()
{
if !timedout {
Expand Down Expand Up @@ -958,7 +982,7 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status
}
}

statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c, 0, 0)

ReleaseRequest(req)
return statusCode, body, err
Expand All @@ -978,14 +1002,14 @@ var (

const defaultMaxRedirectsCount = 16

func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) {
resp := AcquireResponse()
bodyBuf := resp.bodyBuffer()
resp.keepBodyBuffer = true
oldBody := bodyBuf.B
bodyBuf.B = dst

statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c)
statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c, readTimeout, writeTimeout)

body = bodyBuf.B
bodyBuf.B = oldBody
Expand All @@ -995,7 +1019,7 @@ func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clie
return statusCode, body, err
}

func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) {
func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer, readTimeout, writeTimeout time.Duration) (statusCode int, body []byte, err error) {
redirectsCount := 0

for {
Expand All @@ -1004,7 +1028,7 @@ func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedir
return 0, nil, err
}

if err = c.Do(req, resp); err != nil {
if err = c.doInternal(req, resp, readTimeout, writeTimeout); err != nil {
break
}
statusCode = resp.Header.StatusCode()
Expand Down Expand Up @@ -1164,7 +1188,7 @@ func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c, 0, 0)
return err
}

Expand Down Expand Up @@ -1270,6 +1294,10 @@ var errorChPool sync.Pool
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) Do(req *Request, resp *Response) error {
return c.doInternal(req, resp, 0, 0)
}

func (c *HostClient) doInternal(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) error {
var err error
var retry bool
maxAttempts := c.MaxIdemponentCallAttempts
Expand All @@ -1285,7 +1313,7 @@ func (c *HostClient) Do(req *Request, resp *Response) error {

atomic.AddInt32(&c.pendingRequests, 1)
for {
retry, err = c.do(req, resp)
retry, err = c.do(req, resp, readTimeout, writeTimeout)
if err == nil || !retry {
break
}
Expand Down Expand Up @@ -1331,14 +1359,14 @@ func isIdempotent(req *Request) bool {
return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut()
}

func (c *HostClient) do(req *Request, resp *Response) (bool, error) {
func (c *HostClient) do(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) (bool, error) {
nilResp := false
if resp == nil {
nilResp = true
resp = AcquireResponse()
}

ok, err := c.doNonNilReqResp(req, resp)
ok, err := c.doNonNilReqResp(req, resp, readTimeout, writeTimeout)

if nilResp {
ReleaseResponse(resp)
Expand All @@ -1347,7 +1375,7 @@ func (c *HostClient) do(req *Request, resp *Response) (bool, error) {
return ok, err
}

func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) {
func (c *HostClient) doNonNilReqResp(req *Request, resp *Response, readTimeout, writeTimeout time.Duration) (bool, error) {
if req == nil {
panic("BUG: req cannot be nil")
}
Expand All @@ -1365,6 +1393,13 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
return false, ErrHostClientRedirectToDifferentScheme
}

if c.WriteTimeout > writeTimeout {
writeTimeout = c.WriteTimeout
}
if c.ReadTimeout > readTimeout {
readTimeout = c.ReadTimeout
}

atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix))

// Free up resources occupied by response before sending the request,
Expand Down Expand Up @@ -1395,11 +1430,11 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)

resp.parseNetConn(conn)

if c.WriteTimeout > 0 {
if writeTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil {
if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil {
c.closeConn(cc)
return true, err
}
Expand Down Expand Up @@ -1428,11 +1463,8 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
}
c.releaseWriter(bw)

if c.ReadTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil {
if readTimeout > 0 {
if err = conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
c.closeConn(cc)
return true, err
}
Expand Down
70 changes: 70 additions & 0 deletions client_test.go
Expand Up @@ -20,6 +20,76 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
)

func TestClientDoTimeoutsSuccess(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()

s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Error(err)
}
}()
defer s.Shutdown() //nolint:errcheck

c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
defer c.CloseIdleConnections()

var req Request
var resp Response

req.SetRequestURI("http://example.com")
if err := c.DoTimeouts(&req, &resp, time.Second, time.Second); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}

func TestClientDoTimeoutsTimeout(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()

s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(time.Millisecond * 400)
},
Logger: &testLogger{},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Error(err)
}
}()
defer s.Shutdown() //nolint:errcheck

c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxIdemponentCallAttempts: 1,
}
defer c.CloseIdleConnections()

var req Request
var resp Response

req.SetRequestURI("http://example.com")
if err := c.DoTimeouts(&req, &resp, time.Millisecond*200, time.Millisecond*200); err == nil {
t.Fatal("expected timeout error")
}
}

func TestCloseIdleConnections(t *testing.T) {
t.Parallel()

Expand Down