Skip to content

Commit

Permalink
Improve Client timeout
Browse files Browse the repository at this point in the history
Don't run requests in a separate Goroutine anymore. Instead use proper
conn deadlines to enforce timeouts.

- Also contains some linting fixes.
  • Loading branch information
erikdubbelboer committed Jul 29, 2022
1 parent f3513cc commit 9af9728
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 116 deletions.
4 changes: 2 additions & 2 deletions args_test.go
Expand Up @@ -329,7 +329,7 @@ func TestArgsCopyTo(t *testing.T) {

func testCopyTo(t *testing.T, a *Args) {
keys := make(map[string]struct{})
a.VisitAll(func(k, v []byte) {
a.VisitAll(func(k, _ []byte) {
keys[string(k)] = struct{}{}
})

Expand All @@ -340,7 +340,7 @@ func testCopyTo(t *testing.T, a *Args) {
t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint
}

b.VisitAll(func(k, v []byte) {
b.VisitAll(func(k, _ []byte) {
if _, ok := keys[string(k)]; !ok {
t.Fatalf("unexpected key %q after copying from %q", k, a.String())
}
Expand Down
105 changes: 14 additions & 91 deletions client.go
Expand Up @@ -387,7 +387,8 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
req.timeout = timeout
return c.Do(req, resp)
}

// DoDeadline performs the given request and waits for response until
Expand All @@ -414,7 +415,8 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration)
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
req.timeout = time.Until(deadline)
return c.Do(req, resp)
}

// DoRedirects performs the given http request and fills the given http response,
Expand Down Expand Up @@ -1150,7 +1152,8 @@ func ReleaseResponse(resp *Response) {
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
req.timeout = timeout
return c.Do(req, resp)
}

// DoDeadline performs the given request and waits for response until
Expand All @@ -1172,7 +1175,8 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
req.timeout = time.Until(deadline)
return c.Do(req, resp)
}

// DoRedirects performs the given http request and fills the given http response,
Expand All @@ -1199,93 +1203,6 @@ func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount
return err
}

func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error {
deadline := time.Now().Add(timeout)
return clientDoDeadline(req, resp, deadline, c)
}

func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error {
timeout := -time.Since(deadline)
if timeout <= 0 {
return ErrTimeout
}

var ch chan error
chv := errorChPool.Get()
if chv == nil {
chv = make(chan error, 1)
}
ch = chv.(chan error)

// Make req and resp copies, since on timeout they no longer
// may be accessed.
reqCopy := AcquireRequest()
req.copyToSkipBody(reqCopy)
swapRequestBody(req, reqCopy)
respCopy := AcquireResponse()
if resp != nil {
// Not calling resp.copyToSkipBody(respCopy) here to avoid
// unexpected messing with headers
respCopy.SkipBody = resp.SkipBody
}

// Note that the request continues execution on ErrTimeout until
// client-specific ReadTimeout exceeds. This helps limiting load
// on slow hosts by MaxConns* concurrent requests.
//
// Without this 'hack' the load on slow host could exceed MaxConns*
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.

var mu sync.Mutex
var timedout, responded bool

go func() {
reqCopy.timeout = timeout
errDo := c.Do(reqCopy, respCopy)
mu.Lock()
{
if !timedout {
if resp != nil {
respCopy.copyToSkipBody(resp)
swapResponseBody(resp, respCopy)
}
swapRequestBody(reqCopy, req)
ch <- errDo
responded = true
}
}
mu.Unlock()

ReleaseResponse(respCopy)
ReleaseRequest(reqCopy)
}()

tc := AcquireTimer(timeout)
var err error
select {
case err = <-ch:
case <-tc.C:
mu.Lock()
{
if responded {
err = <-ch
} else {
timedout = true
err = ErrTimeout
}
}
mu.Unlock()
}
ReleaseTimer(tc)

errorChPool.Put(chv)

return err
}

var errorChPool sync.Pool

// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
Expand Down Expand Up @@ -1464,6 +1381,12 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
err = bw.Flush()
}
c.releaseWriter(bw)

// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}

isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
c.closeConn(cc)
Expand Down
44 changes: 33 additions & 11 deletions client_test.go
Expand Up @@ -181,7 +181,7 @@ func TestClientInvalidURI(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
Expand Down Expand Up @@ -636,7 +636,7 @@ func TestClientReadTimeout(t *testing.T) {

timeout := false
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
if timeout {
time.Sleep(time.Second)
} else {
Expand Down Expand Up @@ -1191,7 +1191,7 @@ func TestHostClientPendingRequests(t *testing.T) {
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
Expand Down Expand Up @@ -1750,16 +1750,19 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) {

type readTimeoutConn struct {
net.Conn
t time.Duration
t time.Duration
wc chan struct{}
rc chan struct{}
}

func (r *readTimeoutConn) Read(p []byte) (int, error) {
time.Sleep(r.t)
return 0, io.EOF
<-r.rc
return 0, os.ErrDeadlineExceeded
}

func (r *readTimeoutConn) Write(p []byte) (int, error) {
return len(p), nil
<-r.wc
return 0, os.ErrDeadlineExceeded
}

func (r *readTimeoutConn) Close() error {
Expand All @@ -1774,12 +1777,30 @@ func (r *readTimeoutConn) RemoteAddr() net.Addr {
return nil
}

func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
r.rc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.rc <- struct{}{}
}()
return nil
}

func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
r.wc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.wc <- struct{}{}
}()
return nil
}

func TestClientNonIdempotentRetry(t *testing.T) {
t.Parallel()

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
Expand Down Expand Up @@ -1829,7 +1850,7 @@ func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
Expand Down Expand Up @@ -1866,7 +1887,7 @@ func TestClientIdempotentRequest(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
Expand Down Expand Up @@ -1922,7 +1943,7 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
Expand Down Expand Up @@ -2758,6 +2779,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
time.Sleep(sleep)
ctx.WriteString("foo") //nolint:errcheck
},
Logger: &testLogger{}, // Don't print connection closed errors.
}
serverStopCh := make(chan struct{})
go func() {
Expand Down
2 changes: 1 addition & 1 deletion client_timing_test.go
Expand Up @@ -515,7 +515,7 @@ func BenchmarkNetHTTPClientEndToEndBigResponse10Inmemory(b *testing.B) {

func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
bigResponse := createFixedBody(1024 * 1024)
h := func(w http.ResponseWriter, r *http.Request) {
h := func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set(HeaderContentType, "text/plain")
w.Write(bigResponse) //nolint:errcheck
}
Expand Down
6 changes: 3 additions & 3 deletions header.go
Expand Up @@ -868,15 +868,15 @@ func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool {
// i.e. the number of times f is called in VisitAll.
func (h *ResponseHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
h.VisitAll(func(_, _ []byte) { n++ })
return n
}

// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *RequestHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
h.VisitAll(func(_, _ []byte) { n++ })
return n
}

Expand Down Expand Up @@ -1077,7 +1077,7 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) {
f(strServer, server)
}
if len(h.cookies) > 0 {
visitArgs(h.cookies, func(k, v []byte) {
visitArgs(h.cookies, func(_, v []byte) {
f(strSetCookie, v)
})
}
Expand Down
8 changes: 4 additions & 4 deletions header_test.go
Expand Up @@ -1942,7 +1942,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
Expand All @@ -1963,7 +1963,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
Expand All @@ -1987,7 +1987,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
Expand All @@ -2005,7 +2005,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
Expand Down
9 changes: 5 additions & 4 deletions server_test.go
Expand Up @@ -254,7 +254,7 @@ func TestServerConnState(t *testing.T) {
states := make([]string, 0)
s := &Server{
Handler: func(ctx *RequestCtx) {},
ConnState: func(conn net.Conn, state ConnState) {
ConnState: func(_ net.Conn, state ConnState) {
states = append(states, state.String())
},
}
Expand Down Expand Up @@ -2103,7 +2103,7 @@ func TestServerErrorHandler(t *testing.T) {

s := &Server{
Handler: func(ctx *RequestCtx) {},
ErrorHandler: func(ctx *RequestCtx, err error) {
ErrorHandler: func(ctx *RequestCtx, _ error) {
resultReqStr = ctx.Request.String()
resultRespStr = ctx.Response.String()
},
Expand Down Expand Up @@ -3681,6 +3681,7 @@ func TestStreamRequestBody(t *testing.T) {
checkReader(t, ctx.RequestBodyStream(), part2)
},
StreamRequestBody: true,
Logger: &testLogger{},
}

pipe := fasthttputil.NewPipeConns()
Expand Down Expand Up @@ -3828,7 +3829,7 @@ func TestMaxReadTimeoutPerRequest(t *testing.T) {

headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
t.Error("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
Expand Down Expand Up @@ -3971,7 +3972,7 @@ func TestServerChunkedResponse(t *testing.T) {
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %v", err)
}
time.Sleep(time.Second)
time.Sleep(time.Millisecond * 100)
}
})
for k, v := range trailer {
Expand Down

0 comments on commit 9af9728

Please sign in to comment.