Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Client timeout #1346

Merged
merged 1 commit into from Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions args_test.go
Expand Up @@ -329,7 +329,7 @@ func TestArgsCopyTo(t *testing.T) {

func testCopyTo(t *testing.T, a *Args) {
keys := make(map[string]struct{})
a.VisitAll(func(k, v []byte) {
a.VisitAll(func(k, _ []byte) {
keys[string(k)] = struct{}{}
})

Expand All @@ -340,7 +340,7 @@ func testCopyTo(t *testing.T, a *Args) {
t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint
}

b.VisitAll(func(k, v []byte) {
b.VisitAll(func(k, _ []byte) {
if _, ok := keys[string(k)]; !ok {
t.Fatalf("unexpected key %q after copying from %q", k, a.String())
}
Expand Down
105 changes: 14 additions & 91 deletions client.go
Expand Up @@ -387,7 +387,8 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
req.timeout = timeout
return c.Do(req, resp)
}

// DoDeadline performs the given request and waits for response until
Expand All @@ -414,7 +415,8 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration)
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
req.timeout = time.Until(deadline)
return c.Do(req, resp)
}

// DoRedirects performs the given http request and fills the given http response,
Expand Down Expand Up @@ -1150,7 +1152,8 @@ func ReleaseResponse(resp *Response) {
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
req.timeout = timeout
return c.Do(req, resp)
}

// DoDeadline performs the given request and waits for response until
Expand All @@ -1172,7 +1175,8 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
req.timeout = time.Until(deadline)
return c.Do(req, resp)
}

// DoRedirects performs the given http request and fills the given http response,
Expand All @@ -1199,93 +1203,6 @@ func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount
return err
}

func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error {
deadline := time.Now().Add(timeout)
return clientDoDeadline(req, resp, deadline, c)
}

func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error {
timeout := -time.Since(deadline)
if timeout <= 0 {
return ErrTimeout
}

var ch chan error
chv := errorChPool.Get()
if chv == nil {
chv = make(chan error, 1)
}
ch = chv.(chan error)

// Make req and resp copies, since on timeout they no longer
// may be accessed.
reqCopy := AcquireRequest()
req.copyToSkipBody(reqCopy)
swapRequestBody(req, reqCopy)
respCopy := AcquireResponse()
if resp != nil {
// Not calling resp.copyToSkipBody(respCopy) here to avoid
// unexpected messing with headers
respCopy.SkipBody = resp.SkipBody
}

// Note that the request continues execution on ErrTimeout until
// client-specific ReadTimeout exceeds. This helps limiting load
// on slow hosts by MaxConns* concurrent requests.
//
// Without this 'hack' the load on slow host could exceed MaxConns*
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.

var mu sync.Mutex
var timedout, responded bool

go func() {
reqCopy.timeout = timeout
errDo := c.Do(reqCopy, respCopy)
mu.Lock()
{
if !timedout {
if resp != nil {
respCopy.copyToSkipBody(resp)
swapResponseBody(resp, respCopy)
}
swapRequestBody(reqCopy, req)
ch <- errDo
responded = true
}
}
mu.Unlock()

ReleaseResponse(respCopy)
ReleaseRequest(reqCopy)
}()

tc := AcquireTimer(timeout)
var err error
select {
case err = <-ch:
case <-tc.C:
mu.Lock()
{
if responded {
err = <-ch
} else {
timedout = true
err = ErrTimeout
}
}
mu.Unlock()
}
ReleaseTimer(tc)

errorChPool.Put(chv)

return err
}

var errorChPool sync.Pool

// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
Expand Down Expand Up @@ -1464,6 +1381,12 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
err = bw.Flush()
}
c.releaseWriter(bw)

// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}

isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
c.closeConn(cc)
Expand Down
44 changes: 33 additions & 11 deletions client_test.go
Expand Up @@ -181,7 +181,7 @@ func TestClientInvalidURI(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
Expand Down Expand Up @@ -636,7 +636,7 @@ func TestClientReadTimeout(t *testing.T) {

timeout := false
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
if timeout {
time.Sleep(time.Second)
} else {
Expand Down Expand Up @@ -1191,7 +1191,7 @@ func TestHostClientPendingRequests(t *testing.T) {
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
Expand Down Expand Up @@ -1750,16 +1750,19 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) {

type readTimeoutConn struct {
net.Conn
t time.Duration
t time.Duration
wc chan struct{}
rc chan struct{}
}

func (r *readTimeoutConn) Read(p []byte) (int, error) {
time.Sleep(r.t)
return 0, io.EOF
<-r.rc
return 0, os.ErrDeadlineExceeded
}

func (r *readTimeoutConn) Write(p []byte) (int, error) {
return len(p), nil
<-r.wc
return 0, os.ErrDeadlineExceeded
}

func (r *readTimeoutConn) Close() error {
Expand All @@ -1774,12 +1777,30 @@ func (r *readTimeoutConn) RemoteAddr() net.Addr {
return nil
}

func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
r.rc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.rc <- struct{}{}
}()
return nil
}

func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
r.wc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.wc <- struct{}{}
}()
return nil
}

func TestClientNonIdempotentRetry(t *testing.T) {
t.Parallel()

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
Expand Down Expand Up @@ -1829,7 +1850,7 @@ func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
Expand Down Expand Up @@ -1866,7 +1887,7 @@ func TestClientIdempotentRequest(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
Expand Down Expand Up @@ -1922,7 +1943,7 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {

dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
Expand Down Expand Up @@ -2758,6 +2779,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
time.Sleep(sleep)
ctx.WriteString("foo") //nolint:errcheck
},
Logger: &testLogger{}, // Don't print connection closed errors.
}
serverStopCh := make(chan struct{})
go func() {
Expand Down
2 changes: 1 addition & 1 deletion client_timing_test.go
Expand Up @@ -515,7 +515,7 @@ func BenchmarkNetHTTPClientEndToEndBigResponse10Inmemory(b *testing.B) {

func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
bigResponse := createFixedBody(1024 * 1024)
h := func(w http.ResponseWriter, r *http.Request) {
h := func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set(HeaderContentType, "text/plain")
w.Write(bigResponse) //nolint:errcheck
}
Expand Down
6 changes: 3 additions & 3 deletions header.go
Expand Up @@ -868,15 +868,15 @@ func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool {
// i.e. the number of times f is called in VisitAll.
func (h *ResponseHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
h.VisitAll(func(_, _ []byte) { n++ })
return n
}

// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *RequestHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
h.VisitAll(func(_, _ []byte) { n++ })
return n
}

Expand Down Expand Up @@ -1077,7 +1077,7 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) {
f(strServer, server)
}
if len(h.cookies) > 0 {
visitArgs(h.cookies, func(k, v []byte) {
visitArgs(h.cookies, func(_, v []byte) {
f(strSetCookie, v)
})
}
Expand Down
8 changes: 4 additions & 4 deletions header_test.go
Expand Up @@ -1942,7 +1942,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
Expand All @@ -1963,7 +1963,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
Expand All @@ -1987,7 +1987,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
Expand All @@ -2005,7 +2005,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
Expand Down
9 changes: 5 additions & 4 deletions server_test.go
Expand Up @@ -254,7 +254,7 @@ func TestServerConnState(t *testing.T) {
states := make([]string, 0)
s := &Server{
Handler: func(ctx *RequestCtx) {},
ConnState: func(conn net.Conn, state ConnState) {
ConnState: func(_ net.Conn, state ConnState) {
states = append(states, state.String())
},
}
Expand Down Expand Up @@ -2103,7 +2103,7 @@ func TestServerErrorHandler(t *testing.T) {

s := &Server{
Handler: func(ctx *RequestCtx) {},
ErrorHandler: func(ctx *RequestCtx, err error) {
ErrorHandler: func(ctx *RequestCtx, _ error) {
resultReqStr = ctx.Request.String()
resultRespStr = ctx.Response.String()
},
Expand Down Expand Up @@ -3681,6 +3681,7 @@ func TestStreamRequestBody(t *testing.T) {
checkReader(t, ctx.RequestBodyStream(), part2)
},
StreamRequestBody: true,
Logger: &testLogger{},
}

pipe := fasthttputil.NewPipeConns()
Expand Down Expand Up @@ -3828,7 +3829,7 @@ func TestMaxReadTimeoutPerRequest(t *testing.T) {

headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
Handler: func(ctx *RequestCtx) {
Handler: func(_ *RequestCtx) {
t.Error("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
Expand Down Expand Up @@ -3971,7 +3972,7 @@ func TestServerChunkedResponse(t *testing.T) {
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %v", err)
}
time.Sleep(time.Second)
time.Sleep(time.Millisecond * 100)
}
})
for k, v := range trailer {
Expand Down